In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import (Input, Conv2D, Dropout,
                                     Conv2DTranspose, MaxPool2D, Concatenate)
from functools import partial
import matplotlib.pyplot as plt

from utils_brats import (NUM_SLICES, IMG_SIZE, CLASSES)
from utils_brats import (dice_coef, precision, sensitivity, specificity, 
                 dice_coef_necrotic, dice_coef_edema ,dice_coef_enhancing)

In [None]:
 tf.config.experimental.list_physical_devices('GPU')

In [None]:
for x, y in \
tf.data.TFRecordDataset(r"/kaggle/input/brats-dataset-vir/valset.tfrecord").\
    shuffle(8).take(2).map(parse_records):
    
    x, y = tf.squeeze(x, 0), tf.squeeze(y, 0)
    
    slice_num = 50
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 15))
    
    ax1.imshow(x[slice_num,..., 0], cmap='gray')
    ax1.axis(False)
    ax1.set_title("Flair")
    
    ax2.imshow(x[slice_num, ..., 1], cmap='gray')
    ax2.axis(False)
    ax2.set_title("T1ce")
    
    ax3.imshow(x[slice_num, ..., 1], cmap='gray', alpha=0.8)
    ax3.imshow(tf.argmax(y, -1)[slice_num], cmap='OrRd', alpha=0.7)
    ax3.axis(False)
    ax3.set_title("Tumour segmented")
    
    plt.tight_layout()
    plt.show()
    print()

In [None]:
class DataLoader(tf.keras.utils.Sequence):
    def __init__(self, tfrecord, batch_size):
        self.dataset = tf.data.TFRecordDataset(tfrecord)
        self.size = self.dataset.reduce(0, lambda x, _: x+1).numpy()
        self.batch_size = batch_size
        self.mini_batch = None
        self.on_epoch_end()
        
    def __len__(self):
        return self.size // self.batch_size
    
    def on_epoch_end(self):
        self.dataset = self.dataset.shuffle(buffer_size=self.size)
        self.mini_batch = iter(self.dataset.batch(self.batch_size, drop_remainder=True))
        
    def __getitem__(self, idx):
        try:
            batch = next(self.mini_batch)
        except StopIteration:
            self.on_epoch_end()  # Reset the iterator and shuffle for a new epoch
            batch = next(self.mini_batch)
        x, y = parse_records(batch)
        return tf.squeeze(x/tf.reduce_max(x), 0), tf.squeeze(y, 0)

In [None]:
trainloader = DataLoader("/kaggle/input/brats-dataset-vir/trainset.tfrecord", batch_size=1)
valloader = DataLoader("/kaggle/input/brats-dataset-vir/valset.tfrecord", batch_size=1)

In [None]:
def U_net2d(input_shape, classes):
    
    conv = partial(Conv2D, kernel_size=3, 
                   activation='relu', padding='same')
    convT = partial(Conv2DTranspose, kernel_size=2, 
                    strides=2, padding='same', activation='relu')
    x = inputs = Input(shape=input_shape)
   
    # encoder part:
    x = conv(32)(x)
    x = x1 = conv(32)(x)
    x = MaxPool2D(pool_size=2)(x)

    x = conv(64)(x)
    x = x2 = conv(64)(x)
    x = MaxPool2D(pool_size=2)(x)

    x = conv(128)(x)
    x = x3 = conv(128)(x)
    x = MaxPool2D(pool_size=2)(x)

    x = conv(256)(x)
    x = x4 = conv(256)(x)
    x = MaxPool2D(pool_size=2)(x)

    x = conv(512)(x)
    x = conv(512)(x)

    # decoder part:
    x = convT(256)(x)
    x = Concatenate(axis=-1)([x4, x])
    x = conv(256)(x)
    x = conv(256)(x)

    x = convT(128)(x)
    x = Concatenate(axis=-1)([x3, x])
    x = conv(128)(x)
    x = conv(128)(x)

    x = convT(64)(x)
    x = Concatenate(axis=-1)([x2, x])
    x = conv(64)(x)
    x = conv(64)(x)

    x = convT(32)(x)
    x = Concatenate(axis=-1)([x1, x])
    x = conv(32)(x)
    x = conv(32)(x)

    x = outputs = Conv2D(classes, kernel_size=1, activation='softmax')(x)

    return tf.keras.Model(inputs=inputs, outputs=outputs, name='2D-U-NET')


In [None]:
unet = U_net2d(input_shape=(IMG_SIZE, IMG_SIZE, 2), classes=len(CLASSES))

In [None]:
unet.summary()

In [None]:
unet.compile(loss="categorical_crossentropy", optimizer=tf.keras.optimizers.Adam(), 
             metrics = ['accuracy', tf.keras.metrics.MeanIoU(num_classes=len(CLASSES)), 
                        dice_coef, precision, sensitivity, specificity, dice_coef_necrotic, 
                        dice_coef_edema ,dice_coef_enhancing])

In [None]:
callbackks = [tf.keras.callbacks.ModelCheckpoint("2dunet_weights.h5", save_best_only=True),
             tf.keras.callbacks.EarlyStopping(patience=5, verbose=1)]

In [None]:
unet.fit(trainloader, validation_data=valloader, epochs=30, callbacks=callbackks)

In [None]:
unet.save("2dunet_vir.h5")