In [None]:
import glob
import os
import shutil
from PIL import Image 
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import nibabel as nib

from modules.scandata import MriScan, MriSlice, TumourSegmentation, ScanType, ScanPlane

In [None]:
import numpy as np
import os
import PIL
import PIL.Image
import tensorflow as tf
#import tensorflow_datasets as tfds
from sklearn.utils.class_weight import compute_class_weight

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

In [None]:
tf.config.list_logical_devices('TPU')

In [None]:
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local')
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.TPUStrategy(cluster_resolver)


In [None]:
batch_size = 64
img_height = 240
img_width = 240
data_dir = os.path.join('data','UPENN-GBM','slice_classification_common_stratify','train')


In [None]:
train_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    color_mode="rgba",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

In [None]:
val_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    color_mode="rgba",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

In [None]:
class_names = train_ds.class_names
print(class_names)

In [None]:
 # Calculate class weights ofr weighting accuracy
ds_classes = []
for _, batch_classes in train_ds:
    ds_classes.append(batch_classes.numpy())

ds_classes = np.concatenate(ds_classes)

class_weight = compute_class_weight(
    class_weight = 'balanced',
    classes = np.unique(ds_classes),
    y=ds_classes
)

class_weight = dict(zip(np.unique(ds_classes), class_weight))


In [None]:
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

In [None]:
normalization_layer = layers.Rescaling(1./(2**8-1))

In [None]:
num_classes = len(class_names)

In [None]:
margin = 8
scaled_height = img_height - 2*margin
scaled_width = img_width - 2*margin

In [None]:
# Build layers for model


In [None]:

with strategy.scope():
    crop_layer = tf.keras.layers.Cropping2D(margin)
    #rescale_initial = tf.keras.layers.Rescaling(1./127.5, offset=-1)
    rescale_initial = tf.keras.layers.Rescaling(1./255)
    conv_4to3_channel = tf.keras.layers.Conv2D(3,1,padding='same', activation='tanh')
    trained_base_model = tf.keras.applications.MobileNetV2(
        input_shape=(scaled_width,scaled_height,3),
        include_top=False,
        weights='imagenet'
    )
    global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
    prediction_layer = tf.keras.layers.Dense(num_classes)

    trained_base_model.trainable = False
    
    inputs = tf.keras.Input(shape=(img_width, img_height, 4))
    x = crop_layer(inputs)
    x = rescale_initial(x)
    x = conv_4to3_channel(x)
    x = trained_base_model(x, training=False)
    x = global_average_layer(x)
    x = tf.keras.layers.Dropout(0.2)(x)
    outputs = prediction_layer(x)
   
    earlystopping = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
         patience=5,
         min_delta=0.001)
    
    model_fixed_base = tf.keras.Model(inputs, outputs)
    model_fixed_base.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy']
    )

In [None]:
model_fixed_base.summary()

In [None]:
small_model_epochs=80
history_model_fixed_base = model_fixed_base.fit(
  train_ds,
  validation_data=val_ds,
  epochs=small_model_epochs,
  class_weight=class_weight,
  callbacks=[earlystopping],
)

In [None]:
layer_outputs = [layer.output for layer in model_fixed_base.layers[:4]]
vis_model = tf.keras.models.Model(
    inputs=model_fixed_base.input, 
    outputs=layer_outputs
)

In [None]:
#((activations[-1][1,:,:,:]+1)*127.5).astype('uint8')

In [None]:
plt.figure(figsize=(10, 40))
batch = train_ds.take(1)
#activations = vis_model.predict(batch)
for images, labels in batch:
  for i in range(64):
    image = np.expand_dims(images[i], axis=0)
    activation = vis_model.predict(image)
    ax = plt.subplot(16, 4, i + 1)
    plt.imshow(((activation[-1][0,:,:,:]+1)*127.5).astype('uint8'))
    plt.title(class_names[labels[i]], fontsize=6)
    plt.axis("off")

In [None]:
acc = history_model_fixed_base.history['accuracy']
val_acc = history_model_fixed_base.history['val_accuracy']

loss = history_model_fixed_base.history['loss']
val_loss = history_model_fixed_base.history['val_loss']

epochs_range = range(small_model_epochs)

plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

In [None]:
trained_base_model.trainable = True

In [None]:
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in trained_base_model.layers[:fine_tune_at]:
  layer.trainable = False

In [None]:
with strategy.scope():
  model_fixed_base.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.00001),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy']
  )

In [None]:
fine_tuning_epochs=100
total_epochs = small_model_epochs + fine_tuning_epochs
history_fine_tuning = model_fixed_base.fit(
  train_ds,
  validation_data=val_ds,
  epochs=total_epochs,
  initial_epoch=history_model_fixed_base.epoch[-1],
  class_weight=class_weight,
  callbacks=[earlystopping],
)

In [None]:
layer_outputs = [layer.output for layer in model_fixed_base.layers[:4]]
vis_model = tf.keras.models.Model(
    inputs=model_fixed_base.input, 
    outputs=layer_outputs
)

In [None]:
plt.figure(figsize=(10, 40))
for images, labels in batch:
  for i in range(64):
    image = np.expand_dims(images[i], axis=0)
    activation = vis_model.predict(image)
    ax = plt.subplot(16, 4, i + 1)
    plt.imshow(((activation[-1][0,:,:,:]+1)*127.5).astype('uint8'))
    plt.title(class_names[labels[i]], fontsize=6)
    plt.axis("off")

In [None]:
acc_fine = history_fine_tuning.history['val_accuracy']
loss_fine = history_fine_tuning.history['val_loss']
plt.plot(acc_fine)

In [None]:
trained_base_model.trainable = True
for layer in trained_base_model.layers:
    print(layer.trainable)

In [None]:
with strategy.scope():
  model_fixed_base.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.00001),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy']
  )
  

In [None]:
full_relax_epochs=100
total_epochs += full_relax_epochs
history_fine_tuning = model_fixed_base.fit(
  train_ds,
  validation_data=val_ds,
  epochs=total_epochs,
  initial_epoch=history_fine_tuning.epoch[-1],
  class_weight=class_weight,
  callbacks=[earlystopping],
)

In [None]:
layer_outputs = [layer.output for layer in model_fixed_base.layers[:4]]
vis_model = tf.keras.models.Model(
    inputs=model_fixed_base.input, 
    outputs=layer_outputs
)

plt.figure(figsize=(10, 40))
#batch = train_ds.take(1)
#activations = vis_model.predict(batch)
for images, labels in batch:
  for i in range(64):
    image = np.expand_dims(images[i], axis=0)
    activation = vis_model.predict(image)
    ax = plt.subplot(16, 4, i + 1)
    plt.imshow(((activation[-1][0,:,:,:]+1)*127.5).astype('uint8'))
    plt.title(class_names[labels[i]], fontsize=6)
    plt.axis("off")

In [None]:
for ch in range(3):

    print(
        (((activation[-1][:,:,:,ch]+1)*127.5).astype('uint8')).min(), 
        (((activation[-1][:,:,:,ch]+1)*127.5).astype('uint8')).max(), 
    )