In [None]:
import os
import numpy as np
import keras
from matplotlib import pyplot as plt
import glob
import random

train_img_dir = "./NPYcombined/"
train_mask_dir = "./NPYmask/"
val_img_dir = "./Validation/image/"
val_mask_dir = "./Validation/mask/"

img_list = os.listdir(train_img_dir)
msk_list = os.listdir(train_mask_dir)
val_img_list = os.listdir(val_img_dir)
val_mask_list = os.listdir(val_mask_dir)
print(img_list)
print(len(img_list))

In [None]:
num_images = len(img_list)

img_num = random.randint(0,num_images-1)
test_img = np.load(train_img_dir+img_list[img_num])
test_mask = np.load(train_mask_dir+msk_list[img_num])
print(test_img.shape)
print(img_list[img_num])

In [None]:
test_mask = np.argmax(test_mask, axis=3)

n_slice=random.randint(0, test_mask.shape[2])
plt.figure(figsize=(12, 8))

plt.subplot(221)
plt.imshow(test_img[:,:,n_slice, 0], cmap='gray')
plt.title('Image flair')
plt.subplot(222)
plt.imshow(test_img[:,:,n_slice, 1], cmap='gray')
plt.title('Image t1c')
plt.subplot(223)
plt.imshow(test_img[:,:,n_slice, 2], cmap='gray')
plt.title('Image t2w')
plt.subplot(224)
plt.imshow(test_mask[:,:,n_slice])
plt.title('Mask')
plt.show()


In [None]:
def load_img(img_dir, img_list):
    images=[]
    for i, image_name in enumerate(img_list):    
        if (image_name.split('.')[1] == 'npy'):
            image = np.load(img_dir+image_name)
            images.append(image)
    images = np.array(images)
    return(images)

def imageLoader(img_dir, img_list, mask_dir, mask_list, batch_size):
    L = len(img_list)
    while True:
        batch_start = 0
        batch_end = batch_size
        while batch_start < L:
            limit = min(batch_end, L)
            X = load_img(img_dir, img_list[batch_start:limit])
            Y = load_img(mask_dir, mask_list[batch_start:limit])
            yield (X,Y) #a tuple with two numpy arrays with batch_size samples     
            batch_start += batch_size   
            batch_end += batch_size

In [None]:
batch_size = 4

train_img_datagen = imageLoader(train_img_dir, img_list, 
                                train_mask_dir, msk_list, batch_size)

img, msk = train_img_datagen.__next__()

img_num = random.randint(0,img.shape[0]-1)
test_img=img[img_num]
test_mask=msk[img_num]
test_mask=np.argmax(test_mask, axis=3)

n_slice=random.randint(0, test_mask.shape[2])
plt.figure(figsize=(12, 8))

plt.subplot(221)
plt.imshow(test_img[:,:,n_slice, 0], cmap='gray')
plt.title('Image flair')
plt.subplot(222)
plt.imshow(test_img[:,:,n_slice, 1], cmap='gray')
plt.title('Image t1c')
plt.subplot(223)
plt.imshow(test_img[:,:,n_slice, 2], cmap='gray')
plt.title('Image t2w')
plt.subplot(224)
plt.imshow(test_mask[:,:,n_slice])
plt.title('Mask')
plt.show()

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.optimizers import Adam

wt0, wt1, wt2, wt3 = 0.25,0.25,0.25,0.25
import segmentation_models_3D as sm

def dice_loss(y_true, y_pred):
    smooth = 1.
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    return 1 - (2. * intersection + smooth) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + smooth)

def categorical_focal_loss(alpha, gamma):
    def focal_loss_fixed(y_true, y_pred):
        y_pred = tf.keras.backend.clip(y_pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon())
        y_true = tf.cast(y_true, tf.float32)
        alpha_t = y_true * alpha + (tf.keras.backend.ones_like(y_true) - y_true) * (1 - alpha)
        p_t = y_true * y_pred + (tf.keras.backend.ones_like(y_true) - y_true) * (1 - y_pred)
        fl = - alpha_t * tf.keras.backend.pow((tf.keras.backend.ones_like(y_true) - p_t), gamma) * tf.keras.backend.log(p_t)
        return tf.keras.backend.sum(fl, axis=-1)
    return focal_loss_fixed

def combined_loss(y_true, y_pred):
    dice = dice_loss(y_true, y_pred)
    focal = categorical_focal_loss(alpha=0.25, gamma=2.0)(y_true, y_pred)
    return dice + focal

project_dir = 'C:/Users/Romir/Desktop/Projects/BraTS/'
os.chdir(project_dir)


In [None]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, concatenate, Conv3DTranspose, BatchNormalization, Dropout, Dense, Flatten, Reshape, LayerNormalization

kernel_initializer = 'he_uniform'

def kan_layer(x, units):
    x = Dense(units, activation='relu')(x)
    return x

def tok_kan_block(x, filters):
    x = Conv3D(filters, (1, 1, 1), padding='same')(x)
    x = LayerNormalization()(x)
    x = kan_layer(x, filters)
    x = Conv3D(filters, (3, 3, 3), padding='same')(x)
    x = LayerNormalization()(x)
    return x

def U_KAN(IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS, num_classes):
    inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS))
    s = inputs

    #Contraction path
    #convultional layer with 16 filters, kernel size of 3x3x3, relu activation function and he_uniform kernel initializer , padding same to keep the same size of the image
    c1 = Conv3D(16, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(s)
    #dropout layer to prevent overfitting, which randomly sets 10% of the input units to 0 at each update during training time
    c1 = Dropout(0.1)(c1)
    c1 = Conv3D(16, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(c1)
    #maxpooling layer with pool size of 2x2x2 to reduce the size of the image by factor of 2,which downsamples the feature map, also helps in reducing the computational cost and prevent overfitting
    p1 = MaxPooling3D((2, 2, 2), padding='same')(c1)
    print("p1 shape:", p1.shape)
    
    c2 = Conv3D(32, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(p1)
    c2 = Dropout(0.1)(c2)
    c2 = Conv3D(32, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(c2)
    p2 = MaxPooling3D((2, 2, 2), padding='same')(c2)
    print("p2 shape:", p2.shape)

    c3 = tok_kan_block(p2, 64)
    p3 = MaxPooling3D((2, 2, 2), padding='same')(c3)
    print("p3 shape:", p3.shape)
     
    c4 = Conv3D(128, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(p3)
    c4 = Dropout(0.2)(c4)
    c4 = Conv3D(128, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(c4)
    p4 = MaxPooling3D(pool_size=(2, 2, 2), padding='same')(c4)
    print("p4 shape:", p4.shape)

    #deepest layer / bottleneck layer
    c5 = tok_kan_block(p4, 256)
    print("c5 shape:", c5.shape)
    
    #Expansive path 
    #Tranpose convolutional layer / deconvolutional layer 
    #it upsamples the feature map by a factor of 2, which helps in increasing the size of the image
    #Stride determines how much the window is moved in each step. A stride of 2 in each dimension doubles the spatial dimensions.
    u6 = Conv3DTranspose(128, (2, 2, 2), strides=(2, 2, 2), padding='same')(c5)
    #helps to preserve spatial information lost during downsamplin by combining u6 and c4
    u6 = concatenate([u6, c4])
    c6 = Conv3D(128, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(u6)
    c6 = Dropout(0.2)(c6)
    c6 = Conv3D(128, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(c6)
    print("c6 shape:", c6.shape)
     
    u7 = Conv3DTranspose(64, (2, 2, 2), strides=(2, 2, 2), padding='same')(c6)
    u7 = concatenate([u7, c3])
    c7 = Conv3D(64, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(u7)
    c7 = Dropout(0.2)(c7)
    c7 = Conv3D(64, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(c7)
    print("c7 shape:", c7.shape)

    c7 = tok_kan_block(c7, 64)
    print("tokenized c7 shape:", c7.shape)
     
    u8 = Conv3DTranspose(32, (2, 2, 2), strides=(2, 2, 2), padding='same')(c7)
    u8 = concatenate([u8, c2])
    c8 = Conv3D(32, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(u8)
    c8 = Dropout(0.1)(c8)
    c8 = Conv3D(32, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(c8)
    print("c8 shape:", c8.shape)
     
    u9 = Conv3DTranspose(16, (2, 2, 2), strides=(2, 2, 2), padding='same')(c8)
    u9 = concatenate([u9, c1])
    c9 = Conv3D(16, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(u9)
    c9 = Dropout(0.1)(c9)
    c9 = Conv3D(16, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(c9)
    print("c9 shape:", c9.shape)
     
    outputs = Conv3D(num_classes, (1, 1, 1), activation='softmax')(c9)
     
    model = Model(inputs=[inputs], outputs=[outputs])
    model.summary()
    
    return model

In [None]:
# from unet_3d_model import U_KAN

IMG_HEIGHT = 128
IMG_WIDTH = 128
IMG_DEPTH = 128
IMG_CHANNELS = 3
num_classes = 4

model = U_KAN(IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS, num_classes)

model.compile(optimizer=Adam(learning_rate=0.0001), loss=combined_loss, metrics=['accuracy', sm.metrics.IOUScore(threshold=0.5)])

steps_per_epoch = len(img_list) // batch_size

In [None]:
history = model.fit(
    train_img_datagen,
    steps_per_epoch=steps_per_epoch,
    epochs=100,
    verbose=1,
)

model.save('brats_2.h5')

In [None]:
LR = 0.0001
optim = tf.keras.optimizers.Adam(LR)
metrics = ['accuracy', sm.metrics.IOUScore(threshold=0.5)]
model = tf.keras.models.load_model('brats_1.h5', compile=False)

model.compile(optimizer=optim, loss=combined_loss, metrics=metrics)

loss, accuracy, iou_score = model.evaluate(train_img_datagen, steps=steps_per_epoch)
print(f'Validation Loss: {loss}')
print(f'Validation Accuracy: {accuracy}')
print(f'Validation IOU Score: {iou_score}')
    

In [None]:
from keras.models import load_model

my_model = load_model('brats_1.h5', 
                      custom_objects={'dice_loss_plus_1focal_loss': combined_loss,
                                      'iou_score':sm.metrics.IOUScore(threshold=0.5)})

history2=my_model.fit(train_img_datagen,
          steps_per_epoch=steps_per_epoch,
          epochs=1,
          verbose=1
          )

In [None]:
my_model = load_model('brats_1.h5', 
                      compile=False)

from keras.metrics import MeanIoU

batch_size=4
test_img_datagen = imageLoader(val_img_dir, val_img_list, 
                                val_mask_dir, val_mask_list, batch_size)

#Verify generator.... In python 3 next() is renamed as __next__()
test_image_batch, test_mask_batch = test_img_datagen.__next__()

test_mask_batch_argmax = np.argmax(test_mask_batch, axis=4)
test_pred_batch = my_model.predict(test_image_batch)
test_pred_batch_argmax = np.argmax(test_pred_batch, axis=4)

n_classes = 4
IOU_keras = MeanIoU(num_classes=n_classes)  
IOU_keras.update_state(test_pred_batch_argmax, test_mask_batch_argmax)

print("Mean IoU =", IOU_keras.result().numpy())

In [None]:
img_num = 28

test_img = np.load("C:/Users/Romir/Desktop/Projects/BraTS/Validation/image/image_"+str(img_num)+".npy")

test_mask = np.load("C:/Users/Romir/Desktop/Projects/BraTS/Validation/mask/mask_"+str(img_num)+".npy")
test_mask_argmax=np.argmax(test_mask, axis=3)

test_img_input = np.expand_dims(test_img, axis=0)
test_prediction = my_model.predict(test_img_input)
test_prediction_argmax=np.argmax(test_prediction, axis=4)[0,:,:,:]


# print(test_prediction_argmax.shape)
# print(test_mask_argmax.shape)
# print(np.unique(test_prediction_argmax))


#Plot individual slices from test predictions for verification
from matplotlib import pyplot as plt
import random

n_slice=random.randint(0, test_prediction_argmax.shape[2])
# n_slice = 50
plt.figure(figsize=(12, 8))
plt.subplot(231)
plt.title('Testing Image')
plt.imshow(test_img[:,:,n_slice,1], cmap='gray')
plt.subplot(232)
plt.title('Testing Label')
plt.imshow(test_mask_argmax[:,:,n_slice])
plt.subplot(233)
plt.title('Prediction on test image')
plt.imshow(test_prediction_argmax[:,:, n_slice])
plt.show()