In [3]:
import tensorflow_mri as tfmr
import numpy as np
from sklearn.model_selection import train_test_split
from glob import glob
import tensorflow as tf
import tensorflow.keras.backend as K
import random
import matplotlib.pyplot as plt
from skimage import exposure
import scipy
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

2024-08-22 14:49:33.478245: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [None]:
import neptune
from neptune.new.integrations.tensorflow_keras import NeptuneCallback

continue_training = True
model_name = ' '

if continue_training:
    run = neptune.init_run(
        project=" ",
        api_token=" ",
        with_id = model_name)  
else:
    run = neptune.init_run(
        project=" ",
        api_token=" ",
    ) 
    model_name = list(run.__dict__.values())[-6] #new model

In [None]:
patients = [pat.replace('.npy','') for pat in glob('data/Pairs71/*')] # This should be the input data (image+mask) for training.
train_patients, val_patients = train_test_split(patients,test_size=0.3)

In [None]:
class CustomDataGen():    
    def __init__(self, patients, cohort):
        random.shuffle(patients)
        self.patients = patients
        self.cohort = cohort                
    def data_generator(self):
        for patient in self.patients:
            image_mask = np.load(f"{patient}.npy")
            image = image_mask[...,0]
            mask = image_mask[...,1]
            image= normalize(image)
            image= aug_down_gamma_up(image)
            image=image[...,np.newaxis]
            mask=mask[...,np.newaxis]
            bkg = np.zeros(mask.shape[:2])
            bkg = np.where(np.sum(mask,-1) == 1, 0, 1)
            mask = np.concatenate([bkg[...,np.newaxis],mask], -1)
            image=normalize(image)
            yield image, mask.astype('uint8')           
    def get_gen(self):
        return self.data_generator() 	 
          
def normalize(image):
    mean = np.mean(image)
    std = np.std(image)
    if std != 0:
        norm = (image - mean) / std
    else:
        norm = np.zeros_like(image)
    return norm

def random_gamma(img2):
    num_slice = random.randint(0, 3) 
    if num_slice != 0:
        start_slice = random.randint(0, img2.shape[0] - 4)
        selected_slices = list(range(start_slice, start_slice + num_slice))
        #print(selected_slices) 
        for i in range(num_slice):
            slice_i = img2[selected_slices[i], :, :]
            slice_i = np.clip(slice_i, 0, None)
            gamma=round(np.random.beta(1, 5) * 0.2 + 0.5, 1) 
            ad_slice = exposure.adjust_gamma(slice_i, gamma)
            img2[selected_slices[i], :, :] = ad_slice
    return img2

def random_dark(image):
    num_slice = random.randint(0, 2) 
    if num_slice != 0:
        start_slice = random.randint(0, image.shape[0] - 5)
        selected_slices = list(range(start_slice, start_slice + num_slice))
        #print(selected_slices)
        for i in range(num_slice):
            slice_i = image[selected_slices[i], :, :]
            slice_i = np.clip(slice_i, 0, None)
            gamma=round(random.uniform(1.2, 1.7), 1) 
            ad_slice = exposure.adjust_gamma(slice_i, gamma)
            image[selected_slices[i], :, :] = ad_slice
    return image

def aug_down_gamma_up(image):
    resolu= round(random.uniform(5.0, 7.5), 1) # this down-sampling scaling parameter is given by (slices spacing/pixel spacing); slice spacing varies between 8-12 mm, pixel spacing=1.6 mm
    adj_image=scipy.ndimage.zoom(image, (1/resolu,1,1), order=1, mode='constant')
    gamma_image=random_gamma(adj_image)
    dark_image=random_dark(gamma_image)
    image3d=scipy.ndimage.zoom(dark_image, (resolu,1,1), order=3, mode='constant')
    if image3d.shape[0] >= 144:
        image3d= image3d[:144, :, :] # Crop the bottom
    elif image3d.shape[0] < 144:
        bottom_pad = 144 - image3d.shape[0]
        image3d = np.pad(image3d, ((0, bottom_pad), (0, 0), (0, 0)), mode='constant', constant_values=0)
    return image3d

In [None]:
output_channel=2
batch_size = 1
input_shape = [None,None,None,1]  #Image of all size can be used for training and testing. I choose (144, 116, 96) to match with my external testing image size.
output_shape = [None,None,None,output_channel] 

train_gen = CustomDataGen(train_patients, 'train').get_gen
val_gen   = CustomDataGen(val_patients, 'val').get_gen

output_signature = (tf.TensorSpec(shape=input_shape, dtype=tf.float32), tf.TensorSpec(shape=output_shape, dtype=tf.float32))

train_ds = tf.data.Dataset.from_generator(train_gen, output_signature = output_signature)
val_ds = tf.data.Dataset.from_generator(val_gen, output_signature = output_signature)

train_ds = train_ds.shuffle(42, seed = 42, reshuffle_each_iteration=True).batch(batch_size).prefetch(-1)
val_ds = val_ds.batch(batch_size).prefetch(-1)

In [None]:
X, y = next(iter(train_ds))

In [None]:
# Have a look on the training data
fig, axs = plt.subplots(1,2)
print(X.shape)
print(y.shape)
axs[0].imshow(X[0,...,47,0],cmap='gray')
axs[1].imshow(y[0,...,47,1],cmap='gray')   

In [None]:
# Have a look on the training data
for patient in train_patients[:2]:
    gen = CustomDataGen([patient], 'train').get_gen
    ds = tf.data.Dataset.from_generator(gen, output_signature = output_signature)
    X, y = next(iter(ds))
    print(patient.split('/')[-1])
    fig, axs = plt.subplots(1,2)
    axs[0].imshow(X[...,45,0],cmap='gray')
    axs[1].imshow(y[...,45,1],cmap='gray')   

In [None]:
def iou(y_true, y_pred, dtype=tf.float32):
    # tf tensor casting
    y_pred = tf.convert_to_tensor(y_pred)
    y_pred = tf.cast(y_pred[...,1:], dtype)
    y_true = tf.cast(y_true[...,1:], y_pred.dtype)

    y_pred = tf.squeeze(y_pred)
    y_true = tf.squeeze(y_true)
    
    y_true_pos = tf.reshape(y_true, [-1])
    y_pred_pos = tf.reshape(y_pred, [-1])

    area_intersect = tf.reduce_sum(tf.multiply(y_true_pos, y_pred_pos))
    
    area_true = tf.reduce_sum(y_true_pos)
    area_pred = tf.reduce_sum(y_pred_pos)
    area_union = area_true + area_pred - area_intersect
    
    return tf.math.divide_no_nan(area_intersect, area_union)

def dice_coef(y_true, y_pred, const=K.epsilon()):
    
    # flatten 2-d tensors
    y_true_pos = tf.reshape(y_true[...,1:], [-1])
    y_pred_pos = tf.reshape(y_pred[...,1:], [-1])
    
    # get true pos (TP), false neg (FN), false pos (FP).
    true_pos  = tf.reduce_sum(y_true_pos * y_pred_pos)
    false_neg = tf.reduce_sum(y_true_pos * (1-y_pred_pos))
    false_pos = tf.reduce_sum((1-y_true_pos) * y_pred_pos)
    
    # 2TP/(2TP+FP+FN) == 2TP/()
    coef_val = (2.0 * true_pos + const)/(2.0 * true_pos + false_pos + false_neg)
    
    return coef_val

In [None]:
if continue_training:
    model = tf.keras.models.load_model(f'models/{model_name}', compile = False)
else:
    inputs = tf.keras.Input(shape = [None,None,None,1]) 
    tf.keras.backend.clear_session()
    model = tfmr.models.UNet3D (filters=[64,128,256],
                        kernel_size=3,
                        out_activation='softmax',
                        out_channels = output_channel,
                        use_batch_norm=True)

model.compile(loss='categorical_crossentropy',
              optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), metrics=[dice_coef,iou])

In [None]:
from keras.callbacks import EarlyStopping, ModelCheckpoint
es = EarlyStopping(monitor='loss', 
                   mode='min', 
                   verbose = 1, 
                   patience = 10)
mc = ModelCheckpoint(f'models/{model_name}',
                  save_best_only= True,
                    monitor='loss',
                    mode='min')
neptune_callback = NeptuneCallback(run = run)
model.fit(train_ds,
          validation_data = val_ds, 
          epochs=300,
          callbacks=[es, mc, neptune_callback])
          
run['model'].upload(f'models/{model_name}')

In [None]:
if continue_training:
    model = tf.keras.models.load_model(f'models/{model_name}', compile = False)
print(model_name)

In [None]:
patient = val_patients[3]
X_test = []
y_test = []
test_gen   = CustomDataGen([patient], 'test').get_gen()
for X, y in test_gen:
    X_test.append(X)
    y_test.append(y)
X_test = np.stack(X_test)
y_test = np.stack(y_test)

In [None]:
def get_one_hot(targets, nb_classes):
    res = np.eye(nb_classes)[np.array(targets).reshape(-1)]
    return res.reshape(list(targets.shape)+[nb_classes])

In [None]:
y_pred = model.predict(X_test)
y_pred = get_one_hot(np.argmax(y_pred,axis = -1), 3) #np.argmax return the indice of the max value of an array

In [None]:
fig, axs = plt.subplots(1,3,figsize= (12,6))

i = 0
axs[0].imshow(X_test[i,...,45,0],cmap='gray')
axs[1].imshow(y_test[i,...,45,1],cmap='gray')
axs[2].imshow(y_pred[i,...,45,1],cmap='gray')

In [None]:
from matplotlib import animation

In [None]:
#transverse view
for patient in val_patients:
    X_test = []
    y_test = []
    test_gen   = CustomDataGen([patient], 'test').get_gen()
    for X, y in test_gen:
        X_test.append(X)
        y_test.append(y)
    X_test = np.stack(X_test)
    y_test = np.stack(y_test)
    y_pred = model.predict(X_test)
    y_pred = get_one_hot(np.argmax(y_pred,axis = -1), 2)

    fig, axs = plt.subplots(1,2, figsize = (7,5))
    frames = []
    for i in range(y_pred.shape[1]):
        p1 = axs[0].imshow(X_test[0,i,...,0],cmap = 'gray')
        p2 = axs[1].imshow(X_test[0,i,...,0],cmap = 'gray')
        p3 = axs[0].imshow(y_test[0,i,...,-1],alpha=y_test[0,i,...,-1] * 0.7,cmap = 'jet') #ground truth
        p4 = axs[1].imshow(y_pred[0,i,...,-1],alpha = y_pred[0,i,...,-1] * 0.7,cmap = 'Blues') #prediction
        frames.append([p1,p2,p3,p4])
    fig.tight_layout()
    ani = animation.ArtistAnimation(fig, frames)
    ani.save(f"...gif", fps=y_pred.shape[1])
    run[f"..."].upload(f"...gif")
    plt.close()