In [None]:
Arxiv Link to original paper: <a href="https://arxiv.org/abs/1505.04597">U-Net: Convolutional Networks for Biomedical Image Segmentation</a>

In [None]:
from google.colab import drive
drive.mount('/content/drive',force_remount=True)

In [None]:
## Imports
import zipfile
import os
import sys
import random

import matplotlib
from sklearn.manifold import TSNE
import numpy as np
import cv2
import matplotlib.pyplot as plt
from IPython.display import clear_output
import pandas as pd
from sklearn.decomposition import PCA
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow import keras
from keras import layers
import matplotlib.image as mpimg
import glob
from os import listdir
from os.path import isfile, join
from sklearn.metrics import roc_curve, auc,roc_auc_score, f1_score
from sklearn.model_selection import GridSearchCV
from itertools import cycle
import seaborn as sns
import statistics as st
import math
import datetime
from matplotlib.ticker import MaxNLocator
import keras.backend as K
import plotly.express as px 

## Seeding 
seed = 42
#random.seed = seed
np.random.seed = seed
tf.seed = seed

# 1.4 Code

In [None]:
with zipfile.ZipFile('/content/drive/MyDrive/Elizabeth_PhD_Folder/lfe_images/train.zip', 'r') as zip_ref:
    zip_ref.extractall('/content/')

In [None]:
with zipfile.ZipFile('/content/drive/MyDrive/Elizabeth_PhD_Folder/lfe_images/test.zip', 'r') as zip_ref:
    zip_ref.extractall('/content/')

##Plotting Accuracy

In [None]:
# Define some useful functions
class PlotLossAccuracy(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.i = 0
        self.x = []
        self.acc = []
        self.losses = []
        self.val_losses = []
        self.val_acc = []

        self.iou = []
        self.val_iou = []

        self.dice = []
        self.val_dice = []

        self.logs = []

    def on_epoch_end(self, epoch, logs={}):
        
        self.logs.append(logs)
        self.x.append(int(self.i))
        self.losses.append(logs.get('loss'))
        self.val_losses.append(logs.get('val_loss'))
        self.acc.append(logs.get('accuracy'))
        self.val_acc.append(logs.get('val_accuracy'))

        self.iou.append(logs.get('IoU'))
        self.val_iou.append(logs.get('val_IoU'))

        self.dice.append(logs.get('f1_metric'))
        self.val_dice.append(logs.get('val_f1_metric'))
        # print(self.acc)
        # print(self.val_acc)
        # print(logs)
        
        self.i += 1
        
        clear_output(wait=True)
        plt.figure(figsize=(32, 6))
       
        #plt.plot([1, 4])
        plt.subplot(131) 
        plt.tick_params(axis='both', which='major', labelsize=18)
        plt.plot(self.x, self.losses, label="train loss")
        plt.plot(self.x, self.val_losses, label="validation loss")
        plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
        plt.ylabel('loss',fontsize=25)
        plt.xlabel('epoch',fontsize=25)
        plt.title('Model Loss',fontsize=28)
        plt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))
        plt.legend(fontsize=18)

        plt.subplot(132) 
        plt.tick_params(axis='both', which='major', labelsize=18)        
        plt.plot(self.x, self.acc, label="training accuracy")
        plt.plot(self.x, self.val_acc, label="validation accuracy")
        plt.legend(fontsize=18)
        plt.ylabel('accuracy',fontsize=25)
        plt.xlabel('epoch',fontsize=25)
        plt.title('Model Accuracy',fontsize=28)
        plt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))
        plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))


        plt.subplot(133)   
        plt.tick_params(axis='both', which='major', labelsize=18)      
        plt.plot(self.x, self.iou, label="training accuracy")
        plt.plot(self.x, self.val_iou, label="validation accuracy")
        plt.legend(fontsize=18)
        plt.ylabel('IoU',fontsize=25)
        plt.xlabel('epoch',fontsize=25)
        plt.title('IoU',fontsize=28)
        plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
        plt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))
        plt.show();     


## Data Generator

In [None]:
class DataGen(keras.utils.Sequence):
    def __init__(self, ids, path, batch_size, image_w, image_h):
        self.ids = ids
        self.path = path
        self.batch_size = batch_size
        self.image_w = image_w
        self.image_h =image_h
        self.on_epoch_end()
        
    def __load__(self, id_name):
        ## Path
        image_path = os.path.join(self.path, id_name, "images", id_name) + ".npy"
        mask_path = os.path.join(self.path, id_name, "masks/")
        all_masks = os.listdir(mask_path)
        
        ## Reading Image
        image = np.load(image_path, allow_pickle=True)
        resize = tf.keras.Sequential([layers.Resizing(self.image_h, self.image_w)])
        image = resize(image)
        
        
        ## Reading Masks
        for name in all_masks:
            _mask_path = mask_path + name
            _mask_image = np.load(_mask_path,allow_pickle=True)
            _mask_image=np.reshape(_mask_image,(_mask_image.shape[0],_mask_image.shape[1],1))
            resize = tf.keras.Sequential([layers.Resizing(self.image_h, self.image_w)])
            mask = resize(_mask_image)[:,:,0]
            mask=np.where(mask>0, 1, 0).reshape(self.image_h,self.image_w,1)

            
         
            
        #Label
        label_path = os.path.join(self.path, id_name, "label", id_name) + ".npy"
        label = np.load(label_path, allow_pickle=True)

        #Trajectory
        traj_path = os.path.join(self.path, id_name, "traj", id_name) + ".npy"
        traj = np.load(traj_path, allow_pickle=True)
        lat_s, lat_m, lt_s, lt_m = traj[0], traj[1], traj[2], traj[3]
        lat_s_arr=np.full((self.image_h, self.image_w, 1),lat_s)
        lat_m_arr=np.full((self.image_h, self.image_w, 1),lat_m)
        lt_s_arr =np.full((self.image_h, self.image_w, 1),lt_s)
        lt_m_arr=np.full((self.image_h, self.image_w, 1),lt_m)

        #Join
        im_all_channels = np.concatenate([image, lat_s_arr, lat_m_arr,lt_s_arr, lt_m_arr], axis=2)

        return im_all_channels, mask



    def gaussian_noise(self, total_array, mask_array):
        flux=total_array[:,:,0:1]
        pol=total_array[:,:,1:2]
        noise = tf.random.normal(shape=tf.shape(flux), mean=0.0, stddev=0.25,dtype=tf.float32)
        flux=flux+noise
        flux=np.clip(flux, 0, 1)
        pol=pol+noise
        pol=np.clip(pol, 0, 1)
        
        im = tf.concat([flux, pol,total_array[:,:,2:]], axis=2)
        im = np.reshape(im, (1, self.image_h,self.image_w, 6))
        mask_array = np.reshape(mask_array, (1, self.image_h,self.image_w, 1))
      
        return im, mask_array

    def __getitem__(self, index):
        if(index+1)*self.batch_size > len(self.ids):
            self.batch_size = len(self.ids) - index*self.batch_size
        
        files_batch = self.ids[index*self.batch_size : (index+1)*self.batch_size]
        
        im=[]
        mask=[]

        for id_name in files_batch:
            _im_all_channels, _mask = self.__load__(id_name)
            im.append(_im_all_channels)
            mask.append(_mask)
        
        im=np.array(im)
        mask=np.array(mask)
        
        ##Add augmentation
        inds=list(np.arange(len(self.ids)))
        inds_augment = random.choices(inds, k=3)


        #Horizontal Flip
        #h_ind = inds_augment[0]
        #h_im, h_mask = self.__load__(self.ids[h_ind])
        #h_flipped = self.horizontal(h_im, h_mask)
        
        

        #Gaussian Noise
        #g_ind = inds_augment[2]
        #g_im, g_mask = self.__load__(self.ids[g_ind])
        #gaussian_noise = self.gaussian_noise(g_im, g_mask)
        

        #im = np.concatenate([im],axis=0)
        #mask=np.concatenate([mask,gaussian_noise[1]],axis=0)
        
        
        return im, mask
    
    def on_epoch_end(self):
        pass
    
    def __len__(self):
        return int(np.ceil(len(self.ids)/float(self.batch_size)))

In [None]:
class ValidGen(keras.utils.Sequence):
    def __init__(self, ids, path, batch_size, image_w, image_h):
        self.ids = ids
        self.path = path
        self.batch_size = batch_size
        self.image_w = image_w
        self.image_h =image_h
        self.on_epoch_end()
        
    def __load__(self, id_name):
        ## Path
        image_path = os.path.join(self.path, id_name, "images", id_name) + ".npy"
        mask_path = os.path.join(self.path, id_name, "masks/")
        all_masks = os.listdir(mask_path)

        ## Reading Image
        image = np.load(image_path, allow_pickle=True)
        resize = tf.keras.Sequential([layers.Resizing(self.image_h, self.image_w)])
        image = resize(image)

        ## Reading Masks
        for name in all_masks:
            _mask_path = mask_path + name
            _mask_image = np.load(_mask_path,allow_pickle=True)
            _mask_image=np.reshape(_mask_image,(_mask_image.shape[0],_mask_image.shape[1],1))
            resize = tf.keras.Sequential([layers.Resizing(self.image_h, self.image_w)])
            mask = resize(_mask_image)[:,:,0]
            mask=np.where(mask>0, 1, 0).reshape(self.image_h,self.image_w,1)

        #Label
        label_path = os.path.join(self.path, id_name, "label", id_name) + ".npy"
        label = np.load(label_path, allow_pickle=True)

        #trajectory
        traj_path = os.path.join(self.path, id_name, "traj", id_name) + ".npy"
        traj = np.load(traj_path, allow_pickle=True)
        lat_s, lat_m, lt_s, lt_m = traj[0], traj[1], traj[2], traj[3]
        lat_s_arr=np.full((self.image_h, self.image_w, 1),lat_s)
        lat_m_arr=np.full((self.image_h, self.image_w, 1),lat_m)
        lt_s_arr =np.full((self.image_h, self.image_w, 1),lt_s)
        lt_m_arr=np.full((self.image_h, self.image_w, 1),lt_m)

        im_all_channels = np.concatenate([image, lat_s_arr, lat_m_arr,lt_s_arr, lt_m_arr], axis=2)

        return im_all_channels, mask


    def __getitem__(self, index):
        if(index+1)*self.batch_size > len(self.ids):
            self.batch_size = len(self.ids) - index*self.batch_size
        
        files_batch = self.ids[index*self.batch_size : (index+1)*self.batch_size]
        
        im=[]
        mask=[]
        
        for id_name in files_batch:
            _im_all_channels, _mask = self.__load__(id_name)
            im.append(_im_all_channels)
            mask.append(_mask)
        
        im=np.array(im)
        mask=np.array(mask)
        return im, mask
    
    def on_epoch_end(self):
        pass
    
    def __len__(self):
        return int(np.ceil(len(self.ids)/float(self.batch_size)))

## UNet Model

## Different Convolutional Blocks

In [None]:
train_path = "/content/train/"
epochs = 60
batch_size = 16
f = [64,128, 256, 512,1024]


## Training Ids
random_ids = next(os.walk(train_path))[1]

count=3000
ids=[str(i).zfill(3) for i in np.arange(count)]
total_ids = [i for i in ids if i in random_ids]




## Validation Data Size
#Train set is 75% of total, and validation is 1/3 of the train set and so it is 25% of total data
train_label=np.load('/content/drive/MyDrive/Elizabeth_PhD_Folder/lfe_images/train_label.npy',allow_pickle=True)
train_ids, valid_ids = train_test_split(total_ids, test_size=.35, random_state=42,stratify=train_label)

In [None]:
def down_block(x, filters, do,kernel_size=(3, 3), padding='same', strides=1):
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
    dp1 = keras.layers.Dropout(do)(c)
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(dp1)
    p = keras.layers.MaxPool2D((2, 2), (2, 2))(c)

    return c, p

def up_block(x, skip, filters, do,kernel_size=(3,3), padding='same', strides=1):
    us = keras.layers.UpSampling2D((2, 2))(x)
    concat = keras.layers.Concatenate()([us, skip])
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(concat)
    dp1 = keras.layers.Dropout(do)(c)
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(dp1)
    
    return c

def bottleneck(x, filters, do,kernel_size=(3,3), padding='same', strides=1):
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
    dp1 = keras.layers.Dropout(do)(c)
    c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(dp1)
    return c

In [None]:
def UNet(do):
    inputs = keras.layers.Input((image_h, image_w, channels))
    
    p0 = inputs
    c1, p1 = down_block(p0, f[0],do) 
    c2, p2 = down_block(p1, f[1],do) 
    c3, p3 = down_block(p2, f[2],do) 
    c4, p4 = down_block(p3, f[3],do) 
    
    bn = bottleneck(p4, f[4],do)
    
    u4 = up_block(bn, c4, f[3],do) 
    u5 = up_block(u4, c3, f[2],do) 
    u6 = up_block(u5, c2, f[1],do) 
    u7 = up_block(u6, c1, f[0],do) 
    
    
    outputs = keras.layers.Conv2D(1, (1, 1), padding="same", activation="sigmoid")(u7)
    model = keras.models.Model(inputs, outputs)
    return model

## Hyperparameters

In [None]:
def dice_coef(y_true, y_pred, smooth=100):        
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    dice = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    return dice

def dice_coef_loss(y_true, y_pred):
    y_true=tf.cast(y_true, dtype=tf.float32)
    return 1 - dice_coef(y_true, y_pred)

In [None]:
image_w = 128
image_h = 384
channels=6
do=0.4
model = UNet(do)
optname='Adam'
lr =1e-4
opt=tf.keras.optimizers.Adam(learning_rate=lr,name=optname)
binary_iou= tf.keras.metrics.BinaryIoU(name='IoU')
model.summary()
model.compile(optimizer=opt, loss='binary_crossentropy',
              metrics=['accuracy', binary_iou])

## Training the model

In [None]:
root = '/content/drive/MyDrive/Elizabeth_PhD_Folder/Colab_Notebooks/Models/'
max_filter=f[-1]
model_label = 'UNET_'+str(optname) + 'LR' +str(lr)+'_6steps_minfilter'+str(f[0])+'maxfilter'+str(max_filter)+'_binaryloss'

In [None]:
train_gen = DataGen(train_ids, train_path, image_h=image_h, image_w=image_w, batch_size=batch_size)
valid_gen = ValidGen(valid_ids, train_path, image_h=image_h, image_w=image_w, batch_size=batch_size)


train_steps = len(train_ids)//(batch_size)
valid_steps = len(valid_ids)//(batch_size)


pltCallBack = PlotLossAccuracy()

In [None]:
root = '/content/drive/MyDrive/Elizabeth_PhD_Folder/Colab_Notebooks/Models/'
checkpoint_filepath = root+model_label
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_filepath,save_weights_only=False,monitor='val_accuracy',mode='auto',save_best_only=False)
#Fit Model
epochs=60
history = model.fit(train_gen, validation_data=valid_gen, steps_per_epoch=train_steps, validation_steps=valid_steps, epochs=epochs,verbose=1, callbacks=[model_checkpoint_callback,pltCallBack])