# 3DRDNN

Main pipeline for research in my Master.

## goals in this notebook

Prepare the pipeline for any 3d DNN to train on CT data.
1) data loader
2) DNN
3) Training
4) Results comparison

In [1]:
# Testing reading GPU
from distutils.version import LooseVersion
import warnings
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np
import datetime
# Check TensorFlow Version
print("TensorFlow Version: {}".format(tf.__version__))

# Check for a GPU
if not tf.test.gpu_device_name():
    warnings.warn("No GPU found. Please ensure you have installed TensorFlow correctly")
else:
    print(tf.test.gpu_device_name())

TensorFlow Version: 2.10.1
/device:GPU:0


In [2]:
# datasets for tf
from data_preprocessing import get_dataset_large

batch_size = 32
dataset = get_dataset_large("data/LITS_TFRecords_2D/train/"
)
dataset = dataset.shuffle(20_000,reshuffle_each_iteration=True).batch(batch_size)
valid_dataset = get_dataset_large("data/LITS_TFRecords_2D/valid/"
)
valid_dataset = valid_dataset.batch(
    batch_size
)  # .shuffle(1000,  reshuffle_each_iteration=True)

for sample in dataset.take(1):
    print(sample[0].shape)
    print(sample[1].shape)
for sample in valid_dataset.take(1):
    print(sample[0].shape)
    print(sample[1].shape)


initalised with path data\LITS_Challenge\Training_Batch_2
files: 103,103
initalised with path data\LITS_Challenge\Training_Batch_1
files: 28,28
['data/LITS_TFRecords_2D/train\\images0.tfrecords', 'data/LITS_TFRecords_2D/train\\images1.tfrecords', 'data/LITS_TFRecords_2D/train\\images2.tfrecords', 'data/LITS_TFRecords_2D/train\\images3.tfrecords', 'data/LITS_TFRecords_2D/train\\images4.tfrecords', 'data/LITS_TFRecords_2D/train\\images5.tfrecords', 'data/LITS_TFRecords_2D/train\\images6.tfrecords', 'data/LITS_TFRecords_2D/train\\images7.tfrecords']
['data/LITS_TFRecords_2D/valid\\images0.tfrecords', 'data/LITS_TFRecords_2D/valid\\images1.tfrecords', 'data/LITS_TFRecords_2D/valid\\images2.tfrecords', 'data/LITS_TFRecords_2D/valid\\images3.tfrecords', 'data/LITS_TFRecords_2D/valid\\images4.tfrecords', 'data/LITS_TFRecords_2D/valid\\images5.tfrecords', 'data/LITS_TFRecords_2D/valid\\images6.tfrecords']
(32, 256, 256, 1)
(32, 256, 256, 2)
(32, 256, 256, 1)
(32, 256, 256, 2)


In [3]:
# Network set-up
from utils import models

model = models.model_call(model_name="2DUNET",px=256,features=16)

In [4]:
# Metrics and training
# think about cutting down this unet model
from utils import losses
 
adam = tf.keras.optimizers.Adam(
    learning_rate=0.00001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, amsgrad=False,
    name='Adam') # too big LR or model, need to check - first the LR
precision_all = tf.keras.metrics.Precision(
    thresholds=0.5)
precision = tf.keras.metrics.Precision(
    thresholds=0.5, class_id=1)
recall = tf.keras.metrics.Recall(
    thresholds=0.5,class_id=1)
#loss="categorical_crossentropy"
# losses.weighted_categorical_crossentropy_with_fpr(axis=-1)
model.compile(loss="categorical_crossentropy",optimizer=adam,metrics=[precision_all,precision,recall]) #  
model.summary(positions=[.33, .66, .78, 1.])

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape                    Param #     Connected to          
 input_1 (InputLayer)           [(None, 256, 256, 1)]           0           []                    
                                                                                                  
 conv2d (Conv2D)                (None, 256, 256, 16)            160         ['input_1[0][0]']     
                                                                                                  
 conv2d_1 (Conv2D)              (None, 256, 256, 16)            2320        ['conv2d[0][0]']      
                                                                                                  
 max_pooling2d (MaxPooling2D)   (None, 128, 128, 16)            0           ['conv2d_1[0][0]']    
                                                                                              

In [5]:
### learning rate schedule

def scheduler(epoch, lr):
    if epoch % 15 == 0 and epoch > 0:
        lr = 0.00001*0.75 * int(epoch/15)
    return lr * 0.75 ** np.floor(epoch/500)

In [6]:
# Results
# reading checkpoint if needen
#
#################################

# add tensorboard
from tensorflow.keras.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
    ReduceLROnPlateau,
    TerminateOnNaN,
)
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

# training
callbacks = [
    # EarlyStopping(patience=10, verbose=1),
    tf.keras.callbacks.LearningRateScheduler(scheduler),
    ReduceLROnPlateau(factor=0.1, patience=10, min_lr=0.0000001, verbose=1),
    ModelCheckpoint(
        "models\\2DUNET_liver_v2\\{epoch:02d}-{val_loss:.4f}.hdf5",
        verbose=1,
        save_best_only=True,
        save_weights_only=False,
    ),
    TerminateOnNaN(),
    tensorboard_callback
]
epochs=100

history = model.fit(
    dataset,
    epochs=epochs,
    validation_data=valid_dataset,
    callbacks=callbacks,
    initial_epoch=0,
)

Epoch 1/100
    500/Unknown - 139s 194ms/step - loss: 0.4718 - precision: 0.7718 - precision_1: 0.1099 - recall: 0.3272
Epoch 1: val_loss improved from inf to 0.15533, saving model to models\2DUNET_liver_v2\01-0.1553.hdf5
Epoch 2/100
Epoch 2: val_loss improved from 0.15533 to 0.14100, saving model to models\2DUNET_liver_v2\02-0.1410.hdf5
Epoch 3/100
Epoch 3: val_loss improved from 0.14100 to 0.12978, saving model to models\2DUNET_liver_v2\03-0.1298.hdf5
Epoch 4/100


In [None]:
new_model = tf.keras.models.load_model("models\\2DUNET_liver_16_256\\50-0.0155-0.92-0.87.hdf5")



n_max = 80
n=0
for x in loader_valid.data_generator_2d_liver():
   if n < 40:
      pass
   elif n < n_max:
      plt.subplot(1,3,1)
      plt.imshow(x[0])
      plt.subplot(1,3,2)
      preds = new_model.predict(tf.reshape(x[0],[1,256,256,1]))
      plt.imshow(preds[0,:,:,1])
      plt.subplot(1,3,3)
      plt.imshow(x[1][:,:,1])
      plt.show()
   else:
      break
   n+=1