<a href="https://colab.research.google.com/github/FR-Schwartz/IDS705_Team10/blob/main/10_code/auto_encoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
import os
import pickle
import tensorflow_gcs_config
import shutil
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
import nibabel as nib
import tensorflow_hub as hub
drive.mount('/content/MyDrive/')
os.chdir('/content/MyDrive/MyDrive/IDS705_Final') #change to file path on your disk

Mounted at /content/MyDrive/


In [2]:
#Create Train / Val / Test Split
subfolders = os.listdir("Data/Train")
np.random.seed(101)
split = np.random.choice(["Train","Val","Test"], len(subfolders), p=[0.6, 0.2, 0.2])
train_ids = [subfolders[i] for i,v in enumerate(split) if v=="Train"]
val_ids = [subfolders[i] for i,v in enumerate(split) if v=="Val"]
test_ids = [subfolders[i] for i,v in enumerate(split) if v=="Test"]

In [3]:
#Helper Functions
def parse_tfrecord(example):
  """
  This function helps in parsing tfrecord files when creating a TF Dataset object
  """
  feature = {'image': tf.io.FixedLenFeature([240, 240, 155, 4], tf.float32),
             'label': tf.io.FixedLenFeature([240, 240, 155], tf.int64)}
  parsed_example = tf.io.parse_single_example(example, feature)
  return parsed_example

def get_image_and_label(features):
  """
  Extract Image and Label Object from tfrecord files
  """
  image, label = features['image'], features['label']
  return image, label

def get_dataset(tfrecord_names):
  """
  Create TF dataset files that can be fed into model functions
  """
  dataset = (tf.data.TFRecordDataset(tfrecord_names)
             .map(parse_tfrecord)
             .map(get_image_and_label))

  return dataset

In [4]:
#Create dataset objects
train_dataset = get_dataset([f'/content/MyDrive/MyDrive/IDS705_Final/Data/train_tf/{sf}.tfrecord' for sf in train_ids])
test_dataset = get_dataset([f'/content/MyDrive/MyDrive/IDS705_Final/Data/train_tf/{sf}.tfrecord' for sf in test_ids])
val_dataset = get_dataset([f'/content/MyDrive/MyDrive/IDS705_Final/Data/train_tf/{sf}.tfrecord' for sf in val_ids])
mini_dataset = get_dataset([f'/content/MyDrive/MyDrive/IDS705_Final/Data/train_tf/{sf}.tfrecord' for sf in subfolders[100:228]])
minival_dataset = get_dataset([f'/content/MyDrive/MyDrive/IDS705_Final/Data/train_tf/{sf}.tfrecord' for sf in subfolders[500:508]])

In [5]:
def enco(Input_shape):
  print("*****Encoder*****", Input_shape)

  x4 = tf.keras.layers.Conv3D( 4, 3, strides=(1,1,1), activation='relu', padding='same')(Input_shape)
  x4 = tf.keras.layers.BatchNormalization()(x4)
  x4 = tf.keras.layers.Conv3D( 4, 3, strides=(1,1,1), activation='relu', padding='same')(x4)
  x4 = tf.keras.layers.BatchNormalization()(x4)

  x8 = tf.keras.layers.Conv3D( 8, 3, strides=(2,2,2), activation='relu', padding='same')(x4)
  x8 = tf.keras.layers.BatchNormalization()(x8)
  x8 = tf.keras.layers.Conv3D( 8, 3, strides=(1,1,1), activation='relu', padding='same')(x8)
  x8 = tf.keras.layers.BatchNormalization()(x8)

  x16 = tf.keras.layers.Conv3D(16, 3, strides=(2,2,2), activation='relu', padding='same')(x8)
  x16 = tf.keras.layers.BatchNormalization()(x16)
  x16 = tf.keras.layers.Conv3D(16, 3, strides=(1,1,1), activation='relu', padding='same')(x16)
  x16 = tf.keras.layers.BatchNormalization()(x16)

  x32 = tf.keras.layers.Conv3D(32, 3, strides=(2,2,2), activation='relu', padding='same')(x16)
  x32 = tf.keras.layers.BatchNormalization()(x32)
  x32 = tf.keras.layers.Conv3D(32, 3, strides=(1,1,1), activation='relu', padding='same')(x32)
  x32 = tf.keras.layers.BatchNormalization()(x32)

  x64 = tf.keras.layers.Conv3D(64, 3, strides=(2,2,2), activation='relu', padding='same')(x32)
  x64 = tf.keras.layers.BatchNormalization()(x64)
  x64 = tf.keras.layers.Conv3D(64, 3, strides=(1,1,1), activation='relu', padding='same')(x64)
  x64 = tf.keras.layers.BatchNormalization()(x64)
  
  x128 = tf.keras.layers.Conv3D(128, 3, strides=(2,2,2), activation='relu', padding='same')(x64)
  x128 = tf.keras.layers.BatchNormalization()(x128)
  x128 = tf.keras.layers.Conv3D(128, 3, strides=(1,1,1), activation='relu', padding='same')(x128)
  encoded_layer = tf.keras.layers.BatchNormalization()(x128)

  return encoded_layer


In [6]:
def deco(encoded_layer):
  print("*****Decoder*****", encoded_layer)

  x128 = tf.keras.layers.Conv3DTranspose(128, 3, strides=(2,2,2), activation='relu', padding='same')(encoded_layer)
  x128 = tf.keras.layers.BatchNormalization()(x128)
  x128 = tf.keras.layers.Conv3DTranspose(128, 3, strides=(1,1,1), activation='relu', padding='same')(x128)
  x128 = tf.keras.layers.BatchNormalization()(x128)

  x64 = tf.keras.layers.Conv3DTranspose(64, 3, strides=(2,2,2), activation='relu', padding='same')(x128)
  x64 = tf.keras.layers.BatchNormalization()(x64)
  x64 = tf.keras.layers.Conv3DTranspose(64, 3, strides=(1,1,1), activation='relu', padding='same')(x64)
  x64 = tf.keras.layers.BatchNormalization()(x64)

  x32 = tf.keras.layers.Conv3DTranspose(32, 3, strides=(2,2,2), activation='relu', padding='same')(x64)
  x32 = tf.keras.layers.BatchNormalization()(x32)
  x32 = tf.keras.layers.Conv3DTranspose(32, 3, strides=(1,1,1), activation='relu', padding='same')(x32)
  x32 = tf.keras.layers.BatchNormalization()(x32)

  x16 = tf.keras.layers.Conv3DTranspose(16, 3, strides=(2,2,2), activation='relu', padding='same')(x32)
  x16 = tf.keras.layers.BatchNormalization()(x16)
  x16 = tf.keras.layers.Conv3DTranspose(16, 3, strides=(1,1,1), activation='relu', padding='same')(x16)
  x16 = tf.keras.layers.BatchNormalization()(x16)

  x8 = tf.keras.layers.Conv3DTranspose( 8, 3, strides=(2,2,2), activation='relu', padding='same')(x16)
  x8 = tf.keras.layers.BatchNormalization()(x8)
  x8 = tf.keras.layers.Conv3DTranspose( 8, 3, strides=(1,1,1), activation='relu', padding='same')(x8)
  x8 = tf.keras.layers.BatchNormalization()(x8)

  x4 = tf.keras.layers.BatchNormalization()(x8)
  x4 = tf.keras.layers.Conv3DTranspose( 4, 3, strides=(1,1,1), activation='relu', padding='same')(x8)
  x4 = tf.keras.layers.BatchNormalization()(x4)
  x4 = tf.keras.layers.Conv3DTranspose( 4, 3, strides=(1,1,1), activation='relu', padding='same')(x4)
  output = tf.keras.layers.BatchNormalization()(x4)

  return output



In [7]:
# Combine Encoder and Deocder layers
from keras.models import Model, Input

input_img = Input(shape=(240,240,155,4))
e_l = enco(input_img)
output = deco(e_l)


autoencoder = Model(inputs = input_img, outputs = output)

*****Encoder***** KerasTensor(type_spec=TensorSpec(shape=(None, 240, 240, 155, 4), dtype=tf.float32, name='input_1'), name='input_1', description="created by layer 'input_1'")
*****Decoder***** KerasTensor(type_spec=TensorSpec(shape=(None, 8, 8, 5, 128), dtype=tf.float32, name=None), name='batch_normalization_11/FusedBatchNormV3:0', description="created by layer 'batch_normalization_11'")


In [8]:
# Compile the Model

autoencoder.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), loss='SparseCategoricalCrossentropy')
autoencoder.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 240, 240, 155, 4  0         
                             )]                                  
                                                                 
 conv3d (Conv3D)             (None, 240, 240, 155, 4)  436       
                                                                 
 batch_normalization (BatchN  (None, 240, 240, 155, 4)  16       
 ormalization)                                                   
                                                                 
 conv3d_1 (Conv3D)           (None, 240, 240, 155, 4)  436       
                                                                 
 batch_normalization_1 (Batc  (None, 240, 240, 155, 4)  16       
 hNormalization)                                                 
                                                             

In [9]:
batchsize = 4
shufflesize = 8
mini_dataset_trainable = mini_dataset.shuffle(shufflesize).batch(batchsize)
minival_dataset_validable = minival_dataset.batch(batchsize)
train_dataset_trainable = train_dataset.shuffle(shufflesize).batch(batchsize)
val_dataset_validable = val_dataset.batch(batchsize)

In [None]:
autoencoder.fit(mini_dataset_trainable, epochs = 2, batch_size = 4, shuffle = False, verbose=10, validation_data=minival_dataset_validable)

Epoch 1/2


In [None]:
# #DICE Coefficent definition used as a metric in compilation
# import keras.backend as k
# def dice_coef(y_pred, y_true):
#     inter = k.sum(y_true[:,:,-1]*y_pred[:,:,-1])
#     return (2*inter + 1)/(k.sum(y_true[:,:,-1]) + k.sum(y_pred[:,:,-1]) + 1)