# AutoEncoder for multitask 

1. Imports the unlabeled data
2. Train the autoencoder model
3. Save the encoder weights for further use in the multitask model


## Requirements : 
| Library | version | version name |
| :---        |    :----:   |   ------:  |
| cudatoolkit |   9.0 | h13b8566_0 |
| cudnn |                     7.6.5 |                cuda9.0_0 |  
|ipykernel|                 5.3.4|            py37h5ca1d4c_0|    
|ipython |                  7.18.1|           py37h5ca1d4c_0|    
|jupyter_client|            6.1.7|                      py_0|    
|jupyter_core|              4.6.3|                    py37_0|    
|keras-applications|        1.0.8|                      py_1|  
|keras-preprocessing|       1.1.0|                      py_1 | 
|matplotlib|                3.3.3|                    pypi_0|    
|matplotlib-base|           3.3.2|            py37h817c723_0|  
|nibabel|                   3.2.1|                    pypi_0|    
|numpy|                     1.19.2|           py37h54aff64_0|
|opencv|                    3.4.2|            py37h6fd60c2_1|  
|pandas|                    1.1.3|            py37he6710b0_0|  
|pillow|                    8.0.1|            py37he98fc37_0|  
|py-xgboost|                0.90|             py37he6710b0_1|    
|python|                    3.7.9|                h7579374_0|  
|scikit-image     |         0.17.2|                   pypi_0|    
|scikit-learn     |         0.23.2|           py37h0573a6f_0|    
|scipy            |         1.5.2  |          py37h0b6359f_0|  
|seaborn          |         0.11.0 |                    py_0|  
|tensorboard     |          1.14.0 |          py37hf484d3e_0|  
|tensorflow     |           1.14.0 |         gpu_py37hae64822_0|  
|tensorflow-gpu|            1.14.0 |              h0d30ee6_0|  

In [None]:
# Importing important files
import pandas as pd
import os

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf

from tensorflow import keras

from sklearn.preprocessing import LabelEncoder


from tensorflow.keras import backend as K
import nibabel as nib
import cv2
import time
from skimage.transform import resize
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit, StratifiedKFold
from tensorflow.keras.utils import to_categorical

from prep_data import get_roi, get_all_subjects, get_subject, get_subject_list, pad_to_shape, remove_padding
from PatchGenerator import PatchGenerator

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, CSVLogger, EarlyStopping, TensorBoard
from tensorflow.keras.applications.resnet50 import preprocess_input
import models

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, CSVLogger, EarlyStopping, TensorBoard
# import models
from tensorflow.keras import Model
from tensorflow.keras.applications import ResNet50,DenseNet169,InceptionResNetV2,VGG16
from tensorflow.keras.layers import Conv2DTranspose

from tensorflow.keras.layers import Flatten,concatenate,Input,Activation, GlobalAveragePooling2D,GlobalMaxPooling2D, Dense, Conv2D, MaxPooling2D, UpSampling2D, Dropout, BatchNormalization, Lambda, AveragePooling2D

In [None]:
# importing the unlabeled data
x_train_1 = np.load('../Data/Nimhans_new_complete_3mod_X.npy')


In [None]:
print(x_train_1.shape)

In [None]:
x_train_a = np.append(x_train, x_train_1, axis=0)

In [None]:
x_train_a.shape

In [None]:
plt.imshow(x_train_a[1,:])

In [None]:
def autoEncoder(ip):
    """
      Defines the encoder and decoder of the autoencoder
    """
    # Encodeer
    conv0 = Conv2D(64, (3, 3), activation='relu', padding='same', input_shape = (128, 128, 1), name="EncoderLayer_0_Conv2D_0")(ip)
    conv0 = Conv2D(64, (3, 3), activation='relu', padding='same', name="EncoderLayer_0_Conv2D_1")(conv0)
    pool0 = MaxPooling2D((2, 2), padding='same', name="EncoderLayer_0_Pool")(conv0)
    norm0 = BatchNormalization(name="EncoderLayer_0_BatchNorm")(pool0)
    
    conv1 = Conv2D(128, (3, 3), activation='relu', padding='same', name="EncoderLayer_1_Conv2D_0")(norm0)
    conv1 = Conv2D(128, (3, 3), activation='relu', padding='same', name="EncoderLayer_1_Conv2D_1")(conv1)
    pool1 = MaxPooling2D((2, 2), padding='same', name="EncoderLayer_1_Pool")(conv1)
    norm1 = BatchNormalization(name="EncoderLayer_1_BatchNorm")(pool1)
    
    conv2 = Conv2D(256, (3, 3), activation='relu', padding='same', name="EncoderLayer_2_Conv2D_0")(norm1)
    conv2 = Conv2D(256, (3, 3), activation='relu', padding='same', name="EncoderLayer_2_Conv2D_1")(conv2)
    pool2 = MaxPooling2D((2, 2), padding='same', name="EncoderLayer_2_Pool")(conv2)
    norm2 = BatchNormalization(name="EncoderLayer_2_BatchNorm")(pool2)
    
    conv3 = Conv2D(512, (3, 3), activation='relu', padding='same', name="BottleNeckLayer_3_Conv2D_0")(norm2)
    norm3 = BatchNormalization(name="EncoderLayer_3_BatchNorm")(conv3)
    drop3 = Dropout(0.5, name="BottleNeckLayer_3_Dropout")(conv3)
    
    # Latent Space
    encoded = MaxPooling2D((2, 2), padding='same',name= "BottleNeckLayer_0_Pool")(conv3)

    # at this point the representation is (8,8,512) i.e. 128-dimensional
    # Decoder
    deconv0 = Conv2DTranspose(512, (3, 3), activation='relu', padding='same', strides = 2, name="DecoderLayer_0_DeConv_0")(encoded)
    merge0 = concatenate([drop3,deconv0], name="DecoderLayer_0_concatenate")
    deconv0 = Conv2DTranspose(512, (3, 3), activation='relu', padding='same', name="DecoderLayer_0_DeConv_1")(merge0)
    denorm0 = BatchNormalization(name="DecoderLayer_0_BatchNorm")(deconv0)
    
    deconv1 = Conv2DTranspose(256, (3, 3), activation='relu', padding='same', name="DecoderLayer_1_DeConv_0")(denorm0)
    merge1 = concatenate([norm2,deconv1], name="DecoderLayer_1_concatenate")
    deconv1 = Conv2DTranspose(256, (3, 3), activation='relu', padding='same', strides = 2, name="DecoderLayer_1_DeConv_1")(merge1)
    denorm1 = BatchNormalization(name="DecoderLayer_1_BatchNorm")(deconv1)
    
    deconv2 = Conv2DTranspose(128, (3, 3), activation='relu', padding='same', name="DecoderLayer_2_DeConv_0")(denorm1)
    merge2 = concatenate([norm1,deconv2], name="DecoderLayer_2_concatenate")
    deconv2 = Conv2DTranspose(128, (3, 3), activation='relu', padding='same', strides = 2, name="DecoderLayer_2_DeConv_1")(merge2)
    denorm2 = BatchNormalization(name="DecoderLayer_2_BatchNorm")(deconv2)
    
    deconv3 = Conv2DTranspose(64, (3, 3), activation='relu', padding = 'same', name="DecoderLayer_3_DeConv_0")(denorm2)
    merge3 = concatenate([norm0,deconv3], name="DecoderLayer_3_concatenate")
    deconv3 = Conv2DTranspose(64, (3, 3), activation='relu', padding = 'same', strides = 2, name="DecoderLayer_3_DeConv_1")(merge3)
    denorm3 = BatchNormalization(name="DecoderLayer_3_BatchNorm")(deconv3)
    
    decoded = Conv2D(1,(3, 3), activation='sigmoid', padding='same', name="FinalConv")(denorm3)

    return decoded

In [None]:
ip1 = Input(shape=(128,128,1))
AutoEncoder1 = Model(inputs=ip1,outputs=autoEncoder(ip1))
AutoEncoder1.summary()

In [None]:
from tensorflow.keras.utils import plot_model
plot_model(AutoEncoder1, to_file='model_plot.png', show_shapes=True, show_layer_names=True)

In [None]:
# Defining the dice loss coefficient
def dice_coef(y_true, y_pred, smooth=1):
    """
    Dice = (2*|X & Y|)/ (|X|+ |Y|)
         =  2*sum(|A*B|)/(sum(A^2)+sum(B^2))
    ref: https://arxiv.org/pdf/1606.04797v1.pdf
    
    """
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    return (2. * intersection + smooth) / (K.sum(K.square(y_true),-1) + K.sum(K.square(y_pred),-1) + smooth)

def dice_coef_loss(y_true, y_pred):
    return 1-dice_coef(y_true, y_pred)

In [None]:
AutoEncoder1.compile(optimizer='adam', loss=dice_coef_loss, metrics=['accuracy'])

In [None]:
train_steps = len(train_datagen)
test_steps = len(test_datagen)
print(train_steps, test_steps)

In [None]:
filename=os.path.join('logs','AutoEncoder_2_complete.csv')
filepath=os.path.join('weights','AutoEncoder_2_complete.hdf5')
csv_log = CSVLogger(filename, separator=',', append=True)
checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True)
rl = ReduceLROnPlateau(monitor='acc',patience=5,min_delta=0.001,cooldown=5,factor=0.1)
tb = TensorBoard('./logs',histogram_freq=0)
callbacks_list = [csv_log,
                  checkpoint,
                  #rl,
                  tb
                 ]

In [None]:
filepath1 = os.path.join('weights','AutoEncoder_2_all.hdf5')
if os.path.exists(filepath1):
    AutoEncoder1.load_weights(filepath1, by_name=True)

In [None]:
X_train = x_train_a.astype('float32') / 255.
X_test = x_test.astype('float32') / 255.

In [None]:
plt.imshow(X_train[3,:])

In [None]:
epochs = 1000
AutoEncoder1.fit(X_train, X_train,
                epochs = epochs,
                batch_size = 64,
                shuffle=True,
                validation_data=(X_test, X_test),
                callbacks=callbacks_list)

In [None]:
decoded_imgs = AutoEncoder1.predict(X_test)

n = 10
plt.figure(figsize=(20, 4))
for i in range(1, n + 1):
    # Display original
    ax = plt.subplot(2, n, i)
    plt.imshow(X_test[i+120,:])
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # Display reconstruction
    ax = plt.subplot(2, n, i + n)
    plt.imshow(decoded_imgs[i+120])
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

In [None]:
AutoEncoder1.save('../Model/TCGA_AutoEncoder_complete_1.h5')

AutoEncoder1.save_weights('weights/TCGA_AutoEncoder_complete_1.hdf5')

In [None]:
x_train_a_ = x_train_1.astype('float32') / 255.

In [None]:
decoded_imgs_mid = AutoEncoder1.predict(x_train_a_)

n = 10
plt.figure(figsize=(20, 4))
for i in range(1, n + 1):
    # Display original
    ax = plt.subplot(2, n, i)
    plt.imshow(x_train_a_[i+25,:])
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # Display reconstruction
    ax = plt.subplot(2, n, i + n)
    plt.imshow(decoded_imgs_mid[i+25])
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()