<a href="https://colab.research.google.com/github/Abinesh-18/Deep-Lab-based-semantic-segmentation/blob/main/Deep_Lab_based_semantic_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Gather/unzip data
from google.colab import drive
from IPython.display import clear_output
drive.mount('/content/gdrive')
!cp /content/gdrive/MyDrive/AdvCV/celeb.zip .
!unzip celeb.zip
clear_output()

In [None]:
import tensorflow as tf
tf.config.run_functions_eagerly(True)
from tensorflow import keras
import glob
import numpy as np
import cv2

In [None]:
img_paths = glob.glob('CelebAMask-HQ/CelebA-HQ-img/*.jpg')
mask_paths = glob.glob('CelebAMask-HQ/CelebAMask-HQ-mask-anno/*/*.png', recursive=True)

# 18 classes + background (19 total)
classes = {
    0: 'background',
    1: 'cloth',
    2: 'ear_r',
    3: 'eye_g', 
    4: 'hair', 
    5: 'hat', 
    6: 'l_brow',
    7: 'l_ear',
    8: 'l_eye', 
    9: 'l_lip', 
    10: 'mouth', 
    11: 'neck', 
    12: 'neck_l', 
    13: 'nose', 
    14: 'r_brow', 
    15: 'r_ear', 
    16: 'r_eye', 
    17: 'skin', 
    18: 'u_lip',
}

data = {
    'img': {},
    'img_annos': {}
}

# Get image ids (numbers) to identify img + help find annotations
ids = []
for img_path in img_paths:
  s = img_path.split('/')[2]
  id = s.split('.')[0]
  id = int(id)
  ids.append(id)
  data['img'][id] = img_path
  data['img_annos'][id] = []

for mask_path in mask_paths:
  id = int((mask_path.split('/')[3]).split('_')[0])
  data['img_annos'][id].append(mask_path)

for id in ids:
  data['img_annos'][id].sort()

### Partition data into train/val split
ids.sort()
train_ids = ids[:int(len(ids)*0.8)]
val_ids = ids[int(len(ids)*0.8):]

In [None]:
# Dataset generator
class DataGenerator(keras.utils.Sequence):
  def __init__(self, list_IDs, imgs, labels, class_names, n_channels, n_classes, batch_size=32, dim=(512,512), shuffle=False):
    self.list_IDs = list_IDs
    self.imgs = imgs
    self.labels = labels
    self.class_names = class_names
    self.n_channels = n_channels
    self.n_classes = n_classes
    self.batch_size = batch_size
    self.dim = dim
    self.shuffle = shuffle
    self.on_epoch_end()

  def __len__(self):
    # Denotes number of batches per epoch
    return int(np.floor(len(self.list_IDs) / self.batch_size))

  def __getitem__(self, index):
    '''Generate one batch of data'''
    # Generate indexes of the batch
    indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

    # Find list of IDs
    list_IDs_temp = [self.list_IDs[k] for k in indexes]

    # Generate data
    X,y = self.__data_generation(list_IDs_temp)

    # Because pixels are going to be very imbalanced, set a class weights distribution for training
    # Formula for computing class weights from https://www.tensorflow.org/tutorials/structured_data/imbalanced_data#class_weights
    # Sets weights matrix per sample per batch
    # w = {0: 0.1,
    #  1: 10,
    #  2: 10, 
    #  3: 10, 
    #  4: 1,
    #  5: 10,
    #  6: 10,
    #  7: 10,
    #  8: 10,
    #  9: 10,
    #  10: 10,
    #  11: 10,
    #  12: 10,
    #  13: 10,
    #  14: 10,
    #  15: 10,
    #  16: 10,
    #  17: 0.1,
    #  18: 10}
    # weights = np.zeros_like(y)
    # for b in range(self.batch_size):
    #   sample = y[b,:,:,:]
    #   sample = sample.squeeze()
    #   s_w = np.zeros_like(sample)
    #   # print(s_w.shape)
    #   for i in range(19):
    #     # count = np.sum(sample[:,:,i][sample[:,:,i] == 1]) + 1
    #     count = np.sum(sample[sample == i]) + 1
    #     w = count/(256*256)
    #     # w = ( (1 / count) * ((256*256) / 2.0) )
    #     # s_w[:,:,i][sample[:,:,i] == 1] = w[i]
    #     s_w[sample == i] = w
    #   weights[b,] = np.expand_dims(s_w, -1)
    #   # weights[b,] = s_w

    # return X,y,weights
    return X,y

  def on_epoch_end(self):
    '''Updates indexes after each epoch'''
    self.indexes = np.arange(len(self.list_IDs))
    if self.shuffle == True:
      np.random.shuffle(self.indexes)

  def __data_generation(self, list_IDs_temp):
    '''Generate data'''
    X = np.empty((self.batch_size, *self.dim, self.n_channels))
    '''Need to stack all mask images in channel dimension'''
    y = np.zeros((self.batch_size, *self.dim, 1))
    
    for i,ID in enumerate(list_IDs_temp):
      img = cv2.imread(self.imgs[ID])
      img = cv2.resize(img, self.dim)
      # Normalize image
      img = cv2.normalize(img, None, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
      X[i,] = img
      k = 1
      # y1
      for j in range(self.n_classes):
        class_name = self.class_names[j]
        if k < len(self.labels[ID]) and class_name in self.labels[ID][k]:
          mask = cv2.imread(self.labels[ID][k], cv2.IMREAD_GRAYSCALE)
          mask = cv2.resize(mask, self.dim)
          y[i][mask == 255] = j
          # mask_img[:,:,j][mask == 255] = 1
          k += 1
      # y2[i,] = mask_img

      # y2
      # y2 = np.zeros((self.batch_size, *self.dim, self.n_classes))
      # mask_img = np.zeros((*self.dim, self.n_classes))
      # # print(mask_img.shape)
      # for j in range(self.n_classes):
      #   # print(y[i][y[i] == j])
      #   idxs = np.where(y[i] == j)[0]
      #   # print(idxs)
      #   mask_img[...,j][idxs] = 1
      #   # print(mask_img[...,j])
      # y2[i,] = mask_img
        

    # y = np.zeros((self.batch_size, *self.dim, self.n_classes))

    # for i,ID in enumerate(list_IDs_temp):
    #   img = cv2.imread(self.imgs[ID])
    #   img = cv2.resize(img, self.dim)
    #   # Normalize image
    #   img = cv2.normalize(img, None, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
    #   X[i,] = img

    #   mask_img = np.zeros((*self.dim, self.n_classes))
    #   k = 1
    #   for j in range(self.n_classes):
    #     class_name = self.class_names[j]
    #     if k < len(self.labels[ID]) and class_name in self.labels[ID][k]:
    #       mask = cv2.imread(self.labels[ID][k], cv2.IMREAD_GRAYSCALE)
    #       mask = cv2.resize(mask, self.dim)
    #       mask[mask == 255] = 1
    #       # y[i][mask == 255] = j
    #       # mask = np.expand_dims(mask, 2)
    #       mask_img[:,:,j] = mask
    #       k += 1
    #   y[i,] = mask_img

    return X,y

train_generator = DataGenerator(train_ids, data['img'], data['img_annos'], classes, n_channels=3, n_classes=19, batch_size=8, dim=(256,256), shuffle=True)
val_generator = DataGenerator(val_ids, data['img'], data['img_annos'], classes, n_channels=3, n_classes=19, batch_size=8, dim=(256,256), shuffle=True)

In [None]:
from tensorflow.keras import layers

'''
================================
DeepLab implementation from https://keras.io/examples/vision/deeplabv3_plus/
================================
'''

def convolution_block(
    block_input,
    num_filters=256,
    kernel_size=3,
    dilation_rate=1,
    padding="same",
    use_bias=False,
):
    x = layers.Conv2D(
        num_filters,
        kernel_size=kernel_size,
        dilation_rate=dilation_rate,
        padding="same",
        use_bias=use_bias,
        kernel_initializer=keras.initializers.HeNormal(),
    )(block_input)
    x = layers.BatchNormalization()(x)
    return tf.nn.relu(x)


def DilatedSpatialPyramidPooling(dspp_input):
    dims = dspp_input.shape
    x = layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
    x = convolution_block(x, kernel_size=1, use_bias=True)
    out_pool = layers.UpSampling2D(
        size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="bilinear",
    )(x)

    out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
    out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
    out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
    out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)

    x = layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
    output = convolution_block(x, kernel_size=1)
    return output

def DeeplabV3Plus(image_size, num_classes):
    model_input = keras.Input(shape=(image_size, image_size, 3))
    resnet50 = keras.applications.ResNet50(
        weights="imagenet", include_top=False, input_tensor=model_input
    )
    x = resnet50.get_layer("conv4_block6_2_relu").output
    x = DilatedSpatialPyramidPooling(x)

    input_a = layers.UpSampling2D(
        size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),
        interpolation="bilinear",
    )(x)
    input_b = resnet50.get_layer("conv2_block3_2_relu").output
    input_b = convolution_block(input_b, num_filters=48, kernel_size=1)

    x = layers.Concatenate(axis=-1)([input_a, input_b])
    x = convolution_block(x)
    x = convolution_block(x)
    x = layers.UpSampling2D(
        size=(image_size // x.shape[1], image_size // x.shape[2]),
        interpolation="bilinear",
    )(x)
    model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
    # x = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
    # model_output = tf.clip_by_value(x, 0, 1)
    return keras.Model(inputs=model_input, outputs=model_output)


device_name = tf.test.gpu_device_name()
print(device_name)    
with tf.device(device_name):
  model = DeeplabV3Plus(image_size=256, num_classes=19)

# model.summary()

/device:GPU:0
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5


In [None]:
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# loss = keras.losses.CategoricalCrossentropy(from_logits=True)

from keras import backend as K

# https://stackoverflow.com/questions/49284455/keras-custom-function-implementing-jaccard/50832690
def jaccard_loss(y_true, y_pred, smooth=100):
    """ Calculates mean of Jaccard distance as a loss function """
    # y_pred = tf.clip_by_value(y_pred, 0, 1)
    y_pred = tf.math.softmax(y_pred, axis=-1)
    # intersection = tf.reduce_sum(y_true * y_pred, axis=(1,2))
    # sum_ = tf.reduce_sum(y_true + y_pred, axis=(1,2))
    intersection = tf.reduce_sum(y_true * y_pred, axis=-1)
    sum_ = tf.reduce_sum(y_true + y_pred, axis=-1)
    jac = (intersection + smooth) / (sum_ - intersection + smooth)
    jd =  (1 - jac) * smooth
    return tf.reduce_mean(jd)

def dice_loss(y_true, y_pred):
  y_pred = tf.math.softmax(y_pred, axis=-1)
  y_pred = y_pred[...,1:]
  y_true = y_true[...,1:]
  numerator = 2 * tf.reduce_sum(y_true * y_pred, axis=(1,2))
  denominator = tf.reduce_sum(y_true + y_pred, axis=(1,2))
  # numerator = 2 * tf.reduce_sum(y_true * y_pred, axis=-1)
  # denominator = tf.reduce_sum(y_true + y_pred, axis=-1)
  return 1 - (numerator / denominator)

def dice_coef(y_true, y_pred, smooth=1e-7):
    '''
    Dice coefficient for 10 categories. Ignores background pixel label 0
    Pass to model as metric during compile statement
    '''
    y_true_f = K.flatten(K.one_hot(K.cast(y_true, 'int32'), num_classes=19))
    y_pred = tf.math.softmax(y_pred, axis=-1)
    y_pred_f = K.flatten(y_pred)
    intersect = K.sum(y_true_f * y_pred_f, axis=-1)
    denom = K.sum(y_true_f + y_pred_f, axis=-1)
    return K.mean((2. * intersect / (denom + smooth)))

def dice_coef_loss(y_true, y_pred):
    '''
    Dice loss to minimize. Pass to model as loss during compile statement
    '''
    return 1 - dice_coef(y_true, y_pred)

def tversky(y_true, y_pred, smooth=1, alpha=0.7):
    # y_pred = tf.clip_by_value(y_pred, 0, 1)
    y_pred = tf.math.softmax(y_pred, axis=-1)
    true_pos = tf.reduce_sum(y_true * y_pred)
    false_neg = tf.reduce_sum(y_true * (1 - y_pred))
    false_pos = tf.reduce_sum((1 - y_true) * y_pred)
    return (true_pos + smooth) / (true_pos + alpha * false_neg + (1 - alpha) * false_pos + smooth)

def tversky_loss(y_true, y_pred):
    return 1 - tversky(y_true, y_pred)

def focal_tversky_loss(y_true, y_pred, gamma=0.75):
    tv = tversky(y_true, y_pred)
    return K.pow((1 - tv), gamma)

class MeanIOU(tf.keras.metrics.MeanIoU):
    def update_state(self, y_true, y_pred, sample_weight=None):
        # return super().update_state(tf.argmax(y_true, axis=-1), tf.argmax(y_pred, axis=-1), sample_weight)
        return super().update_state(y_true, tf.argmax(y_pred, axis=-1), sample_weight)
iou = MeanIOU(19)

model.compile(
    # optimizer=keras.optimizers.Adam(learning_rate=0.1),
    optimizer='adam',
    # loss=loss,
    # loss=jaccard_loss,
    loss=dice_coef_loss,
    metrics=["accuracy", iou],
)

# Checkpointing
checkpoint_path = 'deeplab_dice'
# checkpoint_path = '/content/gdrive/MyDrive/AdvCV/deeplab_celeb'
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

history = model.fit(train_generator, validation_data=val_generator, epochs=10, callbacks=[cp_callback])

In [None]:
### Prediction tensor (H,W,19) to RGB color coded image (H,W,3)
# from google.colab.patches import cv2_imshow
'''RGB class mapping'''
color_map = {
    0: [0,0,0],
    1: np.flip([165, 198, 239]),
    2: np.flip([168, 31, 5]),
    3: np.flip([31, 215, 43]),
    4: np.flip([255, 0, 0]),
    5: np.flip([221, 7, 188]),
    6: np.flip([40, 19, 179]),
    7: np.flip([153, 245, 26]),
    8: np.flip([39, 10, 112]),
    9: np.flip([221, 130, 242]),
    10: np.flip([99, 55, 72]),
    11: np.flip([28, 231, 244]),
    12: np.flip([162, 174, 79]),
    13: np.flip([139, 60, 131]),
    14: np.flip([246, 176, 47]),
    15: np.flip([223, 237, 127]),
    16: np.flip([135, 66, 130]),
    17: np.flip([255, 0, 0]),
    18: np.flip([213, 149, 107])
}
def color_prediction(prediction):
  color_img = np.zeros((256,256,3), dtype='uint8')
  for i in range(19):
    color_img[prediction == i] = color_map[i]
  return color_img

In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

# model = keras.models.load_model('deeplab_celeb_jaccard')
idxs = np.random.randint(len(val_ids), size=10)
print(idxs)
imgs = []
for i in idxs:
  x = cv2.imread(data['img'][val_ids[i]])
  x = cv2.resize(x, (256,256))
  x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
  imgs.append(x)

  pred_mask = model.predict(np.expand_dims(x, axis=0))
  pred_mask = np.squeeze(pred_mask)
  pred_mask = np.argmax(pred_mask, axis=-1)

  mask_rgb = color_prediction(pred_mask)
  mask_rgb = cv2.cvtColor(mask_rgb, cv2.COLOR_BGR2RGB)
  imgs.append(mask_rgb)

fig = plt.figure(figsize=(20., 20.))
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(5, 4),  # creates 2x2 grid of axes
                 axes_pad=0.05,  # pad between axes in inch.
                 )

for ax, im in zip(grid, imgs):
    # Iterating over the grid returns the Axes.
    ax.imshow(im)
    ax.axes.xaxis.set_visible(False)
    ax.axes.yaxis.set_visible(False)

plt.show()