<a href="https://colab.research.google.com/github/FR-Schwartz/IDS705_Team10/blob/main/10/auto_encoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [23]:
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

Found GPU at: /device:GPU:0


In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from __future__ import print_function
import keras
from keras.models import Sequential, Model
from tensorflow.keras.utils import to_categorical
from keras.layers import Dense, Conv2D, MaxPooling2D, Dropout, Flatten, Input

In [25]:
from google.colab import drive
import os
import pickle
import tensorflow_gcs_config
import shutil
import numpy as np

import tensorflow as tf
from matplotlib import pyplot as plt
import nibabel as nib
import tensorflow_hub as hub
drive.mount('/content/MyDrive/')
os.chdir('/content/MyDrive/MyDrive/IDS705_Final') #change to file path on your disk

Drive already mounted at /content/MyDrive/; to attempt to forcibly remount, call drive.mount("/content/MyDrive/", force_remount=True).


In [28]:
#Helper Functions
def parse_tfrecord(example):
  """
  This function helps in parsing tfrecord files when creating a TF Dataset object
  """
  feature = {'image': tf.io.FixedLenFeature([240, 240, 155, 4], tf.float32),
             'label': tf.io.FixedLenFeature([240, 240, 155], tf.int64)}
  parsed_example = tf.io.parse_single_example(example, feature)
  return parsed_example

def get_image_and_label(features):
  """
  Extract Image and Label Object from tfrecord files
  """
  image, label = features['image'], features['label']
  return image, label

def get_image_and_label_auto(features):
  """
  Extract Image and Label Object from tfrecord files
  """
  image = features['image']
  return image, image

def get_dataset(tfrecord_names):
  """
  Create TF dataset files that can be fed into model functions
  """
  dataset = (tf.data.TFRecordDataset(tfrecord_names)
             .map(parse_tfrecord)
             .map(get_image_and_label_auto))

  return dataset

In [29]:
def encoder(input_layer):
  # Elements in layer represent in this order batchSize, height, width, channels
  print("*****Encoder*****", input_layer)

  x4 = tf.keras.layers.BatchNormalization()(input_layer)
  x4 = tf.keras.layers.Conv3D( 4, 3, strides=(1,1,1), activation='relu', padding='same')(x4)
  x4 = tf.keras.layers.BatchNormalization()(x4)
  x4 = tf.keras.layers.Conv3D( 4, 3, strides=(1,1,1), activation='relu', padding='same')(x4)
  x4 = tf.keras.layers.BatchNormalization()(x4)

  print("x1", x4.shape)

  x8 = tf.keras.layers.Conv3D( 8, 3, strides=(2,2,2), activation='relu', padding='same')(x4)
  x8 = tf.keras.layers.BatchNormalization()(x8)
  x8 = tf.keras.layers.Conv3D( 8, 3, strides=(1,1,1), activation='relu', padding='same')(x8)
  x8 = tf.keras.layers.BatchNormalization()(x8)

  print("x2", x8.shape)

  x16 = tf.keras.layers.Conv3D(16, 3, strides=(2,2,2), activation='relu', padding='same')(x8)
  x16 = tf.keras.layers.BatchNormalization()(x16)
  x16 = tf.keras.layers.Conv3D(16, 3, strides=(1,1,1), activation='relu', padding='same')(x16)
  x16 = tf.keras.layers.BatchNormalization()(x16)

  print("x3", x16.shape)

  x32 = tf.keras.layers.Conv3D(32, 3, strides=(2,2,2), activation='relu', padding='same')(x16)
  x32 = tf.keras.layers.BatchNormalization()(x32)
  x32 = tf.keras.layers.Conv3D(32, 3, strides=(1,1,1), activation='relu', padding='same')(x32)
  x32 = tf.keras.layers.BatchNormalization()(x32)

  print("x4", x32.shape)

  x64 = tf.keras.layers.Conv3D(64, 3, strides=(2,2,2), activation='relu', padding='same')(x32)
  x64 = tf.keras.layers.BatchNormalization()(x64)
  x64 = tf.keras.layers.Conv3D(64, 3, strides=(1,1,1), activation='relu', padding='same')(x64)
  x64 = tf.keras.layers.BatchNormalization()(x64)
  
  print("x5", x64.shape)

  x128 = tf.keras.layers.Conv3D(128, 3, strides=(2,2,2), activation='relu', padding='same')(x64)
  x128 = tf.keras.layers.BatchNormalization()(x128)
  x128 = tf.keras.layers.Conv3D(128, 3, strides=(1,1,1), activation='relu', padding='same')(x128)
  x128 = tf.keras.layers.BatchNormalization()(x128)

  print("x6", x128.shape)

  return x128


In [30]:
def decoder(encoded_layer):
  print("*****Decoder*****", encoded_layer)

  x64 = tf.keras.layers.Conv3DTranspose(64, 3, strides=(1,1,1), activation='relu')(encoded_layer)
  x64 = tf.keras.layers.Conv3DTranspose(64, 3, strides=(2,2,2), activation='relu', padding='same')(x64)
  x64 = tf.keras.layers.Cropping3D(cropping=((1,2), (1,2), (1,1)))(x64)
  x64 = tf.keras.layers.Cropping3D(cropping=((1,1), (1,1), (1,1)))(x64)
  print("x1", x64.shape)

  
  x32 = tf.keras.layers.Conv3DTranspose(32, 3, strides=(1,1,1), activation='relu')(x64)
  x32 = tf.keras.layers.Conv3DTranspose(32, 3, strides=(2,2,2), activation='relu', padding='same')(x32)
  x32 = tf.keras.layers.Cropping3D(cropping=((1,1), (1,1), (1,1)))(x32)
  x32 = tf.keras.layers.Cropping3D(cropping=((1,1), (1,1), (1,1)))(x32)

  print("x2", x32.shape)

  x16 = tf.keras.layers.Conv3DTranspose(16, 3, strides=(1,1,1), activation='relu')(x32)
  x16 = tf.keras.layers.Conv3DTranspose(16, 3, strides=(2,2,2), activation='relu', padding='same')(x16)
  x16 = tf.keras.layers.Cropping3D(cropping=((1,1), (1,1), (1,1)))(x16)
  x16 = tf.keras.layers.Cropping3D(cropping=((1,1), (1,1), (1,2)))(x16)

  print("x4", x16.shape)

  x8 = tf.keras.layers.Conv3DTranspose(8, 3, strides=(1,1,1), activation='relu')(x16)
  x8 = tf.keras.layers.Conv3DTranspose(8, 3, strides=(2,2,2), activation='relu', padding='same')(x8)
  x8 = tf.keras.layers.Cropping3D(cropping=((1,1), (1,1), (1,1)))(x8)
  x8 = tf.keras.layers.Cropping3D(cropping=((1,1), (1,1), (1,1)))(x8)

  print("x5", x8.shape)

  x4 = tf.keras.layers.Conv3DTranspose(4, 3, strides=(1,1,1), activation='relu')(x8)
  x4 = tf.keras.layers.Conv3DTranspose(4, 3, strides=(2,2,2), activation='relu', padding='same')(x4)
  x4 = tf.keras.layers.Cropping3D(cropping=((1,1), (1,1), (1,1)))(x4)
  x4 = tf.keras.layers.Cropping3D(cropping=((1,1), (1,1), (1,2)))(x4)



  print("x6", x4.shape)

  return x4
  


In [31]:
# Combine Encoder and Deocder layers
from keras.models import Model, Input

input_layer = tf.keras.layers.Input(shape=(240,240,155,4))
e_l = encoder(input_layer)
output = decoder(e_l)


autoencoder = Model(inputs = input_layer, outputs = output)

*****Encoder***** KerasTensor(type_spec=TensorSpec(shape=(None, 240, 240, 155, 4), dtype=tf.float32, name='input_2'), name='input_2', description="created by layer 'input_2'")
x1 (None, 240, 240, 155, 4)
x2 (None, 120, 120, 78, 8)
x3 (None, 60, 60, 39, 16)
x4 (None, 30, 30, 20, 32)
x5 (None, 15, 15, 10, 64)
x6 (None, 8, 8, 5, 128)
*****Decoder***** KerasTensor(type_spec=TensorSpec(shape=(None, 8, 8, 5, 128), dtype=tf.float32, name=None), name='batch_normalization_25/FusedBatchNormV3:0', description="created by layer 'batch_normalization_25'")
x1 (None, 15, 15, 10, 64)
x2 (None, 30, 30, 20, 32)
x4 (None, 60, 60, 39, 16)
x5 (None, 120, 120, 78, 8)
x6 (None, 240, 240, 155, 4)


In [32]:
import keras.backend as k

class DiceLoss(tf.keras.losses.Loss):
    def __init__(self, smooth=1e-6, gama=2):
        super(DiceLoss, self).__init__()
        self.name = 'NDL'
        self.smooth = smooth
        self.gama = gama

    def call(self, y_true, y_pred):

        inter = k.sum(y_true[:,:,-1]*y_pred[:,:,-1])
        return (2*inter + 1)/(k.sum(y_true[:,:,-1]) + k.sum(y_pred[:,:,-1]) + 1)

In [33]:
lr=1e-3
autoencoder.compile(loss=DiceLoss(), optimizer=tf.keras.optimizers.Adam(learning_rate=lr))
autoencoder.summary()

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 240, 240, 155, 4  0         
                             )]                                  
                                                                 
 batch_normalization_13 (Bat  (None, 240, 240, 155, 4)  16       
 chNormalization)                                                
                                                                 
 conv3d_12 (Conv3D)          (None, 240, 240, 155, 4)  436       
                                                                 
 batch_normalization_14 (Bat  (None, 240, 240, 155, 4)  16       
 chNormalization)                                                
                                                                 
 conv3d_13 (Conv3D)          (None, 240, 240, 155, 4)  436       
                                                           

In [96]:
#Create Train / Val / Test Split
subfolders = os.listdir("Data/Train")
np.random.seed(101)
split = np.random.choice(["Train","Val","Test"], len(subfolders), p=[0.6, 0.2, 0.2])
train_ids = [subfolders[i] for i,v in enumerate(split) if v=="Train"]
val_ids = [subfolders[i] for i,v in enumerate(split) if v=="Val"]
test_ids = [subfolders[i] for i,v in enumerate(split) if v=="Test"]

In [97]:
#Create dataset objects
train_dataset = get_dataset([f'Data/train_tf/{sf}.tfrecord' for sf in train_ids])
test_dataset = get_dataset([f'Data/train_tf/{sf}.tfrecord' for sf in test_ids])
val_dataset = get_dataset([f'Data/train_tf/{sf}.tfrecord' for sf in val_ids])
mini_dataset = get_dataset([f'Data/train_tf/{sf}.tfrecord' for sf in subfolders[100:228]])
minival_dataset = get_dataset([f'Data/train_tf/{sf}.tfrecord' for sf in subfolders[500:508]])

In [98]:
batchsize = 4
shufflesize = 8
mini_dataset_trainable = mini_dataset.shuffle(shufflesize).batch(batchsize)
minival_dataset_validable = minival_dataset.batch(batchsize)
train_dataset_trainable = train_dataset.shuffle(shufflesize).batch(batchsize)
val_dataset_validable = val_dataset.batch(batchsize)

In [99]:
from keras.callbacks import TensorBoard


In [100]:
epochs = 5
auto_enc = autoencoder.fit(train_dataset_trainable, epochs = epochs, batch_size = 4, shuffle = False, \
                           validation_data=val_dataset_validable, callbacks=[TensorBoard(log_dir='/tmp/autoencoder')])

loss = auto_enc.history['loss']
val_loss = auto_enc.history['val_loss']
epochs = range(epochs)

plt.figure()
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'go', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()
plt.close()

Epoch 1/5


UnknownError: ignored