In [None]:
!pip install classification-models-3D

In [None]:
import math, re, os, gc
import tensorflow as tf
import tensorflow.keras.backend as K
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from kaggle_datasets import KaggleDatasets
from tensorflow import keras
from functools import partial
from sklearn.model_selection import train_test_split
import tensorflow_addons as tfa
import matplotlib.ticker as mtick
import tensorflow.experimental.numpy as tnp

In [None]:
import tensorflow as tf
try:
    # TPU detection. No parameters necessary if TPU_NAME environment variable is
    # set: this is always the case on Kaggle.
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
else:
    # Default distribution strategy in Tensorflow. Works on CPU and single GPU.
    strategy = tf.distribute.get_strategy()

AUTOTUNE = tf.data.experimental.AUTOTUNE
print("REPLICAS: ", strategy.num_replicas_in_sync)

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

def seed_all(s):
    random.seed(s)
    np.random.seed(s)
    tf.random.set_seed(s)
    os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
    os.environ['PYTHONHASHSEED'] = str(s)
    
seed = 42

In [None]:
BATCH_SIZE = 8 * strategy.num_replicas_in_sync # Number of the batch size

In [None]:
PIXEL_MEAN = 0.25
MIN_BOUND = -1000.0
MAX_BOUND = 400.0

from scipy import ndimage

def normalize(image):
    image = (image - MIN_BOUND) / (MAX_BOUND - MIN_BOUND)
#     image /= 2048
#     image[image>1] = 1.
#     image[image<0] = 0.
    image = tf.multiply(image, tf.cast(tf.math.greater(image, tf.constant([0], dtype=tf.float32)), tf.float32))
    image = tf.multiply(image, tf.cast(tf.math.less(image, tf.constant([1], dtype=tf.float32)), tf.float32))
    return image

In [None]:
# DO NOT RUN

IMG_DEPTH = 128
IMG_HEIGHT = 512
IMG_WIDTH = 512
IMG_CHANNELS = 3

MIN_BOUND = -1000.0
MAX_BOUND = 400.0
CLASSES = [0, 1, 2, 3, 4]

feature_description = {
    'scan_id': tf.io.FixedLenFeature([], tf.string),
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'num_channels': tf.io.FixedLenFeature([], tf.int64),
    'image': tf.io.FixedLenFeature([], tf.string),
    'label': tf.io.FixedLenFeature([], tf.int64),
    'internalStructure': tf.io.FixedLenFeature([], tf.int64),
    'calcification': tf.io.FixedLenFeature([], tf.int64),
    'sphericity': tf.io.FixedLenFeature([], tf.int64),
    'margin': tf.io.FixedLenFeature([], tf.int64),
    'lobulation': tf.io.FixedLenFeature([], tf.int64),
    'spiculation': tf.io.FixedLenFeature([], tf.int64),
    'texture': tf.io.FixedLenFeature([], tf.int64),
}

feature_list_description = {
    'annotation_indices': tf.io.FixedLenSequenceFeature([], tf.int64),
}

def _parse_image_function(example_proto):
  # Parse the input tf.Example proto using the dictionary above.
    feature, features_list = tf.io.parse_single_sequence_example(
                            example_proto, 
                            sequence_features=feature_list_description,                                                   
                            context_features=feature_description
                        )
#     single_example = tf.io.parse_single_example(example_proto, image_feature_description)
    scan_id = feature['scan_id']
    img_height = feature['height']
    img_width = feature['width']
    num_images = feature['num_channels']
    label = feature['label']
    internalStructure = feature['internalStructure']
    calcification = feature['calcification']
    sphericity = feature['sphericity']
    margin = feature['margin']
    lobulation = feature['lobulation']
    spiculation = feature['spiculation']
    texture = feature['texture']
    
    img_bytes =  tf.io.parse_tensor(feature['image'], out_type=tf.float32) #tf.io.decode_raw(feature['image'],out_type='double')
    img_array = tf.reshape(img_bytes, (num_images, img_height, img_width))
    
    annotation_indices = features_list['annotation_indices']
    
    struct = {
        'scan_id': scan_id,
        'height': img_height,
        'width': img_width,
        'num_images': num_images,
        'img': img_array,
        'label': label,
        'internalStructure': internalStructure,
        'calcification': calcification,
        'sphericity': sphericity,
        'margin': margin,
        'lobulation': lobulation,
        'spiculation': spiculation,
        'texture': texture,
        'annotation_indices': annotation_indices,
    }

#     img_array = tf.stack([img_array, img_array, img_array], axis=-1)
    
#     img_array = tf.reshape(img_array, [IMG_DEPTH, IMG_WIDTH, IMG_HEIGHT, IMG_CHANNELS])
    
    img_array = tf.reshape(img_array, [32, 64, 64, 1])
#     img_array = zero_center(normalize(img_array))
#     img_array = resize_volume(normalize(img_array), img_width, img_height, num_images)
    
    label = tf.cast(label, tf.int32)
    
#     if label < 3:
#         label = 0
#     elif label > 3:
#         label = 1
#     else:
#         label = 2
        
    label = tf.one_hot(label - 1, depth=len(CLASSES))
                         
    return {'image': img_array, 
            'internalStructure': internalStructure,
            'calcification': calcification,
            'sphericity': sphericity,
            'margin': margin,
            'lobulation': lobulation,
            'spiculation': spiculation,
            'texture': texture}, label

def _parse_image_function_testing(example_proto):
  # Parse the input tf.Example proto using the dictionary above.
    feature, features_list = tf.io.parse_single_sequence_example(
                            example_proto, 
                            sequence_features=feature_list_description,                                                   
                            context_features=feature_description
                        )
#     single_example = tf.io.parse_single_example(example_proto, image_feature_description)
    scan_id = feature['scan_id']
    img_height = feature['height']
    img_width = feature['width']
    num_images = feature['num_channels']
    label = feature['label']
    internalStructure = feature['internalStructure']
    calcification = feature['calcification']
    sphericity = feature['sphericity']
    margin = feature['margin']
    lobulation = feature['lobulation']
    spiculation = feature['spiculation']
    texture = feature['texture']
    
    img_bytes =  tf.io.parse_tensor(feature['image'], out_type=tf.float32) #tf.io.decode_raw(feature['image'],out_type='double')
    img_array = tf.reshape(img_bytes, (num_images, img_height, img_width))
    
    annotation_indices = features_list['annotation_indices']
    
    struct = {
        'scan_id': scan_id,
        'height': img_height,
        'width': img_width,
        'num_images': num_images,
        'img': img_array,
        'label': label,
        'internalStructure': internalStructure,
        'calcification': calcification,
        'sphericity': sphericity,
        'margin': margin,
        'lobulation': lobulation,
        'spiculation': spiculation,
        'texture': texture,
        'annotation_indices': annotation_indices,
    }

#     img_array = tf.stack([img_array, img_array, img_array], axis=-1)
    
#     img_array = tf.reshape(img_array, [IMG_DEPTH, IMG_WIDTH, IMG_HEIGHT, IMG_CHANNELS])
    
    img_array = tf.reshape(img_array, [32, 64, 64, 1])
#     img_array = zero_center(normalize(img_array))
#     img_array = resize_volume(normalize(img_array), img_width, img_height, num_images)
    
    label = tf.cast(label, tf.int32)
    
#     if label < 3:
#         label = 0
#     elif label > 3:
#         label = 1
#     else:
#         label = 2
        
    label = tf.one_hot(label - 1, depth=len(CLASSES))
                         
    return {'image': img_array, 
            'internalStructure': internalStructure,
            'calcification': calcification,
            'sphericity': sphericity,
            'margin': margin,
            'lobulation': lobulation,
            'spiculation': spiculation,
            'texture': texture,
            'scan_id': scan_id,
            'annotation_indices': annotation_indices,
            }, label

def read_tf_dataset(storage_file_path):
    encoded_image_dataset = tf.data.TFRecordDataset(storage_file_path, compression_type="GZIP")
    parsed_image_dataset = encoded_image_dataset.map(_parse_image_function)
    return parsed_image_dataset

In [None]:
# Augmentations - Source: https://www.kaggle.com/code/sreevishnudamodaran/rsna-3d-clahe-voxels-tpu-3d-augmentations/notebook

from tensorflow_addons.image import utils as img_utils

FLIP = 0.5 # @params: probability
CONTRAST = (0.3, 1.3, 0.5) # @params: (minval, maxval, probability)
BRIGHTNESS = (0.4, 0.4) # @params: (delta, probability)
GAMMA = (0.8, 1.2, 0.25) # @params: (minval, maxval, probability)
ROTATE = (20, 0.5) # @params: (maxangle, probability)
RANDOM_CROP = (32, 32, 0.0) # @params: (min_width, min_height, probability)
CUTOUT = ((4, 4), 4, 0.0) # @params: ((mask_dim0, mask_dim1), max_num_holes, probability)
BLUR = ([4, 4], 4, 0.0) # @params: ((filter_dim0, filter_dim1), sigma, probability)

def random_rotate3D(voxel, limit=90, p=0.5):
    if tf.random.uniform(()) < p:
        angle = tf.random.uniform((), minval=-limit, maxval=limit,
                                  dtype=tf.int32)
        voxel['image'] = tfa.image.rotate(voxel['image'], tf.cast(angle,
                                                tf.float32)*(math.pi/180),
                                 interpolation='nearest',
                                 fill_mode='constant',
                                 fill_value=0.0)
    return voxel['image']

def random_resized_crop3D(voxel, min_width, min_height, p=0.5):
    if tf.random.uniform(()) < p:
        voxel_shape = voxel['image'].shape
        assert voxel_shape[1] >= min_height
        assert voxel_shape[2] >= min_width
        
        width = tf.random.uniform((), minval=min_width,
                                  maxval=voxel_shape[2],
                                  dtype=tf.int32)
        height = tf.random.uniform((), minval=min_height,
                                   maxval=voxel_shape[1],
                                   dtype=tf.int32)
        x = tf.random.uniform((), minval=0,
                              maxval=voxel_shape[2] - width,
                              dtype=tf.int32)
        y = tf.random.uniform((), minval=0,
                              maxval=voxel_shape[1] - height,
                              dtype=tf.int32)
        voxel['image'] = voxel['image'][:, y:y+height, x:x+width, :]
        voxel['image'] = tf.image.resize(voxel['image'],
                                voxel_shape[1:3],
                                method='lanczos5')
    return voxel['image']

def random_cutout3D(voxel, mask_shape=(10, 10), num_holes=20, p=0.5):
    if tf.random.uniform(()) < p:
        voxel_shape = voxel['image'].shape
        assert voxel_shape[1] >= mask_shape[0]
        assert voxel_shape[2] >= mask_shape[1]

        holes = tf.random.uniform((), minval=1, maxval=num_holes,
                                  dtype=tf.int32)
        mask_size = tf.constant([mask_shape[0], mask_shape[1]])
        mask = tf.Variable((lambda : tf.ones(voxel_shape)),
                           trainable=False)
        
        for i in tf.range(holes):
            x = tf.random.uniform((), minval=0,
                                  maxval=voxel_shape[2],
                                  dtype=tf.int32)
            y = tf.random.uniform((), minval=0,
                                  maxval=voxel_shape[1],
                                  dtype=tf.int32)
            mask_endx = tf.add(x, mask_size[1])
            mask_endy = tf.add(y, mask_size[0])
            mask[:, x:mask_endx,
                 y:mask_endy, :].assign(tf.zeros_like(mask[:, x:mask_endx,
                                                        y:mask_endy, :]))
        voxel['image'] = tf.multiply(voxel['image'], mask)
        mask.assign(tf.ones(voxel_shape))
    return voxel['image']

def _get_gaussian_kernel(sigma, filter_shape):
    x = tf.range(-filter_shape // 2 + 1, filter_shape // 2 + 1)
    x = tf.cast(x ** 2, sigma.dtype)
    x = tf.nn.softmax(-x / (2.0 * (sigma ** 2)))
    return x

def random_gaussian_blur3D(voxel, filter_shape=[5, 5], max_sigma=3, p=0.5):
    if tf.random.uniform(()) < p:
        sigma = tf.random.uniform((), minval=3, maxval=max_sigma,
                          dtype=tf.int32)
        filter_shape = tf.constant(filter_shape)
        channels = voxel['image'].shape[-1]
        sigma = tf.cast(sigma, voxel['image'].dtype)
        gaussian_kernel_x = _get_gaussian_kernel(sigma,
                                                 filter_shape[1])
        gaussian_kernel_x = gaussian_kernel_x[tf.newaxis, :]
        gaussian_kernel_y = _get_gaussian_kernel(sigma,
                                                 filter_shape[0])        
        gaussian_kernel_y = gaussian_kernel_y[:, tf.newaxis]
        gaussian_kernel_2d = tf.matmul(gaussian_kernel_y,
                                       gaussian_kernel_x)
        gaussian_kernel_2d = gaussian_kernel_2d[:, :, tf.newaxis,
                                                tf.newaxis]
        gaussian_kernel_2d = tf.tile(gaussian_kernel_2d,
                                     tf.constant([1, 1, channels, 1]))
        voxel['image'] = tf.nn.depthwise_conv2d(input=voxel['image'],
                                       filter=gaussian_kernel_2d,
                                       strides=(1, 1, 1, 1),
                                       padding="SAME",
                                       )
        voxel['image'] = tf.cast(voxel['image'], voxel['image'].dtype)
    return voxel['image']

def build_augmenter(with_labels=True):
    '''
    Performing tranformations with the same seed
    to ensure the same tranformation is applied to every voxel['image'] slice.
    ''' 
    def augment(voxel):
        aug_seed = tf.random.uniform((2,), minval=1, maxval=9999, dtype=tf.int32)
        if tf.random.uniform(()) < FLIP:
            if tf.random.uniform(()) < 0.5:
                voxel['image'] = tf.image.flip_up_down(voxel['image'])
            else:
                voxel['image'] = tf.image.flip_left_right(voxel['image'])
                
        if tf.random.uniform(()) < BRIGHTNESS[1]:
            voxel['image'] = tf.image.adjust_brightness(
                voxel['image'], tf.random.uniform((), minval=0.0,
                                         maxval=BRIGHTNESS[0],
                                         seed=seed))
        if tf.random.uniform(()) < CONTRAST[2]:
            voxel['image'] = tf.image.adjust_contrast(
                voxel['image'], tf.random.uniform((), minval=CONTRAST[0],
                                         maxval=CONTRAST[1],
                                         seed=seed))
        if tf.random.uniform(()) < GAMMA[2]:
            voxel['image'] = tf.image.adjust_gamma(
                voxel['image'], tf.random.uniform((), minval=GAMMA[0],
                                         maxval=GAMMA[1],
                                         seed=seed))
        voxel['image'] = random_rotate3D(voxel, limit=ROTATE[0],
                                p=ROTATE[1])
        voxel['image'] = random_resized_crop3D(voxel, RANDOM_CROP[0],
                                      RANDOM_CROP[1],
                                      p=RANDOM_CROP[2])
        voxel['image'] = random_cutout3D(voxel, mask_shape=CUTOUT[0],
                                num_holes=CUTOUT[1],
                                p=CUTOUT[2])
        voxel['image'] = random_gaussian_blur3D(voxel,
                                       filter_shape=BLUR[0],
                                       max_sigma=BLUR[1],
                                       p=BLUR[2])

        # Remove nans in place of black pixels in some imgs.
        # Anyone knows the reason for the nans?
        voxel['image'] = tf.where(tf.math.is_nan(voxel['image']),
                         tf.zeros_like(voxel['image']), voxel['image'])
        voxel['image'] = tnp.maximum(tnp.array([0.]), voxel['image'])
        voxel['image'] = tnp.minimum(tnp.array([1.]), voxel['image'])
        voxel['image'] = tf.cast(voxel['image'], tf.float32)
        return voxel
    
    def augment_with_labels(voxel, label):
        return augment(voxel), label
    
    return augment_with_labels if with_labels else augment

In [None]:
def load_dataset(filenames, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE, compression_type="GZIP")
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(_parse_image_function, num_parallel_calls=AUTOTUNE)
    return dataset

def load_testing_dataset(filenames, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE, compression_type="GZIP")
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(_parse_image_function_testing, num_parallel_calls=AUTOTUNE)
    return dataset

def get_training_dataset(dataset, do_aug=True, shuffleBuffer=2048):

    AUTOTUNE = tf.data.experimental.AUTOTUNE
    
    if do_aug:
        augment_fn = build_augmenter(with_labels=True)
        dataset = dataset.map(augment_fn, num_parallel_calls=AUTOTUNE)
        
    dataset = dataset.repeat()
    dataset = dataset.shuffle(shuffleBuffer)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTOTUNE)
    
    return dataset

def get_val_dataset(dataset, do_aug=False):
    if do_aug:
        pass
    dataset = dataset.batch(BATCH_SIZE)
#     dataset = dataset.cache()
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset

In [None]:
from kaggle_datasets import KaggleDatasets
GCS_DS_PATH = KaggleDatasets().get_gcs_path('lidcidri-tfrecords8')
GCS_DS_PATH

In [None]:
train_tf_gcs = GCS_DS_PATH + '/train*.tfrec'
val_tf_gcs = GCS_DS_PATH +'/val*.tfrec'
train_tf_files = tf.io.gfile.glob(train_tf_gcs)
val_tf_files = tf.io.gfile.glob(val_tf_gcs)
# print(val_tf_files[:3])
print("Train TFrecord Files:", len(train_tf_files))
print("Val TFrecord Files:", len(val_tf_files))

In [None]:
train_dataset = load_dataset(train_tf_files)
val_dataset = load_dataset(val_tf_files)

In [None]:
# # del image, mask
# # gc.collect()

# fig = plt.figure(figsize=(25, 25))

for img, label in get_training_dataset(train_dataset, do_aug=False):
    plt.imshow(img['image'][50][0], cmap=plt.cm.gray)
    print(label)
    break

In [None]:
NUM_FILES = sum([int(file.split('/')[-1][-8:-6]) for file in train_tf_files])

In [None]:
from classification_models_3D.tfkeras import Classifiers

input_dict = ['image',
            'internalStructure',
            'calcification',
            'sphericity',
            'margin',
            'lobulation',
            'spiculation',
            'texture']

def create_model(input_shape, num_classes, model_arch):
    
#     input_img = tf.keras.layers.Input((*input_shape, 1), name='image', dtype=tf.float32)
    
    inputs = {}
    for name in input_dict:
        if name == 'image':
            inputs[name] = tf.keras.layers.Input((*input_shape, 1), name='image', dtype=tf.float32)
        else:
            inputs[name] = tf.keras.Input(shape=(1,), name=name, dtype=tf.float32)
        
    numeric_inputs = {name: input_data for name, input_data in inputs.items() if name != 'image'}
    numeric_data = tf.keras.layers.Concatenate()(list(numeric_inputs.values()))

    img_input = tf.keras.layers.Conv3D(3, (3, 3, 3), strides=(1, 1, 1), 
                          padding='same', use_bias=True)(inputs['image'])
    
    net, preprocess_input = Classifiers.get(model_arch)
    x = net(input_shape=(*input_shape, 3), include_top=False,
                   weights='imagenet')(img_input)
    x = tf.keras.layers.GlobalAveragePooling3D()(x)
    x = tf.keras.layers.Dropout(rate=0.5)(x)
    
#     net2, preprocess_input = Classifiers.get('vgg16')
#     y = net2(input_shape=(*input_shape, 3), include_top=False,
#                    weights='imagenet')(img_input)
#     y = tf.keras.layers.GlobalAveragePooling3D()(y)
#     y = tf.keras.layers.Dropout(rate=0.5)(y)
    
    x = tf.keras.layers.Concatenate()([x, numeric_data])
#     x = tf.keras.layers.Dense(256, activation='relu')(x)
    
    # Cast output to float32 for numerical stability
    outputs = tf.keras.layers.Dense(num_classes, activation='softmax',
                                   dtype='float32')(x)
    model  = tf.keras.Model(inputs, outputs)
    model.summary()
    
    return model

In [None]:
class CategoricalFocalLossLabelSmoothing(tf.keras.losses.Loss):
    def __init__(self, gamma=2.0, alpha=0.25, ls=0.1, classes=5.0):
        super(CategoricalFocalLossLabelSmoothing, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.ls = ls
        self.classes = classes
        
    def focal_loss(self, y_true, y_pred, gamma, alpha, ls, classes):
        # Define epsilon so that the backpropagation will not result in NaN
        # for 0 divisor case
        epsilon = K.epsilon()
        # Add the epsilon to prediction value
        #y_pred = y_pred + epsilon
        #label smoothing
        y_pred_ls = (1 - ls) * y_pred + ls / classes
        # Clip the prediction value
        y_pred_ls = K.clip(y_pred_ls, epsilon, 1.0-epsilon)
        # Calculate cross entropy
        cross_entropy = -y_true*K.log(y_pred_ls)
        # Calculate weight that consists of  modulating factor and weighting factor
        weight = alpha * y_true * K.pow((1-y_pred_ls), gamma)
        # Calculate focal loss
        loss = weight * cross_entropy
        # Sum the losses in mini_batch
        loss = K.sum(loss, axis=1)
        return loss
        
    def call(self, y_true, y_pred):
        return self.focal_loss(y_true, y_pred, gamma=self.gamma, alpha=self.alpha, ls=self.ls, classes=self.classes)

In [None]:
# BATCH_SIZE = 1
model_arch = 'vgg19'
EPOCHS = 50
STEPS_PER_EPOCH = 2104 // BATCH_SIZE
input_shape = (32, 64, 64)
with strategy.scope():
    model = create_model(input_shape, len(CLASSES), model_arch)
    
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        1e-3, decay_steps=10000, decay_rate=0.96, staircase=True
    )
    
    optimizer = tf.keras.optimizers.RMSprop(lr_schedule)#(1e-3)
#     loss = CategoricalFocalLossLabelSmoothing(gamma=1.5, alpha=0.1, ls=0.3, classes=5.0)
    loss = tf.keras.losses.CategoricalCrossentropy()
    metrics = [tf.keras.metrics.CategoricalAccuracy(), tf.keras.metrics.AUC()]
    model.compile(optimizer, loss, metrics)

In [None]:
tf.keras.utils.plot_model(model, to_file="model.jpg", rankdir='LR')

In [None]:
START_LR = 1e-9
MAX_LR = 1e-4
MIN_LR = 1e-9
LR_RAMP = 5
LR_SUSTAIN = 3
LR_DECAY = 0.90

def lrfn(epoch):
    if LR_RAMP > 0 and epoch < LR_RAMP:
        lr = (MAX_LR-START_LR)/(LR_RAMP*1.0)*epoch+START_LR
    elif epoch < LR_RAMP+LR_SUSTAIN:
        lr = MAX_LR
    else: # exponential decay from MAX_LR to MIN_LR
        lr = (MAX_LR-MIN_LR)*LR_DECAY**(epoch-LR_RAMP-LR_SUSTAIN)+MIN_LR
    return lr
    
@tf.function
def lrfn_tffun(epoch):
    return lrfn(epoch)

fig = plt.figure(figsize=(14, 5))
ax = fig.add_subplot(111)
rng = [i for i in range(EPOCHS)]
plt.plot(rng, [lrfn(x) for x in rng],
         marker='o')
plt.ticklabel_format(axis="y", style="plain")
ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
plt.show()

lr_callback = tf.keras.callbacks.LearningRateScheduler(lambda epoch: lrfn_tffun(epoch), verbose=1)

model_save = tf.keras.callbacks.ModelCheckpoint('./model.h5',
                             save_best_only = True, 
                             monitor = 'val_categorical_accuracy', 
                             mode = 'max', verbose = 1)

early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_categorical_accuracy', patience=10, mode='max', verbose=1)

In [None]:
model.summary()

In [None]:
history = model.fit(
    get_training_dataset(train_dataset, do_aug=True),
    steps_per_epoch = STEPS_PER_EPOCH,
    epochs = EPOCHS,
    callbacks = [model_save,
                 early_stop,
                ],
    validation_data = get_val_dataset(val_dataset),
    verbose=1
)

In [None]:
with strategy.scope():
    model = tf.keras.models.load_model('./model.h5')

In [None]:
# history.history
val_dataset = load_dataset(val_tf_files, ordered=True)
valds = get_val_dataset(val_dataset)
pred = model.predict(valds, verbose=1)

In [None]:
imgs = []
labels = []
scan_ids = []
val_dataset = load_testing_dataset(val_tf_files, ordered=True)
for input_data, label in val_dataset.as_numpy_iterator():
    imgs.append(input_data['image'])
    labels.append(label)
    scan_ids.append(input_data['scan_id'])

In [None]:
history_frame = pd.DataFrame(history.history)

fig = history_frame.loc[1:, ['loss', 'val_loss']].plot(title="Training and Validation Losses", figsize=(20, 10), fontsize=15)
plt.legend(["Training Loss", "Validation Loss"], fontsize=15)
fig.axes.title.set_size(20)
plt.xlabel('Epochs', fontsize=18)
plt.savefig('losses.jpg')

fig = history_frame.loc[:, ['categorical_accuracy', 'val_categorical_accuracy']].plot(title="Training and Validation Accuracies", xlabel="Epochs", figsize=(20, 10), fontsize=15)
plt.legend(["Training Accuracy", "Validation Accuracy"], fontsize=15)
fig.axes.title.set_size(20)
plt.xlabel('Epochs', fontsize=18)
plt.savefig('accuracies.jpg')

fig = history_frame.loc[:, ['auc_1', 'val_auc_1']].plot(title="Training and Validation AUC", xlabel="Epochs", figsize=(20, 10), fontsize=15)
plt.legend(["Training AUC", "Validation AUC"], fontsize=15)
fig.axes.title.set_size(20)
plt.xlabel('Epochs', fontsize=18)
plt.savefig('auc.jpg')


In [None]:
tumor_id = 63
ncols = 4
nrows = 8

f, plots = plt.subplots(nrows, ncols, figsize=(10, 25))
title = f.suptitle(f'{scan_ids[tumor_id].decode("utf-8")}\n Label: {labels[tumor_id].argmax() + 1} | Prediction: {pred[tumor_id].argmax() + 1}', fontsize='20', fontweight ="bold")

i = 0
j = 0

for index, ind in enumerate(imgs[tumor_id]):
    plots[i, j].axis('off')
    plots[i, j].set_title(f"Slice Number: {index}")
#     plots[i, j].imshow(img3d[ind], cmap=plt.cm.gray)
#     plots[i, j].imshow(masks[ind], alpha=0.5)
    plots[i, j].imshow(imgs[tumor_id][index], cmap=plt.cm.gray)
    j += 1
    if j % ncols == 0:
        i += 1
        j = 0

title.set_y(1)
f.tight_layout()
plt.savefig(f'{scan_ids[tumor_id].decode("utf-8")}.jpg')