In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import tensorflow as tf
from tensorflow.keras import mixed_precision
import utilities.Model_utilities as my_model_util

policy = mixed_precision.Policy('mixed_float16') 

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  # Restrict TensorFlow to only allocate 4GB of memory on the first GPU
  try:
    tf.config.experimental.set_virtual_device_configuration(
        gpus[0],
        [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096)])
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Virtual devices must be set before GPUs have been initialized
    print(e)


In [None]:
image_height = 544
image_width = 736
n_classes = 7

In [None]:
model = my_model_util.unet_custom(n_classes, image_height, image_width, model_depth=3, dropout=0.1)
# model.summary()

In [None]:
gen = my_model_util.train_gen("tmp/test/",
                            "tmp/label/", batch_size = 2, \
                             n_classes = n_classes, height = image_height, width = image_width)
model.compile(optimizer='adam',
              loss='mse',
              metrics=[my_model_util.iou_coef])
model.fit(gen,steps_per_epoch=300,epochs=15)
model.save('models/')

In [None]:
pr = my_model_util.predict_model(model, 
    '/media/anaph/My Passport/dataset/images/9image4801.png',
    n_classes, image_height, image_width)

In [None]:
import matplotlib.pyplot as plt
plt.imshow(pr)

In [None]:
color = my_model_util.gen_color_for_labels(['Ground', 'Biker', 'Pedestrian', 'Skateboarder', 'Cart', 'Car', 'Bus'])

In [None]:
import cv2
pr = my_model_util.predict_model(model, '/media/anaph/My Passport/dataset/images/0image8901.png', n_classes, image_height, image_width)
image = cv2.imread('/media/anaph/My Passport/dataset/images/0image8901.png',)
o = my_model_util.visualize_segmentation(pr,image,color)
plt.figure(figsize=(9, 9))
plt.imshow(o.astype('uint8'))