# Transfer learning for the multimodal learning and segmentation
#### This notebook trains autoencoder on stereozoom data and then transfers the weights to the U-net for multimodal learning and segmentation of OCT images which will be used to calculate the healing score

## workflow:
- 1. [install packages](#install-packages)
- 2. [Libraries imports](#libraries-imports)
- 3. [Data imports for autoencoder](#data-imports-for-autoencoder)
- 4. [Train autoencoder](#train-autoencoder)
- 5. [Transfer weights to unet and training](#transfer-weights-to-unet-and-training)
- 6. [Plotting the predictions](#plotting-the-predictions)

### install packages

In [None]:
! pip install patchify
! pip install segmentation_models

### Libraries imports

In [None]:
import os
import cv2
import numpy as np
import glob
import cv2
import random
import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K
import segmentation_models as sm

from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from keras.models import load_model
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import class_weight
from patchify import patchify, unpatchify
from keras.models import Model
from tqdm import tqdm 
from keras.metrics import MeanIoU
from matplotlib import pyplot as plt
from sklearn.utils import class_weight
from patchify import patchify, unpatchify
from tensorflow.keras.utils import normalize
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import normalize, img_to_array, to_categorical
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Conv2DTranspose, BatchNormalization, Dropout, Lambda


### Data imports for autoencoder

In [None]:
# specify parameters and load stereozoom data
SIZE=256
n_classes = 4
img_data=[]
path1 = '/content/drive/MyDrive/Image_dataset/image'
files=os.listdir(path1)
for i in tqdm(files):
    img=cv2.imread(path1+'/'+i,1)   #Change 0 to 1 for color images
    img=cv2.resize(img,(SIZE, SIZE))
    img_data.append(img_to_array(img))

In [None]:
# reshape data into a numpy array
img_array = np.reshape(img_data, (len(img_data), SIZE, SIZE, 3))
img_array = img_array.astype('float32') / 255.
img_array2 = img_array[:576]

### Train autoencoder

In [None]:
# define double convolution blocks
def conv_block(input, num_filters):
    x = Conv2D(num_filters, 3, padding="same")(input)
    x = BatchNormalization()(x)   #Not in the original network. 
    x = Activation("relu")(x)

    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)  #Not in the original network
    x = Activation("relu")(x)

    return x

# define encoder blocks
def encoder_block(input, num_filters):
    x = conv_block(input, num_filters)
    p = MaxPool2D((2, 2))(x)
    return x, p   

# define decoder blocks
def decoder_block(input, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
    x = conv_block(x, num_filters)
    return x

# assemble encoder part of the network
def build_encoder(input_image):
    #inputs = Input(input_shape)

    s1, p1 = encoder_block(input_image, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)
    
    encoded = conv_block(p4, 1024) #Bridge
    
    return encoded

# assemble decoder part of the network
def build_decoder(encoded):
    d1 = decoder_block(encoded, 512)
    d2 = decoder_block(d1, 256)
    d3 = decoder_block(d2, 128)
    d4 = decoder_block(d3, 64)
    
    decoded = Conv2D(3, 3, padding="same", activation="sigmoid")(d4)
    return decoded

In [None]:
# assemble the autoencoder network
def build_autoencoder(input_shape):
    input_img = Input(shape=input_shape)
    autoencoder = Model(input_img, build_decoder(build_encoder(input_img)))
    return(autoencoder)

In [None]:
# deine decoder blocks for U-Net
def decoder_block_for_unet(input, skip_features, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

In [None]:
# build the U-Net network with same dimentions as the autoencoder
def build_unet(input_shape):
    inputs = Input(input_shape)

    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)

    b1 = conv_block(p4, 1024) #Bridge

    d1 = decoder_block_for_unet(b1, s4, 512)
    d2 = decoder_block_for_unet(d1, s3, 256)
    d3 = decoder_block_for_unet(d2, s2, 128)
    d4 = decoder_block_for_unet(d3, s1, 64)

    
    outputs = Conv2D(n_classes, (1, 1), activation='softmax')(d4)  #Binary (can be multiclass)

    model = Model(inputs, outputs, name="U-Net")
    print(model.summary())
    return model

In [None]:
# compile the autoencoder model
autoencoder_model=build_autoencoder(img.shape)
autoencoder_model.compile(optimizer='adam', loss='mean_squared_error', metrics=['accuracy'])
print(autoencoder_model.summary())

In [None]:
# train to get weights for the U-Net
history = autoencoder_model.fit(img_array2, img_array2,
        epochs=100, verbose=1)

In [None]:
# save the weights
autoencoder_model.save('autoencoder_multimodal_100epochs.h5')

In [None]:
# load the weights
autoencoder_model = load_model("autoencoder_multimodal_100epochs.h5", compile=False)
       
# check the reconstructed images
num=random.randint(0, len(img_array2)-1)
test_img = np.expand_dims(img_array[num], axis=0)
pred = autoencoder_model.predict(test_img)

plt.subplot(1,2,1)
plt.imshow(test_img[0])
plt.title('Original')
plt.subplot(1,2,2)
plt.imshow(pred[0].reshape(SIZE,SIZE,3))
plt.title('Reconstructed')
plt.show()

### transfer weights to unet and training

In [None]:
#Now define encoder model only, without the decoder part. 
input_shape = (256, 256, 3)
input_img = Input(shape=input_shape)

encoder = build_encoder(input_img)
encoder_model = Model(input_img, encoder)
print(encoder_model.summary())

num_encoder_layers = len(encoder_model.layers) #35 layers in our encoder. 

In [None]:
#Get weights for the 35 layers from trained autoencoder model and assign to our new encoder model 
for l1, l2 in zip(encoder_model.layers[:35], autoencoder_model.layers[0:35]):
    l1.set_weights(l2.get_weights())

#Verify if the weights are the same between autoencoder and encoder only models. 
autoencoder_weights = autoencoder_model.get_weights()[0][1]
encoder_weights = encoder_model.get_weights()[0][1]

#Save encoder weights for future comparison
np.save('pretrained_encoder-weights.npy', encoder_weights )


#Check the output of encoder_model on a test image
#Should be of size 16x16x1024 for our model
temp_img = cv2.imread('/content/drive/MyDrive/Image_dataset/image/106.png',1)
temp_img = cv2.resize(temp_img,(256,256))
temp_img = temp_img.astype('float32') / 255.
temp_img = np.expand_dims(temp_img, axis=0)
temp_img_encoded=encoder_model.predict(temp_img)


In [None]:
#Now let us define a Unet with same encoder part as out autoencoder. 
#Then load weights from the original autoencoder for the first 35 layers (encoder)
input_shape = (256, 256, 3)
unet_model = build_unet(input_shape)

#Print layer names for each model to verify the layers....
#First 35 layers should be the same in both models. 
unet_layer_names=[]
for layer in unet_model.layers:
    unet_layer_names.append(layer.name)

autoencoder_layer_names = []
for layer in autoencoder_model.layers:
    autoencoder_layer_names.append(layer.name)
    
#Make sure the first 35 layers are the same. Remember that the exct names of the layers will be different.

#Set weights to encoder part of the U-net (first 35 layers)
for l1, l2 in zip(unet_model.layers[:35], autoencoder_model.layers[0:35]):
    l1.set_weights(l2.get_weights())

unet_model.compile('Adam', loss=sm.losses.categorical_focal_jaccard_loss, metrics=[sm.metrics.iou_score])
#unet_model.compile(optimizer=Adam(lr = 1e-3), loss='binary_crossentropy', metrics=['accuracy'])
unet_model.summary()
print(unet_model.output_shape)

unet_model.save('unet_model_weights.h5')

In [None]:
# specify parameters for OCT images
SIZE_X = 512
SIZE_Y = 128*13 #1712

patch_size = (256, 256,3)
window_step = 128 # 128
n_classes=4 

batch_size = 8

In [None]:
# load OCT images
train_images = []

for img_path in sorted(glob.glob(os.path.join("/content/drive/MyDrive/OCT_DATASET/Images", "*.BMP"))):
    img = cv2.imread(img_path, 1)
    p_imgs = patchify(img, patch_size, step=window_step).reshape(-1,*patch_size) # split image into patches
    # print(img.shape)      
    # img = cv2.resize(img, (SIZE_Y, SIZE_X))
    train_images.append(p_imgs)
train_images = np.concatenate(train_images)

In [None]:
train_images.shape # check the shape of the images

In [None]:
patch_size=(256,256) # size of the patches

In [None]:
# load OCT masks
train_masks = [] 
for mask_path in sorted(glob.glob(os.path.join("/content/drive/MyDrive/OCT_DATASET/Mask", "*.png"))):
    mask = cv2.imread(mask_path,cv2.IMREAD_UNCHANGED)
    p_imgs = patchify(mask, patch_size, step=window_step).reshape(-1,*patch_size)  # split into patches 
    #mask = cv2.resize(mask, (SIZE_Y, SIZE_X)) 
    train_masks.append(p_imgs)

In [None]:
train_masks = np.expand_dims((np.array(train_masks)), 3) /255. # normalize masks

In [None]:
train_masks.shape # check the shape of the masks

In [None]:
# remove the patches with only one class (this will remove the patches with only background, because the other patches will almost always have atleast 2 classes)
idx_lst = []
for i, single_mask in enumerate(train_masks):
  if len(np.unique(single_mask)) > 1:
    idx_lst.append(i)
len(idx_lst), train_images.shape, train_images[idx_lst].shape

In [None]:
# take the patches with more than one class
train_images, train_masks = train_images[idx_lst], train_masks[idx_lst]

In [None]:
# test split
X1, X_test, y1, y_test = train_test_split(train_images,train_masks, test_size = 0.10, random_state = 0)

In [None]:
# train validation split
X_train, X_do_not_use, y_train, y_do_not_use = train_test_split(X1, y1, test_size = 0.2, random_state = 0)

In [None]:
# check shapes
X_train.shape, y_train.shape

In [None]:
# convert training masks to categorical
train_masks_cat = to_categorical(y_train, num_classes=n_classes)
y_train_cat = train_masks_cat.reshape((y_train.shape[0], y_train.shape[1], y_train.shape[2], n_classes))

In [None]:
# convert testing masks to categorical
test_masks_cat = to_categorical(y_test, num_classes=n_classes)
y_test_cat = test_masks_cat.reshape((y_test.shape[0], y_test.shape[1], y_test.shape[2], n_classes))

In [None]:
# define segmentation model parameters
sm.set_framework('tf.keras')
sm.framework()

In [None]:
# loss functions and metrics
dice_loss = sm.losses.DiceLoss(class_weights=np.array([0.25, 0.25, 0.25, 0.25])) 
focal_loss = sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (1 * focal_loss)

metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]

In [None]:
# training parameters
n_classes=4
activation='softmax'

LR = 0.0001
optim = tf.keras.optimizers.Adam(LR)

In [None]:
# initialize with pretrained weights from autoencoder
input_shape = (256, 256, 3)
random_wt_unet_model = build_unet(input_shape)

random_wt_unet_model_weights = random_wt_unet_model.get_weights()[0][1]

pre_trained_unet_model = build_unet(input_shape)
pre_trained_unet_model.load_weights('unet_model_weights.h5')
pre_trained_unet_model_weights = pre_trained_unet_model.get_weights()[0][1]

#Load previously saved pretrained encoder weights just for comparison with the unet weights (Sanity check)
pretrained_encoder_wts = np.load('pretrained_encoder-weights.npy')

if pre_trained_unet_model_weights.all() == pretrained_encoder_wts.all():
    print("Both weights are identical")
else: 
    print("Something wrong, weghts are different")

In [None]:
# compile the models
random_wt_unet_model.compile('Adam', loss=sm.losses.categorical_focal_jaccard_loss, metrics=[sm.metrics.iou_score])
pre_trained_unet_model.compile('Adam', loss=sm.losses.categorical_focal_jaccard_loss, metrics=[sm.metrics.iou_score])

In [None]:
# train the models
batch_size=16

random_wt_unet_model_history = random_wt_unet_model.fit(X_train, y_train, 
                    verbose=1,
                    batch_size = batch_size,
                    validation_data=(X_do_not_use, y_do_not_use), 
                    shuffle=False,
                    epochs=25)

#### Mean IoU score

In [None]:
# get the predictions for the test data
y_pred=model.predict(X_test)
y_pred_argmax=np.argmax(y_pred, axis=3)

# check mean Intersection over Union score (IoU)
n_classes = 4
IOU_keras = MeanIoU(num_classes=n_classes)  
IOU_keras.update_state(y_test[:,:,:,0], y_pred_argmax)
print("Mean IoU =", IOU_keras.result().numpy())

#### IoU metric per class

In [None]:
values = np.array(IOU_keras.get_weights()).reshape(n_classes, n_classes)
print(values)
class1_IoU = values[0,0]/(values[0,0] + values[0,1] + values[0,2] + values[0,3] + values[1,0]+ values[2,0]+ values[3,0])
class2_IoU = values[1,1]/(values[1,1] + values[1,0] + values[1,2] + values[1,3] + values[0,1]+ values[2,1]+ values[3,1])
class3_IoU = values[2,2]/(values[2,2] + values[2,0] + values[2,1] + values[2,3] + values[0,2]+ values[1,2]+ values[3,2])
class4_IoU = values[3,3]/(values[3,3] + values[3,0] + values[3,1] + values[3,2] + values[0,3]+ values[1,3]+ values[2,3])

print("IoU for class1 is: ", class1_IoU)
print("IoU for class2 is: ", class2_IoU)
print("IoU for class3 is: ", class3_IoU)
print("IoU for class4 is: ", class4_IoU)

In [None]:
# get the predictions for the test data
y_pred=model.predict(X_test)
y_pred_argmax=np.argmax(y_pred, axis=3)

In [None]:
# check mean Intersection over Union score (IoU)
n_classes = 4
IOU_keras = MeanIoU(num_classes=n_classes)  
IOU_keras.update_state(y_test[:,:,:,0], y_pred_argmax)
print("Mean IoU =", IOU_keras.result().numpy())

### Plotting the predictions


In [None]:
# plotting

#getting and preping prediction on test image
test_img_number = random.randint(0, len(X_test)) # take a random image patch
test_img = X_test[test_img_number] 
ground_truth=y_test[test_img_number] # mask for the test image patch
test_img_norm=test_img[:,:,0][:,:,None] # normalizing the test image patch
test_img_input=np.expand_dims(test_img_norm, 0) # expanding the dimensions for prediction
prediction = (model.predict(test_img_input))
predicted_img=np.argmax(prediction, axis=3)[0,:,:] 

# plot the test image patch
plt.figure(figsize=(12, 8))
plt.subplot(231)
plt.title('Testing Image')
plt.imshow(test_img[:,:,0], cmap='gray')
plt.subplot(232)
plt.title('Testing Label')
plt.imshow(ground_truth[:,:,0], cmap='jet')
plt.subplot(233)
plt.title('Prediction on test image')
plt.imshow(predicted_img, cmap='jet')
plt.show()