In [None]:
# Random seed number for repeatability
ranosedo = 2
from numpy.random import seed
seed(ranosedo)
import tensorflow
tensorflow.random.set_seed(ranosedo)

In [None]:
import IPython.display as Image
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import keras
from keras.models import Model
from keras.layers import Activation
from keras.utils.vis_utils import plot_model
from numba import cuda, jit, njit
import nibabel as nib
import numpy as np
import pandas as pd
from sklearn.utils import shuffle
import random
from scipy import ndimage
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import datetime
from tensorflow.python.keras.callbacks import TensorBoard
from sklearn.metrics import roc_curve, auc
from matplotlib import pyplot
from sklearn.metrics import confusion_matrix
import itertools
import csv
import seaborn as sns
import scipy.stats as stats
from scipy.stats import skew 

In [None]:
from preprocessingfunctions import *
from ResNet3D import Resnet3DBuilder

#### Define data loaders.

In [None]:
batch_size = 2
prefetch_size = len(df_val)

train_loader = tf.data.Dataset.from_tensor_slices((np.array([process_scan(path) for path in df_train['FILEPATH']]),
                                                   np.array([label for label in df_train['labels']]))) #create a dataframe containing the training files with labels indicating the two classes

validation_loader = tf.data.Dataset.from_tensor_slices((np.array([process_scan(path) for path in df_val['FILEPATH']]),
                                                   np.array([label for label in df_val['labels']]))) #create a dataframe containing the validation files with labels indicating the two classes
# Augment on the fly during training.
train_dataset = (
    train_loader.shuffle(len(df_train))
        .map(train_preprocessing)
        .batch(batch_size)
        .prefetch(prefetch_size)
)

#iterator for training batch
iterator_t = tf.compat.v1.data.make_one_shot_iterator(train_dataset)
next_element_train = iterator_t.get_next()

# Only rescale for validation set.
validation_dataset = (
    validation_loader.shuffle(len(df_val))
        .map(validation_preprocessing)
        .batch(batch_size)
        .prefetch(prefetch_size)
)

#iterator for validation batch
iterator_v = tf.compat.v1.data.make_one_shot_iterator(validation_dataset)
next_element_validate = iterator_v.get_next()

#### Transfer learning: Retrain a pretrained model with new data

In [None]:
model_name = 'ResNet.h5'
old_model = tf.keras.models.load_model('ResNet.h5')
input_volume_size = tf.keras.Input((final_height, final_width, final_depth, 1)) #specify the final_height, final_width, final_depth

regularization_factor = 1e-16
model = Resnet3DBuilder.build_resnet_18((final_height, final_width, final_depth, 1),1,regularization_factor)

def updateweights(model, old_model):
    for layer, old_layer in zip(model.layers[1:], old_model.layers[1:]):
        try:
            layer.set_weights(old_layer.get_weights())
        except:
            print("Weights transfer failed! for layer {}".format(layer.name))
    return model

model = updateweights(model,old_model)

#### Start Training

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import tensorflow as tf
from tensorflow.python.keras import backend as K

config = tf.compat.v1.ConfigProto( device_count={'GPU':0})
config.gpu_options.allow_growth = True 
sess = tf.compat.v1.Session(config=config)
K.set_session(sess)

import wandb
from wandb.keras import WandbCallback
initial_learning_rate = 1e-5
wandb.init(project="Classifier",name='20220604_MainRun_121x145x121')


metricr = 'accuracy'
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate,
                                                          decay_steps=100000000000,
                                                          decay_rate=0.99,
                                                          staircase=True,
                                                            )

model.compile(
    loss="binary_crossentropy",
    optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule, beta_1=0.999, epsilon=1e-18, amsgrad = True),
    metrics=[metricr],
)


model.run_eagerly = True

epochs = 300
checkpoint_filepath =  '/mnt/j6/m258195/python_m258195/1style_transfer/Custom/checkpoints/cp-{epoch:04d}.ckpt'
checkpoint_dir = os.path.dirname(checkpoint_filepath)
checkpoint_cb = keras.callbacks.ModelCheckpoint(filepath = checkpoint_filepath,
                                                monitor = metricr,
                                                verbose = 2,
                                                save_weights_only=True,
                                                save_freq = 'epoch',
                                                save_best_only=True,
                                                mode = 'max')


early_stopping_cb = keras.callbacks.EarlyStopping(monitor=metricr, patience=int(epochs/2))

model.save_weights(checkpoint_filepath.format(epoch=0))
print('Trainig begins...')
model_run = model.fit(
    train_dataset,
    validation_data=validation_dataset,
    epochs=epochs,
    use_multiprocessing=True,
    shuffle=True,
    verbose=1,
    batch_size = batch_size*8,
    callbacks=[WandbCallback(),checkpoint_cb],
)

model.save(model_name)