In [None]:
import glob
import os
import shutil
from PIL import Image 
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import nibabel as nib

from modules.scandata import MriScan, MriSlice, TumourSegmentation, ScanType, ScanPlane

In [None]:
import numpy as np
import os
import PIL
import PIL.Image
import tensorflow as tf
#import tensorflow_datasets as tfds
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import train_test_split

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from IPython.display import clear_output

In [None]:
tf.config.list_logical_devices('TPU')

In [None]:
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local')
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.TPUStrategy(cluster_resolver)


In [None]:
batch_size = 64
img_height = 240
img_width = 240
data_dir = os.path.join('data','UPENN-GBM','slice_classification_common_stratify','train')


In [None]:
train_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    color_mode="rgba",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

In [None]:
val_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    color_mode="rgba",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

In [None]:
class_names = train_ds.class_names
print(class_names)

In [None]:
 # Calculate class weights ofr weighting accuracy
ds_classes = []
for _, batch_classes in train_ds:
    ds_classes.append(batch_classes.numpy())

ds_classes = np.concatenate(ds_classes)

class_weight = compute_class_weight(
    class_weight = 'balanced',
    classes = np.unique(ds_classes),
    y=ds_classes
)

class_weight = dict(zip(np.unique(ds_classes), class_weight))


In [None]:
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

In [None]:
normalization_layer = layers.Rescaling(1./(2**8-1))

In [None]:
num_classes = len(class_names)

In [None]:
margin = 8
scaled_height = img_height - 2*margin
scaled_width = img_width - 2*margin

In [None]:
# Build layers for model


In [None]:

with strategy.scope():
    crop_layer = tf.keras.layers.Cropping2D(margin)
    #rescale_initial = tf.keras.layers.Rescaling(1./127.5, offset=-1)
    rescale_initial = tf.keras.layers.Rescaling(1./255)
    conv_4to3_channel = tf.keras.layers.Conv2D(3,1,padding='same', activation='tanh')
    trained_base_model = tf.keras.applications.MobileNetV2(
        input_shape=(scaled_width,scaled_height,3),
        include_top=False,
        weights='imagenet'
    )
    global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
    prediction_layer = tf.keras.layers.Dense(num_classes)

    trained_base_model.trainable = False
    
    inputs = tf.keras.Input(shape=(img_width, img_height, 4))
    x = crop_layer(inputs)
    x = rescale_initial(x)
    x = conv_4to3_channel(x)
    x = trained_base_model(x, training=False)
    x = global_average_layer(x)
    x = tf.keras.layers.Dropout(0.2)(x)
    outputs = prediction_layer(x)
   
    earlystopping = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
         patience=3,
         min_delta=0.001)
    
    model_fixed_base = tf.keras.Model(inputs, outputs)
    model_fixed_base.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy']
    )

In [None]:
model_fixed_base.summary()

In [None]:
tf.keras.utils.plot_model(model_fixed_base, show_shapes=True)

In [None]:
small_model_epochs=80
history_model_fixed_base = model_fixed_base.fit(
  train_ds,
  validation_data=val_ds,
  epochs=small_model_epochs,
  class_weight=class_weight,
  callbacks=[earlystopping],
)

In [None]:
layer_outputs = [layer.output for layer in model_fixed_base.layers[:4]]
vis_model = tf.keras.models.Model(
    inputs=model_fixed_base.input, 
    outputs=layer_outputs
)

In [None]:
#((activations[-1][1,:,:,:]+1)*127.5).astype('uint8')

In [None]:
plt.figure(figsize=(10, 40))
batch = train_ds.take(1)
#activations = vis_model.predict(batch)
for images, labels in batch:
  for i in range(64):
    image = np.expand_dims(images[i], axis=0)
    activation = vis_model.predict(image)
    ax = plt.subplot(16, 4, i + 1)
    plt.imshow(((activation[-1][0,:,:,:]+1)*127.5).astype('uint8'))
    plt.title(class_names[labels[i]], fontsize=6)
    plt.axis("off")

In [None]:
acc = history_model_fixed_base.history['accuracy']
val_acc = history_model_fixed_base.history['val_accuracy']

loss = history_model_fixed_base.history['loss']
val_loss = history_model_fixed_base.history['val_loss']

epochs_range = range(small_model_epochs)

plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

In [None]:
trained_base_model.trainable = True

In [None]:
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in trained_base_model.layers[:fine_tune_at]:
  layer.trainable = False

In [None]:
with strategy.scope():
  model_fixed_base.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.00001),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy']
  )

In [None]:
fine_tuning_epochs=100
total_epochs = small_model_epochs + fine_tuning_epochs
history_fine_tuning = model_fixed_base.fit(
  train_ds,
  validation_data=val_ds,
  epochs=total_epochs,
  initial_epoch=history_model_fixed_base.epoch[-1],
  class_weight=class_weight,
  callbacks=[earlystopping],
)

In [None]:
layer_outputs = [layer.output for layer in model_fixed_base.layers[:4]]
vis_model = tf.keras.models.Model(
    inputs=model_fixed_base.input, 
    outputs=layer_outputs
)

In [None]:
plt.figure(figsize=(10, 40))
for images, labels in batch:
  for i in range(64):
    image = np.expand_dims(images[i], axis=0)
    activation = vis_model.predict(image)
    ax = plt.subplot(16, 4, i + 1)
    plt.imshow(((activation[-1][0,:,:,:]+1)*127.5).astype('uint8'))
    plt.title(class_names[labels[i]], fontsize=6)
    plt.axis("off")

In [None]:
acc_fine = history_fine_tuning.history['val_accuracy']
loss_fine = history_fine_tuning.history['val_loss']
plt.plot(acc_fine)

In [None]:
trained_base_model.trainable = True
for layer in trained_base_model.layers:
    print(layer.trainable)

In [None]:
with strategy.scope():
  model_fixed_base.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.00001),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy']
  )
  

In [None]:
full_relax_epochs=100
total_epochs += full_relax_epochs
history_fine_tuning = model_fixed_base.fit(
  train_ds,
  validation_data=val_ds,
  epochs=total_epochs,
  initial_epoch=history_fine_tuning.epoch[-1],
  class_weight=class_weight,
  callbacks=[earlystopping],
)

In [None]:
layer_outputs = [layer.output for layer in model_fixed_base.layers[:4]]
vis_model = tf.keras.models.Model(
    inputs=model_fixed_base.input, 
    outputs=layer_outputs
)

plt.figure(figsize=(10, 40))
#batch = train_ds.take(1)
#activations = vis_model.predict(batch)
for images, labels in batch:
  for i in range(64):
    image = np.expand_dims(images[i], axis=0)
    activation = vis_model.predict(image)
    ax = plt.subplot(16, 4, i + 1)
    plt.imshow(((activation[-1][0,:,:,:]+1)*127.5).astype('uint8'))
    plt.title(class_names[labels[i]], fontsize=6)
    plt.axis("off")

In [None]:
for ch in range(3):

    print(
        (((activation[-1][:,:,:,ch]+1)*127.5).astype('uint8')).min(), 
        (((activation[-1][:,:,:,ch]+1)*127.5).astype('uint8')).max(), 
    )

In [None]:
model_fixed_base.summary()

In [None]:
tf.keras.utils.plot_model(trained_base_model)

In [None]:
with strategy.scope():
    # Use the activations of these layers
    pretrained_layer_name='mobilenetv2_1.00_224'
    layer_names = [
        'block_1_expand_relu',   # 64x64
        'block_3_expand_relu',   # 32x32
        'block_6_expand_relu',   # 16x16
        'block_13_expand_relu',  # 8x8
        'block_16_project',      # 4x4
    ]
    base_model_outputs = [
        model_fixed_base.get_layer(pretrained_layer_name)
        .get_layer(name).output for name in layer_names
    ]

    # Create the feature extraction model
    down_stack = tf.keras.Model(
        inputs=model_fixed_base.get_layer(pretrained_layer_name).input, 
        outputs=base_model_outputs
    )

    down_stack.trainable = False

In [None]:
down_stack.summary()

In [None]:
#for layer in model_fixed_base.get_layer(pretrained_layer_name).layers[:-4]:
#    down_stack.get_layer(layer.name).set_weights(layer.get_weights())

In [None]:
#down_stack.set_weights(model_fixed_base.get_layer(pretrained_layer_name).get_weights())

In [None]:
tf.keras.utils.plot_model(down_stack, show_shapes=True)

In [None]:

# Define each layer block for upbranch
def upsample(filters, size, norm_type='batchnorm', apply_dropout=False):
  """Upsamples an input.

  Conv2DTranspose => Batchnorm => Dropout => Relu

  Args:
    filters: number of filters
    size: filter size
    norm_type: Normalization type; either 'batchnorm' or 'instancenorm'.
    apply_dropout: If True, adds the dropout layer

  Returns:
    Upsample Sequential Model
  """

  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))

  if norm_type.lower() == 'batchnorm':
    result.add(tf.keras.layers.BatchNormalization())
  #elif norm_type.lower() == 'instancenorm':
  #  result.add(InstanceNormalization())

  if apply_dropout:
    result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result


In [None]:
up_stack = [
    upsample(512, 3),  # 7x7 -> 14x14
    upsample(256, 3),  # 14x14 -> 28x28
    upsample(128, 3),  # 28x28 -> 56x56
    upsample(64, 3),   # 56x56 -> 112x112
]

In [None]:
for layer in model_fixed_base.layers[1:4]:
    print(layer.name)

In [None]:
tensor = tf.convert_to_tensor(np.random.rand(64,240,240,4))
tensor

In [None]:
#model.layers[2].input

In [None]:
def unet_model(output_channels: int):

    # Add layers from classification model
    inputs = tf.keras.layers.Input(shape=[240, 240, 4])
    x = model_fixed_base.layers[1](inputs)
    for layer in model_fixed_base.layers[2:4]:
        x = layer(x)

    # Downsampling through the model
    skips = down_stack(x)
    x = skips[-1]
    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])

    # This is the last layer of the model
    last_conv_trans = tf.keras.layers.Conv2DTranspose(
        filters=output_channels, kernel_size=3, strides=2, padding="same"
    )  # 64x64 -> 128x128

    x = last_conv_trans(x)

    x = tf.keras.layers.ZeroPadding2D(8)(x)

    return tf.keras.Model(inputs=inputs, outputs=x)


In [None]:
OUTPUT_CLASSES = 4
with strategy.scope():
    model = unet_model(output_channels=OUTPUT_CLASSES)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.00001,),
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])

In [None]:
tf.keras.utils.plot_model(model, show_shapes=True)

In [None]:
def display(display_list):
    plt.figure(figsize=(15, 15))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()
  
def create_mask(pred_mask):
    pred_mask = tf.math.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]

def show_predictions(dataset=None, num=1):
    if dataset:
        for image, mask in dataset.take(num):
            pred_mask = model.predict(image)
        display([image[0], mask[0], create_mask(pred_mask)])
    else:
        display([sample_image, sample_mask,
            create_mask(model.predict(sample_image[tf.newaxis, ...]))])

In [None]:
images = []
maps = []

train_image_dir = os.path.join('data','UPENN-GBM','slice_segmentation_stratify','train','image_data')
train_map_dir = os.path.join('data','UPENN-GBM','slice_segmentation_stratify','train','map_data')

pixel_counts = [0,0,0,0]

for image_file in os.listdir(train_image_dir):
    map_file = image_file.replace('allseq', 'map')
    if not os.path.exists(os.path.join(train_map_dir,map_file)):
        raise FileNotFoundError((image_file, map_file))

    image = tf.io.read_file(os.path.join(train_image_dir,image_file))
    image = tf.io.decode_png(image, channels=4)
    seg_map = tf.io.read_file(os.path.join(train_map_dir,map_file))
    seg_map = tf.io.decode_png(seg_map, channels=1)

    # Convert map to make class integers contiguous
    seg_map = seg_map.numpy()
    seg_map[seg_map==4] = 3
    seg_map = tf.convert_to_tensor(seg_map)

    # Count pixel classes
    indices,counts = np.unique(seg_map,return_counts=True)
    for i, index in enumerate(indices):
        pixel_counts[index] += counts[i]
    
    images.append(image)
    maps.append(seg_map)

In [None]:
train_images, val_images, train_maps, val_maps = train_test_split(images, maps, test_size=0.2)
train_images = tf.convert_to_tensor(train_images)
train_maps = tf.convert_to_tensor(train_maps)
val_images = tf.convert_to_tensor(val_images)
val_maps = tf.convert_to_tensor(val_maps)

In [None]:
val_images.shape

In [None]:
def scaler_0_1(x):
    return x/255.0

def scaler_neg1_1(x):
    return x/127.5 - 1

def create_dataset(img, seg_map, scaler):
    img = scaler(tf.cast(img, tf.float32))
    
    return img,seg_map

In [None]:
train_data = tf.data.Dataset.from_tensor_slices(
    create_dataset(train_images,train_maps,scaler_neg1_1)
)
val_data = tf.data.Dataset.from_tensor_slices(
    create_dataset(val_images,val_maps,scaler_neg1_1)
)

In [None]:
BUFFER_SIZE = 1000
BATCH_SIZE = 64

In [None]:
train_batch = (
    train_data.cache()
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .repeat()
    .prefetch(buffer_size=tf.data.AUTOTUNE)
)

In [None]:
val_batch = val_data.batch(BATCH_SIZE)

In [None]:
for images, masks in train_batch.take(1):
    sample_image, sample_mask = images[0], masks[0]
    display([sample_image, sample_mask])
    print(images.shape, masks.shape)

In [None]:
show_predictions()

In [None]:
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        clear_output(wait=True)
        show_predictions()
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

In [None]:
# Calculate class weights
weights = 1.0/np.array(pixel_counts)
weights = weights/np.sum(weights)

def add_sample_weights(image, label):
  # The weights for each class, with the constraint that:
  #     sum(class_weights) == 1.0
  #class_weights = tf.constant([2.0, 2.0, 1.0])
  #class_weights = class_weights/tf.reduce_sum(class_weights)

  # Create an image of `sample_weights` by using the label at each pixel as an 
  # index into the `class weights` .
  sample_weights = tf.gather(weights, indices=tf.cast(label, tf.int32))

  return image, label, sample_weights

In [None]:
#lr 0.00001
TRAIN_LENGTH=68076
EPOCHS = 40
VAL_SUBSPLITS = 5
VALIDATION_STEPS = 17019//BATCH_SIZE//VAL_SUBSPLITS
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

model_history = model.fit(
    train_batch.map(add_sample_weights), 
    epochs=EPOCHS,
    steps_per_epoch=STEPS_PER_EPOCH,
    validation_steps=VALIDATION_STEPS,
    validation_data=val_batch,
    callbacks=[DisplayCallback()],
)

In [None]:
model_history.history

In [None]:
#val_maps.shape
plt.imshow(val_images[10,:,:,1],cmap='gist_gray')

In [None]:
single_img = tf.expand_dims(val_images[10],0)
single_map = tf.expand_dims(val_maps[10],0)


In [None]:
single_img.shape

In [None]:
single_ds = tf.data.Dataset.from_tensor_slices(([single_img],[single_map]))


In [None]:
for simg, smap in single_ds:
    print(simg.shape, smap.shape)

In [None]:
show_predictions(single_ds)

In [None]:
weight = tf.constant([2,10,100])

In [None]:
arr = np.array([[0,0,0],[0,1,0],[0,1,2]])

In [None]:
print(arr)

In [None]:
sample_weights = tf.gather(weight, indices=tf.cast(arr, tf.int32))
sample_weights