# Importing Libraries

In [27]:
import pickle
import numpy as np
import pandas as pd
from PIL import Image
import albumentations as A
from IPython.display import SVG
import matplotlib.pyplot as plt
%matplotlib inline
import os, re, sys, random, shutil, cv2

import segmentation_models as sm
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow import keras
from keras import backend as K
from keras.models import Model
from keras.optimizers import Adam, Nadam
from keras import applications, optimizers
from keras.applications import InceptionResNetV2
from keras.applications import resnet

from keras.preprocessing.image import ImageDataGenerator
from keras.utils import model_to_dot, plot_model, to_categorical
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, CSVLogger, LearningRateScheduler
from keras.layers import Input, Conv2D, BatchNormalization, Activation, MaxPooling2D, Conv2DTranspose, concatenate, ZeroPadding2D, Dropout

# Data Augmentation

In [3]:
images_dir = './data/images/'
masks_dir = './data/masks/'

In [8]:
file_names = np.sort(os.listdir(images_dir))
file_names = np.char.split(file_names, '.')
filenames = np.array([])
for i in range(len(file_names)):
    filenames = np.append(filenames, file_names[i][0])

In [10]:
transform_1 = A.Compose([
    A.HorizontalFlip(p=1.0),
    A.VerticalFlip(p=1.0),
    A.Rotate(limit=[60, 300], p=1.0, interpolation=cv2.INTER_NEAREST),
    A.OneOf([
        A.GridDistortion(distort_limit= 0.2,p=0.5),
        A.OpticalDistortion(distort_limit=1, shift_limit=0.5, interpolation=cv2.INTER_NEAREST, p=0.5),
    ], p=1.0),
], p=1.0)

transform_2 = A.Compose([
   A.RandomBrightnessContrast(brightness_limit=[-0.05, 0.20], contrast_limit=0.2, p=1.0),
    A.OneOf([
        A.CLAHE (clip_limit=1.5, tile_grid_size=(8, 8), p=0.5),
    ], p=1.0),
], p=1.0)

In [13]:
for i in range(8):
    print(f"loop {i+1} ...")
    for file in filenames:
        img = cv2.imread(images_dir+file+'.png')
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(masks_dir+'mask'+file[5:]+'.png')
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)

        transformed = transform_1(image=img, mask=mask)
        transformed_image = transformed['image']
        transformed_mask = transformed['mask']

        cv2.imwrite('./data/images/aug_image{}_'.format(str(i+1))+file[5:]+'.png', cv2.cvtColor(transformed_image, cv2.COLOR_BGR2RGB))
        cv2.imwrite('./data/masks/aug_mask{}_'.format(str(i+1))+file[5:]+'.png', cv2.cvtColor(transformed_mask, cv2.COLOR_BGR2RGB))

for i in range(8, 10):
    print(f"loop {i+1} ...")
    for file in filenames:
        img = cv2.imread(images_dir+file+'.png')
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(masks_dir+'mask'+file[5:]+'.png')
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)

        transformed = transform_2(image=img)
        transformed_image = transformed['image']
        transformed_mask = mask

        cv2.imwrite('./data/images/aug_image{}_'.format(str(i+1))+file[5:]+'.png', cv2.cvtColor(transformed_image, cv2.COLOR_BGR2RGB))
        cv2.imwrite('./data/masks/aug_mask{}_'.format(str(i+1))+file[5:]+'.png', cv2.cvtColor(transformed_mask, cv2.COLOR_BGR2RGB))

loop 1 ...
loop 2 ...
loop 3 ...
loop 4 ...
loop 5 ...
loop 6 ...
loop 7 ...
loop 8 ...
loop 9 ...
loop 10 ...


# Importing Data

In [16]:
image_dataset = []
mask_dataset = []

for path in [images_dir, masks_dir]:
    print(f"Loading {path.split('/')[2]} ...")
    for image_path in os.listdir(path):

        image = cv2.imread(os.path.join(path, image_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (512, 512))
        
        if path.endswith('images/'):
            image = image / 255.0
            image_dataset.append(image)
        elif path.endswith('masks/'):
            mask_segment = np.zeros(image.shape, dtype=np.uint8)
            # each number refers to the severity of burns
            mask_segment[image < 32] = 0
            mask_segment[image >= 32] = 1
            mask_segment[image >= 95] = 2
            mask_segment[image >= 159] = 3
            mask_segment[image >=223] = 4
            mask_segment = mask_segment[:,:,0]
            mask_dataset.append(mask_segment)


image_dataset = np.array(image_dataset)
mask_dataset = np.array(mask_dataset)
mask_dataset = np.expand_dims(mask_dataset, axis=3)

Loading images ...
Loading masks ...


In [17]:
total_classes = len(np.unique(mask_dataset))

X_main = image_dataset
y_main = to_categorical(mask_dataset, num_classes=total_classes)

In [21]:
X_train, X_test, y_train, y_test = train_test_split(X_main, y_main, test_size=0.15, random_state=100)

print(X_train.shape)
print(X_test.shape)
print(y_train.shape)
print(y_test.shape)

(617, 512, 512, 3)
(109, 512, 512, 3)
(617, 512, 512, 5)
(109, 512, 512, 5)


# Building the model

In [28]:
def multi_unet_model(n_classes=5, image_height=512, image_width=512, image_channels=3):

  inputs = Input((image_height, image_width, image_channels))

  source_input = inputs

  c1 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(source_input)
  c1 = Dropout(0.2)(c1)
  c1 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c1)
  p1 = MaxPooling2D((2,2))(c1)

  c2 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p1)
  c2 = Dropout(0.2)(c2)
  c2 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c2)
  p2 = MaxPooling2D((2,2))(c2)

  c3 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p2)
  c3 = Dropout(0.2)(c3)
  c3 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c3)
  p3 = MaxPooling2D((2,2))(c3)

  c4 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p3)
  c4 = Dropout(0.2)(c4)
  c4 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c4)
  p4 = MaxPooling2D((2,2))(c4)

  c5 = Conv2D(256, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p4)
  c5 = Dropout(0.2)(c5)
  c5 = Conv2D(256, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c5)

  u6 = Conv2DTranspose(128, (2,2), strides=(2,2), padding="same")(c5)
  u6 = concatenate([u6, c4])
  c6 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u6)
  c6 = Dropout(0.2)(c6)
  c6 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c6)

  u7 = Conv2DTranspose(64, (2,2), strides=(2,2), padding="same")(c6)
  u7 = concatenate([u7, c3])
  c7 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u7)
  c7 = Dropout(0.2)(c7)
  c7 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c7)

  u8 = Conv2DTranspose(32, (2,2), strides=(2,2), padding="same")(c7)
  u8 = concatenate([u8, c2])
  c8 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u8)
  c8 = Dropout(0.2)(c8)
  c8 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c8)

  u9 = Conv2DTranspose(16, (2,2), strides=(2,2), padding="same")(c8)
  u9 = concatenate([u9, c1], axis=3)
  c9 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u9)
  c9 = Dropout(0.2)(c9)
  c9 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c9)

  outputs = Conv2D(n_classes, (1,1), activation="softmax")(c9)

  model = Model(inputs=[inputs], outputs=[outputs])
  return model


In [33]:
model = multi_unet_model()

In [34]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 512, 512, 16  448         ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 dropout (Dropout)              (None, 512, 512, 16  0           ['conv2d[0][0]']                 
                                )                                                             

In [35]:
def exponential_decay(lr0, s):
    def exponential_decay_fn(epoch):
        return lr0 * 0.1 **(epoch / s)
    return exponential_decay_fn

exponential_decay_fn = exponential_decay(0.0001, 60)

lr_scheduler = LearningRateScheduler(
    exponential_decay_fn,
    verbose=1
)

checkpoint = ModelCheckpoint(
    filepath = 'Model-UNet.h5',
    save_best_only = True, 
    monitor = 'val_loss', 
    mode = 'auto', 
    verbose = 1
)

earlystop = EarlyStopping(
    monitor = 'val_loss', 
    min_delta = 0.001, 
    patience = 12, 
    mode = 'auto', 
    verbose = 1,
    restore_best_weights = True
)

csvlogger = CSVLogger(
    filename= "model_training.csv",
    separator = ",",
    append = False
)

callbacks = [checkpoint, earlystop, csvlogger, lr_scheduler]

def iou_coef(y_true, y_pred):
    y_true_flatten = K.flatten(y_true)
    y_pred_flatten = K.flatten(y_pred)
    intersection = K.sum(y_true_flatten * y_pred_flatten)
    final_coef_value = (intersection + 1.0) / (K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + 1.0)
    return final_coef_value

Generating Loss Function
- Total Loss = (Dice loss + (1*Focal Loss)

In [None]:
weights = [0.2, 0.2, 0.2, 0.2, 0.2]

dice_loss = sm.losses.DiceLoss(class_weights = weights)

focal_loss = sm.losses.CategoricalFocalLoss()

total_loss = dice_loss + (1 * focal_loss)

In [None]:
K.clear_session()

model.compile(optimizer=Adam(learning_rate = 0.0001), loss='categorical_crossentropy', metrics=["accuracy", iou_coef])

history = model.fit(
    X_train, y_train, 
    batch_size=16,
    validation_data = (X_test, y_test), 
    epochs = 100,
    callbacks=callbacks,
    verbose=1
)