In [None]:
# run if using google colab
from google.colab import drive
drive.mount('/content/drive')

import os
os.chdir('/content/drive/MyDrive/FCN')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import numpy as np
import cv2
import tensorflow as tf
from tensorflow import keras
from tensorflow.compat.v1.losses import softmax_cross_entropy
from tensorflow.keras import regularizers, initializers, Input, Model
from tensorflow.keras.layers import MaxPool2D, Conv2D, Conv2DTranspose, Lambda, Dropout, Add, UpSampling2D
from tensorflow.keras.metrics import MeanIoU
import matplotlib.pyplot as plt
import pickle
import os
from sklearn.model_selection import train_test_split

#import data, models

# for auto-reloading external modules
%load_ext autoreload
%autoreload 2

In [None]:
def split_dataset(DATASET_PATH='rats_data', holdout=0.8):
    images = []
    labels = []

    treatments = os.listdir(DATASET_PATH)
    for treatment in treatments:
        if treatment == 'CIC' or treatment == 'PDX':
            days = os.listdir(os.path.join(DATASET_PATH, treatment))
            for day in days:
                animals = os.listdir(os.path.join(DATASET_PATH, treatment, day))
                for animal in animals:
                    path = os.path.join(DATASET_PATH, treatment, day, animal)
                    images.append(os.path.join(path, animal+'.png'))
                    labels.append(os.path.join(path, animal+'_label.png'))
        else:
            doses = os.listdir(os.path.join(DATASET_PATH, treatment))
            for dose in doses:
                days = os.listdir(os.path.join(DATASET_PATH, treatment, dose))
                for day in days:
                    animals = os.listdir(os.path.join(DATASET_PATH, treatment, dose, day))
                    for animal in animals:
                        path = os.path.join(DATASET_PATH, treatment, dose, day, animal)
                        images.append(os.path.join(path, animal+'.png'))
                        labels.append(os.path.join(path, animal+'_label.png'))

    return train_test_split(images, labels, test_size=1-holdout, random_state=np.random.randint(0,1000000))


In [None]:
def vgg16(weight_decay=0, dropout=0.5):
    '''
    VGG16 network
    
    args:
        weight_decay = L2 regularization factor (float), weight_decay=0 by default
        dropout = dropout rate (float), dropout=0.5 by default
        classes = number of classes
    return:
        Keras model
    '''
    
    ##Input as keras tensor
    input = Input(shape=(None, None, 3), name='input')

    ##Block 1 - 64 filters
    x = Conv2D(filters = 64,
                kernel_size=(3,3),
                padding='same',
                activation='relu',
                kernel_regularizer=regularizers.L2(l2=weight_decay),
                name='Conv1-1')(input)

    x = Conv2D(filters=64,
                kernel_size=(3,3),
                padding='same',
                activation='relu',
                kernel_regularizer=regularizers.L2(l2=weight_decay),
                name='Conv1-2')(x)

    x = MaxPool2D(pool_size=(2,2),
                    strides=(2,2),
                    name='Pool1')(x)

    ##Block 2 - 128 filters
    x = Conv2D(filters=128,
                kernel_size=(3,3),
                padding='same',
                activation='relu',
                kernel_regularizer=regularizers.L2(l2=weight_decay),
                name='Conv2-1')(x)

    x = Conv2D(filters=128,
                kernel_size=(3,3),
                padding='same',
                activation='relu',
                kernel_regularizer=regularizers.L2(l2=weight_decay),
                name='Conv2-2')(x)
    
    x = MaxPool2D(pool_size=(2,2),
                    strides=(2,2),
                    name='Pool2')(x)
    
    ##Block 3 - 256 filters
    x = Conv2D(filters=256,
                kernel_size=(3,3),
                padding='same',
                activation='relu',
                kernel_regularizer=regularizers.L2(l2=weight_decay),
                name='Conv3-1')(x)

    x = Conv2D(filters=256,
                kernel_size=(3,3),
                padding='same',
                activation='relu',
                kernel_regularizer=regularizers.L2(l2=weight_decay),
                name='Conv3-2')(x)

    x = Conv2D(filters=256,
                kernel_size=(3,3),
                padding='same',
                activation='relu',
                kernel_regularizer=regularizers.L2(l2=weight_decay),
                name='Conv3-3')(x)

    x = MaxPool2D(pool_size=(2,2),
                    strides=(2,2),
                    name='Pool3')(x)

    ##Block 4 - 512 filters
    x = Conv2D(filters=512,
                kernel_size=(3,3),
                padding='same',
                activation='relu',
                kernel_regularizer=regularizers.L2(l2=weight_decay),
                name='Conv4-1')(x)

    x = Conv2D(filters=512,
                kernel_size=(3,3),
                padding='same',
                activation='relu',
                kernel_regularizer=regularizers.L2(l2=weight_decay),
                name='Conv4-2')(x)

    x = Conv2D(filters=512,
                kernel_size=(3,3),
                padding='same',
                activation='relu',
                kernel_regularizer=regularizers.L2(l2=weight_decay),
                name='Conv4-3')(x)

    x = MaxPool2D(pool_size=(2,2),
                    strides=(2,2),
                    name='Pool4')(x)

    ##Block 5 - 512 filters
    x = Conv2D(filters=512,
                kernel_size=(3,3),
                padding='same',
                activation='relu',
                kernel_regularizer=regularizers.L2(l2=weight_decay),
                name='Conv5-1')(x)

    x = Conv2D(filters=512,
                kernel_size=(3,3),
                padding='same',
                activation='relu',
                kernel_regularizer=regularizers.L2(l2=weight_decay),
                name='Conv5-2')(x)

    x = Conv2D(filters=512,
                kernel_size=(3,3),
                padding='same',
                activation='relu',
                kernel_regularizer=regularizers.L2(l2=weight_decay),
                name='Conv5-3')(x)

    x = MaxPool2D(pool_size=(2,2),
                    strides=(2,2),
                    name='Pool5')(x)

    ## FC --> Convolutionized Fully Connected Layers

    x = Conv2D(filters=4096, 
                kernel_size=(7,7), 
                strides=(1,1), 
                padding='same', 
                activation='relu',
                kernel_regularizer=regularizers.L2(l2=weight_decay), 
                name='conv6')(x)

    x = Dropout(rate=dropout, name='drop-conv6')(x)

    x = Conv2D(filters=4096, 
                kernel_size=(1,1), 
                strides=(1,1), 
                padding='same', 
                activation='relu',
                kernel_regularizer=regularizers.L2(l2=weight_decay), 
                name='conv7')(x)

    x = Dropout(rate=dropout, name='drop-conv7')(x)

    return Model(input, x)



def fcn32s(vgg16, weight_decay=0):
    '''
    32x upsampled
    
    Args:
        vgg16: VGG16 model
        fcn16: FCN16 model
        weight_decay = L2 regularization factor (float), weight_decay=0 by default
    returns:
        keras model
    '''

    x = Conv2D(filters=3, 
                kernel_size=(1,1), 
                strides=(1,1), 
                padding='same', 
                activation='linear',
                kernel_regularizer=regularizers.L2(l2=weight_decay),
                name='score-conv7')(vgg16.get_layer('drop-conv7').output)

    x = UpSampling2D(size=(32,32), interpolation='bilinear', name='upsample-32')(x)

    x = Conv2D(filters=3, 
                kernel_size=(1,1),
                strides=(1,1),
                padding='same',
                activation='linear',
                kernel_regularizer=regularizers.L2(l2=weight_decay),
                name='FCN32s')(x)

    return Model(vgg16.input, x)



def fcn16s(vgg16, fcn32, weight_decay=0):
    '''
    16x upsampled 
    
    Args:
        vgg16: VGG16 custom keras model
        fcn32: FCN32 custom keras model
        weight_decay = L2 regularization factor (float), weight_decay=0 by default
    returns:
        keras model
    '''
    x = UpSampling2D(size=(2,2), interpolation='bilinear')(vgg16.get_layer('drop-conv7').output)

    x = Conv2D(filters=21, 
                kernel_size=(1,1),
                strides=(1,1),
                padding='same',
                activation='linear',
                kernel_regularizer=regularizers.L2(l2=weight_decay),
                name='upsample-conv7')(x)

    y = Conv2D(filters=21, 
                kernel_size=(1,1), 
                strides=(1,1),
                padding='same',  
                activation='linear',
                kernel_initializer=initializers.Zeros(), #Net starts with unmodified predictions
                kernel_regularizer=regularizers.l2(l2=weight_decay) 
                )(vgg16.get_layer('Pool4').output)

    m = Add(name='step4')([x,y]) ##fusion

    m  = UpSampling2D(size=(16,16), interpolation='bilinear', name='FCN16s')(m)

    x = Conv2D(filters=21, 
                kernel_size=(1,1),
                strides=(1,1),
                padding='same',
                activation='linear',
                kernel_regularizer=regularizers.L2(l2=weight_decay),
                name='FCN16s')(m)

    return Model(fcn32.input, m)

    

def fcn8s(vgg16, fcn16, weight_decay=0):
    '''
    8x upsampled
    
    Args:
        vgg16: VGG16 custom keras model
        fcn16: FCN16 custom keras model
        weight_decay = L2 regularization factor (float), weight_decay=0 by default
    returns:
        keras model
    '''

    x = UpSampling2D(size=(2,2), interpolation='bilinear', name='upsampled-step4')(fcn16.get_layer('step4').output)

    x = Conv2D(filters=21, 
                kernel_size=(1,1),
                strides=(1,1),
                padding='same',
                activation='linear',
                kernel_regularizer=regularizers.L2(l2=weight_decay),
                name='upsample-step4')(x)

    y = Conv2D(filters=21, 
                kernel_size=(1,1), 
                strides=(1,1), 
                padding='same', 
                activation='linear', 
                kernel_regularizer=regularizers.l2(l2=weight_decay), 
                )(vgg16.get_layer('Pool3').output)

    m = Add(name='step3')([x,y])

    m = UpSampling2D(size=(8,8), interpolation='bilinear', name='upsampled-step4')(fcn16.get_layer('step4').output)

    m = Conv2D(filters=21, 
                kernel_size=(1,1),
                strides=(1,1),
                padding='same',
                activation='linear',
                kernel_regularizer=regularizers.L2(l2=weight_decay),
                name='upsample-step4')(m)

    return Model(fcn16.input, m)

# Models

In [None]:
## VGG16 base model
vgg_model = vgg16(weight_decay=1e-6, dropout=0.2)

In [None]:
## FCN32
fcn32 = fcn32s(vgg_model, weight_decay=1e-6)

## freeze upsample layer
fcn32.get_layer('FCN32s').trainable=False

In [None]:
## FCN16
fcn16 = fcn16s(vgg_model, fcn32, weight_decay=1e-6)

## freeze upsample layer
fcn16.get_layer('upsample-conv7').trainable=False
fcn16.get_layer('FCN16s').trainable=False

In [None]:
## FCN8
fcn8 = fcn8s(vgg_model, fcn16, weight_decay=1e-6)

## freeze upsample layer
fcn8.get_layer('upsample-step4').trainable=False

# Training

In [None]:
# Some important metrics

# get it from: https://github.com/kevinddchen/Keras-FCN/blob/main/models.py

def crossentropy(y_true, y_pred_onehot):
    '''Custom cross-entropy to handle borders (class = -1).'''
    n_valid = tf.math.reduce_sum(tf.cast(y_true != 255, tf.float32))
    y_true_onehot = tf.cast(np.arange(21) == y_true, tf.float32)
    return tf.reduce_sum(-y_true_onehot * tf.math.log(y_pred_onehot + 1e-7)) / n_valid

In [None]:
def pixelacc(y_true, y_pred_onehot):
    '''Custom pixel accuracy to handle borders (class = -1).'''
    n_valid = tf.math.reduce_sum(tf.cast(y_true != 255, tf.float32))
    y_true = tf.cast(y_true, tf.int32)[..., 0]
    y_pred = tf.argmax(y_pred_onehot, axis=-1, output_type=tf.int32)
    return tf.reduce_sum(tf.cast(y_true == y_pred, tf.float32)) / n_valid

In [None]:
class MyMeanIoU(keras.metrics.MeanIoU):
    '''Custom meanIoU to handle borders (class = -1).'''
    def update_state(self, y_true, y_pred_onehot, sample_weight=None):
        y_pred = tf.argmax(y_pred_onehot, axis=-1)
        ## add 1 so boundary class=0
        y_true = tf.cast(y_true+1, self._dtype)
        y_pred = tf.cast(y_pred+1, self._dtype)
        ## Flatten the input if its rank > 1.
        if y_pred.shape.ndims > 1:
            y_pred = tf.reshape(y_pred, [-1])
        if y_true.shape.ndims > 1:
            y_true = tf.reshape(y_true, [-1])
        ## calculate confusion matrix with one extra class
        current_cm = tf.math.confusion_matrix(
            y_true,
            y_pred,
            self.num_classes+1,
            weights=sample_weight,
            dtype=self._dtype)
        return self.total_cm.assign_add(current_cm[1:, 1:])

In [None]:
## Load model
model = fcn32
model.summary()

In [None]:
## Train and test datasets
X_train, X_test, y_train, y_test = split_dataset('/content/drive/MyDrive/rats_data', holdout=0.8)

In [None]:
X_train_images = []
for image in X_train:
  img = cv2.imread(image)
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  X_train_images.append(img)

In [None]:
X_test_images = []
for image in X_test:
  img = cv2.imread(image)
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  X_test_images.append(img)

In [None]:
y_train_images = []
for image in y_train:
  img = cv2.imread(image)
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  y_train_images.append(img)

In [None]:
y_test_images = []
for image in y_test:
  img = cv2.imread(image)
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  y_test_images.append(img)

In [None]:
#checking the data

n = 4
plt.figure(figsize=(25, 25))
for i in range(n):
  plt.subplot(4, 2, 2*i+1)
  plt.imshow(X_train_images[i], )
  plt.subplot(4, 2, 2*i+2)
  plt.imshow(y_train_images[i])

In [None]:
X_train_images = tf.convert_to_tensor(X_train_images)
y_train_images = tf.convert_to_tensor(y_train_images)
X_test_images = tf.convert_to_tensor(X_test_images)
y_test_images = tf.convert_to_tensor(y_test_images)

In [None]:
## compile
opt = keras.optimizers.Adam(learning_rate=1e-4)
loss = crossentropy
metrics = [loss,
           pixelacc,
           MyMeanIoU(num_classes=2, name='meanIoU')]
model.compile(optimizer=opt, loss=loss, metrics=metrics)

In [None]:
history = model.fit(X_train_images, y_train_images, batch_size=1, epochs=20, verbose=1)