In [None]:
import glob
import os
import shutil
from PIL import Image 
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import numpy as np
import pandas as pd
import nibabel as nib

from modules.scandata import MriScan, MriSlice, TumourSegmentation, ScanType, ScanPlane

In [None]:
import numpy as np
import os
import PIL
import PIL.Image
import tensorflow as tf
#import tensorflow_datasets as tfds
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import train_test_split

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from IPython.display import clear_output

In [None]:
tf.config.list_logical_devices('TPU')

In [None]:
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local')
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.TPUStrategy(cluster_resolver)


In [None]:
batch_size = 32
buffer_size = 1000
img_height = 240
img_width = 240
data_dir = os.path.join('data','UPENN-GBM','slice_classification_common_stratify_healthysegmented','train')


In [None]:
train_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    color_mode="rgba",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

In [None]:
val_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    color_mode="rgba",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

In [None]:
class_names = train_ds.class_names
print(class_names)

In [None]:
 # Calculate class weights ofr weighting accuracy
ds_classes = []
for _, batch_classes in train_ds:
    ds_classes.append(batch_classes.numpy())

ds_classes = np.concatenate(ds_classes)

class_weight = compute_class_weight(
    class_weight = 'balanced',
    classes = np.unique(ds_classes),
    y=ds_classes
)

class_weight = dict(zip(np.unique(ds_classes), class_weight))


In [None]:
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

In [None]:
normalization_layer = layers.Rescaling(1./(2**8-1))

In [None]:
num_classes = len(class_names)

In [None]:
margin = 8
scaled_height = img_height - 2*margin
scaled_width = img_width - 2*margin

In [None]:
# Load saved model
with strategy.scope():
    trained_down_model = tf.keras.models.load_model('models/mobilenetv2_bs16_pat5_del0.001_dr0.25_lr0.0001_save-2022-12-02-09:29:15')


In [None]:
val_test_batch_ims, val_test_batch_true_classes = next(iter(val_ds))

In [None]:
val_preds = trained_down_model.predict(val_test_batch_ims)

In [None]:
from sklearn.metrics import classification_report
val_test_probs = tf.nn.softmax(val_preds)
val_pred_class = [np.argmax(x) for x in val_test_probs]
print(classification_report(val_test_batch_true_classes, val_pred_class))

In [None]:
with strategy.scope():
    # Use the activations of these layers
    pretrained_layer_name='mobilenetv2_1.00_224'
    layer_names = [
        'block_1_expand_relu',   # 64x64
        'block_3_expand_relu',   # 32x32
        'block_6_expand_relu',   # 16x16
        'block_13_expand_relu',  # 8x8
        'block_16_project',      # 4x4
    ]
    base_model_outputs = [
        trained_down_model.get_layer(pretrained_layer_name)
        .get_layer(name).output for name in layer_names
    ]

    # Create the feature extraction model
    down_stack = tf.keras.Model(
        inputs=trained_down_model.get_layer(pretrained_layer_name).input, 
        outputs=base_model_outputs
    )

    down_stack.trainable = False

In [None]:
down_stack.summary()

In [None]:
tf.keras.utils.plot_model(down_stack, show_shapes=True)

In [None]:

# Define each layer block for upbranch
def upsample(filters, size, norm_type='batchnorm', apply_dropout=False):
  """Upsamples an input.

  Conv2DTranspose => Batchnorm => Dropout => Relu

  Args:
    filters: number of filters
    size: filter size
    norm_type: Normalization type; either 'batchnorm' or 'instancenorm'.
    apply_dropout: If True, adds the dropout layer

  Returns:
    Upsample Sequential Model
  """

  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))

  if norm_type.lower() == 'batchnorm':
    result.add(tf.keras.layers.BatchNormalization())
  #elif norm_type.lower() == 'instancenorm':
  #  result.add(InstanceNormalization())

  if apply_dropout:
    result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result


In [None]:
up_stack = [
    upsample(512, 3),  # 7x7 -> 14x14
    upsample(256, 3),  # 14x14 -> 28x28
    upsample(128, 3),  # 28x28 -> 56x56
    upsample(64, 3),   # 56x56 -> 112x112
]

In [None]:
for layer in trained_down_model.layers[1:4]:
    print(layer.name)

In [None]:
def unet_model(output_channels: int):

    # Add layers from classification model
    inputs = tf.keras.layers.Input(shape=[240, 240, 4])
    x = trained_down_model.layers[1](inputs)
    for layer in trained_down_model.layers[2:4]:
        x = layer(x)

    # Downsampling through the model
    skips = down_stack(x)
    x = skips[-1]
    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])

    # This is the last layer of the model
    last_conv_trans = tf.keras.layers.Conv2DTranspose(
        filters=output_channels, kernel_size=3, strides=2, padding="same"
    )  # 64x64 -> 128x128

    x = last_conv_trans(x)

    x = tf.keras.layers.ZeroPadding2D(8)(x)

    return tf.keras.Model(inputs=inputs, outputs=x)


In [None]:
OUTPUT_CLASSES = 5
with strategy.scope():
    model = unet_model(output_channels=OUTPUT_CLASSES)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.00005,),
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy']
    )

In [None]:
model.summary()

In [None]:
tf.keras.utils.plot_model(model, show_shapes=True)

In [None]:
colour_list = ['w', '#d73027', '#91bfdb', '0.8', '#fee090']
cmap = ListedColormap(colour_list)

def display(display_list):
    plt.figure(figsize=(15, 15))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.utils.array_to_img(display_list[i]),cmap=cmap)
        plt.axis('off')
    plt.show()
  
def create_mask(pred_mask):
    pred_mask = tf.math.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]

def show_predictions(dataset=None, num=1):
    if dataset:
        for image, mask in dataset.take(num):
            pred_mask = model.predict(image)
        display([image[0], mask[0], create_mask(pred_mask)])
    else:
        display([sample_image, sample_mask,
            create_mask(model.predict(sample_image[tf.newaxis, ...]))])

In [None]:
images = []
maps = []

train_image_dir = os.path.join('data','UPENN-GBM','slice_segmentation_stratify_healthy','train','image_data')
train_map_dir = os.path.join('data','UPENN-GBM','slice_segmentation_stratify_healthy','train','map_data')

# Count pixels for sample weight
pixel_counts = [0,0,0,0,0]
for map_file in os.listdir(train_map_dir):

    seg_map = tf.io.read_file(os.path.join(train_map_dir,map_file))
    seg_map = tf.io.decode_png(seg_map, channels=1)
    
    indices,counts = np.unique(seg_map,return_counts=True)
    for i, index in enumerate(indices):
        pixel_counts[index] += counts[i]
print(pixel_counts)

In [None]:
image_filenames = os.listdir(train_image_dir)
map_filenames = [filename.replace('allseq', 'map') for filename in image_filenames]
image_filepaths = [os.path.join(train_image_dir,filename) for filename in image_filenames]
map_filepaths = [os.path.join(train_map_dir,filename) for filename in map_filenames]

In [None]:
train_image_filepaths, val_image_filepaths, train_map_filepaths, val_map_filepaths = train_test_split(
    image_filepaths, 
    map_filepaths, 
    test_size=0.2,
    )
    

In [None]:
train_image_data = tf.data.Dataset.list_files(train_image_filepaths, shuffle=False)
train_map_data = tf.data.Dataset.list_files(train_map_filepaths, shuffle=False)
train_data = tf.data.Dataset.zip((train_image_data, train_map_data))
val_image_data = tf.data.Dataset.list_files(val_image_filepaths, shuffle=False)
val_map_data = tf.data.Dataset.list_files(val_map_filepaths, shuffle=False)
val_data = tf.data.Dataset.zip((val_image_data, val_map_data))

In [None]:
def scaler_0_1(x):
    return x/255.0

def scaler_neg1_1(x):
    return x/127.5 - 1

def alter_segmap(x):
    return tf.where(x==4,tf.constant(3,dtype='uint8'),x)

def read_image_map(image, seg_map):
   image = tf.io.read_file(image)
   image = tf.io.decode_png(image, channels=4)
   seg_map = tf.io.read_file(seg_map)
   seg_map = tf.io.decode_png(seg_map, channels=1)
   # Change scaler below to scaler_0_1 to get initial values between 0 and 1
   return scaler_neg1_1(tf.cast(image, 'float32')) ,seg_map


In [None]:
train_batch = (
    train_data.cache()
    .shuffle(buffer_size)
    .map(read_image_map)
    .batch(batch_size)
    .repeat()
    .prefetch(buffer_size=tf.data.AUTOTUNE)
)
val_batch = val_data.map(read_image_map).shuffle(buffer_size).batch(batch_size)

In [None]:
images, masks = next(iter(train_batch))

In [None]:
num = 20
fig, ax = plt.subplots(1,2)
ax[0].imshow(tf.cast(127.5*(images[num]+1), 'uint8'))
ax[1].imshow(masks[num])

In [None]:
val_images, val_masks = next(iter(train_batch))

In [None]:
vnum = 7
fig, ax = plt.subplots(1,2)
ax[0].imshow(tf.cast(127.5*(val_images[vnum]+1), 'uint8'))
ax[1].imshow(val_masks[vnum])

In [None]:
for images, masks in train_batch.take(1):
    sample_image, sample_mask = images[0], masks[0]
    display([sample_image, sample_mask])
    print(images.shape, masks.shape)

In [None]:
slice_num=0
sample_image, sample_mask = images[slice_num], masks[slice_num]
show_predictions()

In [None]:
np.unique(sample_mask, return_counts=True)

In [None]:
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        #clear_output(wait=True)
        show_predictions()
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

In [None]:
# Calculate class weights
weights = 1.0/np.array(pixel_counts)
weights = weights/np.sum(weights)

def add_sample_weights(image, label):
  # The weights for each class, with the constraint that:
  #     sum(class_weights) == 1.0
  #class_weights = tf.constant([2.0, 2.0, 1.0])
  #class_weights = class_weights/tf.reduce_sum(class_weights)

  # Create an image of `sample_weights` by using the label at each pixel as an 
  # index into the `class weights` .
  sample_weights = tf.gather(weights, indices=tf.cast(label, tf.int32))

  return image, label, sample_weights

In [None]:
weights

In [None]:
earlystopping = tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=3,
            )

In [None]:
print(len(train_image_filepaths),len(val_image_filepaths))

In [None]:
#lr 0.0001
TRAIN_LENGTH=68076
EPOCHS = 40
VAL_SUBSPLITS = 5
VALIDATION_STEPS = 17019//batch_size//VAL_SUBSPLITS
STEPS_PER_EPOCH = TRAIN_LENGTH // batch_size

model_history = model.fit(
    train_batch.map(add_sample_weights), 
    epochs=EPOCHS,
    steps_per_epoch=STEPS_PER_EPOCH,
    validation_steps=VALIDATION_STEPS,
    validation_data=val_batch,
    callbacks=[DisplayCallback(), earlystopping],
)

In [None]:
val_images, val_maps = next(iter(val_batch))

In [None]:
type(val_maps)

In [None]:
#val_maps.shape
plt.imshow(tf.cast(val_maps[2,:,:,0],'uint8'), cmap=cmap)

In [None]:
single_img = tf.expand_dims(val_images[2],0)
single_map = tf.expand_dims(val_maps[2],0)


In [None]:
single_ds = tf.data.Dataset.from_tensor_slices(([single_img],[single_map]))


In [None]:
show_predictions(single_ds)

In [None]:
earlystopping = tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=10,
            )

In [None]:
# Partially relax model
model.trainable = True

fix_below_layer = 100
    for layer in base_model.layers[:fix_below_layer]:
        layer.trainable = False


with strategy.scope():
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.00001,),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy']
    )

In [None]:
model_history_fullrelax = model.fit(
    train_batch.map(add_sample_weights), 
    epochs=50,
    steps_per_epoch=STEPS_PER_EPOCH,
    initial_epoch=model_history.epoch[-1]+1,
    validation_steps=VALIDATION_STEPS,
    validation_data=val_batch,
    callbacks=[DisplayCallback(), earlystopping],
)

In [None]:
show_predictions(single_ds)