# 3D CNN

In [1]:
import tensorflow as tf
# import tensorflow_datasets as tfds

import nibabel as nib
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from scipy import ndimage
from pathlib import Path
from functools import partial
import math

from time import strftime

#from tensorflow.train import BytesList, FloatList, Int64List
#from tensorflow.train import Feature, Features, Example

import sys
sys.path.append(r"/Users/LennartPhilipp/Desktop/Uni/Prowiss/Code/Brain_Mets_Classification")

import brain_mets_classification.custom_funcs as funcs

from tqdm import tqdm

## load data from TFRecord file

In [2]:
path_to_tfr = "/Volumes/BrainMets/Rgb_Brain_Mets/brain_mets_classification/derivatives/TFRecords/patient_data_2classes.tfrecord"
#path_to_tfr = "/Volumes/BrainMets/Rgb_Brain_Mets/brain_mets_classification/derivatives/TFRecords/testing_patient_data_2classes.tfrecord"


tf.keras.utils.set_random_seed(42)

In [6]:
feature_description = {
    "image": tf.io.FixedLenFeature([149, 185, 155, 4], tf.float32),
    "sex": tf.io.FixedLenFeature([2], tf.int64, default_value=[0,0]),
    "age": tf.io.FixedLenFeature([], tf.int64, default_value=0),
    "primary": tf.io.FixedLenFeature([], tf.int64, default_value=0),
}

def parse(serialize_patient):
    example = tf.io.parse_single_example(serialize_patient, feature_description)
    # input = [example["image"], example["sex"], example["age"]]
    # label = example["primary"]
    image = example["image"]
    image = tf.reshape(image, [149, 185, 155, 4])
    return image, example["sex"], example["age"], example["primary"]

dataset = tf.data.TFRecordDataset([path_to_tfr], compression_type="GZIP")
parsed_dataset = dataset.map(parse)

# Display brain slice
# numpy_image = parsed_dataset.get_single_element()[0].numpy()
# plt.imshow(numpy_image[80,:,:,0], cmap = "inferno")

# split dataset into train, validation and test

#########################################################

#Calculate sizes for train, validation, and test sets
# total_samples = sum(1 for _ in parsed_dataset)
# train_size = int(0.8 * total_samples)
# val_size = int(0.1 * total_samples)
# test_size = total_samples - train_size - val_size
total_samples = sum(1 for _ in parsed_dataset)
train_size = int(0.5 * total_samples)
val_size = int(0.25 * total_samples)
test_size = total_samples - train_size - val_size

print(f"Training size: {train_size}")
print(f"Validation size: {val_size}")
print(f"Testing size: {test_size}")

# Shuffle and split dataset
dataset = parsed_dataset.shuffle(buffer_size=200)
train_dataset = dataset.take(train_size).prefetch(buffer_size = tf.data.AUTOTUNE)
remainder_dataset = dataset.skip(train_size).prefetch(buffer_size = tf.data.AUTOTUNE)
val_dataset = remainder_dataset.take(val_size).prefetch(buffer_size = tf.data.AUTOTUNE)
test_dataset = remainder_dataset.skip(val_size).prefetch(buffer_size = tf.data.AUTOTUNE)

# Example usage of datasets
# print("Train dataset size:", sum(1 for _ in train_dataset))
# print("Validation dataset size:", sum(1 for _ in val_dataset))
# print("Test dataset size:", sum(1 for _ in test_dataset))

#############################################################

# train_images = tf.Variable(initial_value=tf.zeros((149, 185, 155, 4)), trainable=False)
# train_ages = tf.Variable(initial_value=tf.zeros((0,), dtype=tf.float32), trainable=False)
# train_sexes = tf.Variable(initial_value=tf.zeros((0,), dtype=tf.int64), trainable=False)
# train_primaries = tf.Variable(initial_value=tf.zeros((0,), dtype=tf.int64), trainable=False)

def split_dataset(dataset):
    images = []
    ages = []
    sexes = []
    primaries = []
    for image, sex, age, primary in dataset:
        images.append(image)
        ages.append(age)
        sexes.append(sex)
        primaries.append(primary)
    return tf.stack(images), tf.stack(sexes), tf.stack(ages), tf.stack(primaries)

train_images, train_sex, train_ages, train_primaries = split_dataset(train_dataset)
val_images, val_sex, val_ages, val_primaries = split_dataset(val_dataset)
test_images, test_sex, test_ages, test_primaries = split_dataset(test_dataset)

Training size: 4
Validation size: 2
Testing size: 2


In [7]:
print(train_images.shape)

(4, 149, 185, 155, 4)


Write simple CNN and then go from there

In [4]:
path_to_callback = "/Volumes/BrainMets/Rgb_Brain_Mets/brain_mets_classification/derivatives/logs/callback"

## Custom Callbacks and building blocks

### Callbacks

- Checkpoint (= safe best model\)
- Early Stopping\
- Tensorboard (not currently working)

In [14]:
def get_run_logdir(root_logdir="/Volumes/BrainMets/Rgb_Brain_Mets/brain_mets_classification/derivatives/logs/tensorboard"):
    return Path(root_logdir) / strftime("run_%Y_%m_%d_%H_%M_%S")

run_logdir = get_run_logdir()

checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(filepath = path_to_callback,
                                                   monitor = "val_accuracy",
                                                   mode = "max",
                                                   save_best_only = True,
                                                   save_weights_only = True)

early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=10,
                                                     restore_best_weights = True)

tensorboard_cb = tf.keras.callbacks.TensorBoard(log_dir = run_logdir,
                                                histogram_freq = 1)

1Cycle Scheduler

In [15]:
K = tf.keras.backend

class ExponentialLearningRate(tf.keras.callbacks.Callback):
    def __init__(self, factor):
        self.factor = factor
        self.rates = []
        self.losses = []

    def on_epoch_begin(self, epoch, logs=None):
        self.sum_of_epoch_losses = 0

    def on_batch_end(self, batch, logs=None):
        mean_epoch_loss = logs["loss"]  # the epoch's mean loss so far 
        new_sum_of_epoch_losses = mean_epoch_loss * (batch + 1)
        batch_loss = new_sum_of_epoch_losses - self.sum_of_epoch_losses
        self.sum_of_epoch_losses = new_sum_of_epoch_losses
        self.rates.append(K.get_value(self.model.optimizer.learning_rate))
        self.losses.append(batch_loss)
        K.set_value(self.model.optimizer.learning_rate,
                    self.model.optimizer.learning_rate * self.factor)
        
def find_learning_rate(model, X, y, epochs=1, batch_size=32, min_rate=1e-4,
                       max_rate=1):
    init_weights = model.get_weights()
    iterations = math.ceil(len(X) / batch_size) * epochs
    factor = (max_rate / min_rate) ** (1 / iterations)
    init_lr = K.get_value(model.optimizer.learning_rate)
    K.set_value(model.optimizer.learning_rate, min_rate)
    exp_lr = ExponentialLearningRate(factor)
    history = model.fit(X, y, epochs=epochs, batch_size=batch_size,
                        callbacks=[exp_lr])
    K.set_value(model.optimizer.learning_rate, init_lr)
    model.set_weights(init_weights)
    return exp_lr.rates, exp_lr.losses

def plot_lr_vs_loss(rates, losses):
    plt.plot(rates, losses, "b")
    plt.gca().set_xscale('log')
    max_loss = losses[0] + min(losses)
    plt.hlines(min(losses), min(rates), max(rates), color="k")
    plt.axis([min(rates), max(rates), 0, max_loss])
    plt.xlabel("Learning rate")
    plt.ylabel("Loss")
    plt.grid()

# USAGE:
# batch_size = 128
# rates, losses = find_learning_rate(model, X_train, y_train, epochs=1,
#                                    batch_size=batch_size)
# plot_lr_vs_loss(rates, losses)

# 1CycleScheduler
#https://arxiv.org/abs/1803.09820

class OneCycleScheduler(tf.keras.callbacks.Callback):
    def __init__(self, iterations, max_lr=1e-3, start_lr=None,
                 last_iterations=None, last_lr=None):
        self.iterations = iterations
        self.max_lr = max_lr
        self.start_lr = start_lr or max_lr / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_lr = last_lr or self.start_lr / 1000
        self.iteration = 0

    def _interpolate(self, iter1, iter2, lr1, lr2):
        return (lr2 - lr1) * (self.iteration - iter1) / (iter2 - iter1) + lr1

    def on_batch_begin(self, batch, logs):
        if self.iteration < self.half_iteration:
            lr = self._interpolate(0, self.half_iteration, self.start_lr,
                                   self.max_lr)
        elif self.iteration < 2 * self.half_iteration:
            lr = self._interpolate(self.half_iteration, 2 * self.half_iteration,
                                   self.max_lr, self.start_lr)
        else:
            lr = self._interpolate(2 * self.half_iteration, self.iterations,
                                   self.start_lr, self.last_lr)
        self.iteration += 1
        K.set_value(self.model.optimizer.learning_rate, lr)

# USAGE
# n_epochs = 25
# onecycle = OneCycleScheduler(math.ceil(len(X_train) / batch_size) * n_epochs,
#                              max_lr=0.1)
# history = model.fit(X_train, y_train, epochs=n_epochs, batch_size=batch_size,
#                     validation_data=(X_valid, y_valid),
#                     callbacks=[onecycle])

### Initializers, Optimizers, etc.

In [11]:
intializer = tf.keras.initializers.HeNormal()
activation_func = "mish"
#optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=1e-3) # this is a placeholder, chagne to Nestorev oder AdamW
optimizer = tf.keras.optimizers.legacy.SGD(learning_rate=0.001, momentum=0.9, nesterov=True)

### Building Blocks

In [13]:
# MCDropout
# https://arxiv.org/abs/1506.02142

class MCDropout(tf.keras.layers.Dropout):
    def call(self, inputs, training=False):
        return super().call(inputs, training=True)
    


### ResNeXt Blocks
courtesy of https://github.com/taki0112/ResNeXt-Tensorflow/tree/master\
original ResNeXt paper: https://arxiv.org/abs/1611.05431

In [4]:
cardinality = 8 # how many split ?
blocks = 3 # res_block ! (split + transition)

"""
So, the total number of layers is (3*blokcs)*residual_layer_num + 2
because, blocks = split(conv 2) + transition(conv 1) = 3 layer
and, first conv layer 1, last dense layer 1
thus, total number of layers = (3*blocks)*residual_layer_num + 2
"""

depth = 64 # out channel

batch_size = 2 # original: 128
iteration = 391
# 128 * 391 ~ 50,000

test_iteration = 10

total_epochs = 300

def conv_layer(input, filters, kernel, stride, padding='SAME', layer_name="3Dconv"):
    with tf.name_scope(layer_name):
        #network = tf.layers.conv2d(inputs=input, use_bias=False, filters=filter, kernel_size=kernel, strides=stride, padding=padding)
        network = tf.keras.layers.Conv3D(filters = filters,
                                         use_bias = False,
                                         kernel_size = kernel,
                                         strides = stride,
                                         padding = padding,
                                         inputs = input)
        return network
    
def Global_Average_Pooling(x):
    return tf.keras.layers.GlobalAveragePooling3D(x, name='Global_avg_pooling')

def Average_pooling(x, pool_size=[2,2,2], stride=2, padding='SAME'):
    return tf.keras.layers.AveragePooling3D(inputs=x, pool_size=pool_size, strides=stride, padding=padding)

def Relu(x):
    return tf.keras.activations.mish(x)

def Concatenation(layers) :
    return tf.keras.layers.Concatenate(layers, axis=-1) # orginally axis = 3

def Linear(x, class_num) :
    return tf.keras.layers.Dense(inputs=x, use_bias=False, units=class_num, name='linear')


## Image Augmentation

In [None]:
data_augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomFlip(mode = "horizontal"),
    tf.keras.layers.RandomBrightness(factor = (-0.2, 0.5), value_range=(0, 1)),
    tf.keras.layers.RandomContrast(0.5),
    tf.keras.layers.RandomRotation(factor = (-0.07, 0.07), fill_mode = "nearest"),
    tf.keras.layers.RandomTranslation(
        height_factor = 0.025,
        width_factor = 0.05,
        fill_mode = "nearest"
    )
])

simple model

In [14]:
# loss: categorical crossentropy
# set class weight for underrepresented classes

batch_norm_layer = tf.keras.layers.BatchNormalization()
conv_1_layer = tf.keras.layers.Conv3D(filters = 64, kernel_size = 7, input_shape = [149, 185, 155, 4], strides=(2,2,2), activation=activation_func, kernel_initializer=tf.keras.initializers.HeNormal())
max_pool_1_layer = tf.keras.layers.MaxPooling3D(pool_size = (2,2,2))
conv_2_layer = tf.keras.layers.Conv3D(filters = 64, kernel_size = 7, strides=(2,2,2), activation=activation_func, kernel_initializer=tf.keras.initializers.HeNormal())
max_pool_2_layer = tf.keras.layers.MaxPooling3D(pool_size = (2,2,2))
dense_1_layer = tf.keras.layers.Dense(100, activation=activation_func, kernel_initializer=tf.keras.initializers.HeNormal())
dropout_1_layer = tf.keras.layers.Dropout(0.5)
dense_2_layer = tf.keras.layers.Dense(100, activation=activation_func, kernel_initializer=tf.keras.initializers.HeNormal())
dropout_2_layer = tf.keras.layers.Dropout(0.5)
output_layer = tf.keras.layers.Dense(2, activation="softmax")

# Define inputs
input_image = tf.keras.layers.Input(shape=train_images.shape[1:])

# concatenate input sex and input age

batch_norm = batch_norm_layer(input_image)
conv_1 = conv_1_layer(batch_norm)
max_pool_1 = max_pool_1_layer(conv_1)
conv_2 = conv_2_layer(max_pool_1)
max_pool_2 = max_pool_2_layer(conv_2)
dense_1 = dense_1_layer(max_pool_2)
dropout_1 = dropout_1_layer(dense_1)
dense_2 = dense_2_layer(dropout_1)
dropout_2 = dropout_2_layer(dense_2)
output = output_layer(dropout_2)



model = tf.keras.Model(inputs = input_image, outputs = [output])
model.compile(loss="mse", optimizer=optimizer, metrics = ["RootMeanSquaredError", "accuracy"])
model.summary()

# tensorboard_cb = tf.keras.callbacks.TensorBoard(run_logdir)



Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_3 (InputLayer)        [(None, 149, 185, 155,    0         
                             4)]                                 
                                                                 
 batch_normalization_2 (Bat  (None, 149, 185, 155, 4   16        
 chNormalization)            )                                   
                                                                 
 conv3d_4 (Conv3D)           (None, 72, 90, 75, 64)    87872     
                                                                 
 max_pooling3d_4 (MaxPoolin  (None, 36, 45, 37, 64)    0         
 g3D)                                                            
                                                                 
 conv3d_5 (Conv3D)           (None, 15, 20, 16, 64)    1404992   
                                                           

In [13]:
if __name__=="__main__":
    history = model.fit(train_images, train_primaries, epochs=20, batch_size=1, validation_data=(val_images, val_primaries), callbacks = [checkpoint_cb, early_stopping_cb, tensorboard_cb])

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20

KeyboardInterrupt: 

## relatively complex model, image only

In [None]:
# To-use:
# image augmentation
# local response normalization
# 1-cycle scheduling
# Resnet Blöcke
# normalization layers

In [38]:
# Assuming you have placeholders for sex_input and age_input
sex_input = tf.keras.Input(shape=(2,))
age_input = tf.keras.Input(shape=(1,))

# Concatenate the inputs
concatenated_inputs = tf.keras.layers.concatenate([sex_input, age_input])

# Continue building your model using the concatenated inputs
# For example:
# output_layer = SomeLayer()(concatenated_inputs)
# model = tf.keras.Model(inputs=[sex_input, age_input], outputs=output_layer)

# Example of using the concatenated inputs in a model
output_layer = tf.keras.layers.Dense(64, activation='relu')(concatenated_inputs)
output_layer = tf.keras.layers.Dense(1, activation='sigmoid')(output_layer)

# Define the model with concatenated inputs
model = tf.keras.Model(inputs=[sex_input, age_input], outputs=output_layer)

# Compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Example usage:
# model.fit([sex_data, age_data], target_labels, epochs=num_epochs, batch_size=batch_size)
