In [1]:
import os
import numpy as np
from Brats2020_Segmentation_Data_Generator import imageLoader
import tensorflow as tf
from matplotlib import pyplot as plt
from keras import backend as K
import random
from glob import glob
import cv2
from matplotlib import pyplot as plt
from tensorflow.keras.utils import to_categorical
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv3D, LeakyReLU, BatchNormalization, Concatenate, Activation, Dropout,Conv3DTranspose
from tensorflow import keras

from tensorflow.keras.optimizers import Adam
from IPython import display
from tqdm.auto import tqdm



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_img_dir = "BraTS2020_TrainingData/input_data_128/train/images/"
train_mask_dir = "BraTS2020_TrainingData/input_data_128/train/masks/"

val_img_dir = "BraTS2020_TrainingData/input_data_128/val/images/"
val_mask_dir = "BraTS2020_TrainingData/input_data_128/val/masks/"

train_img_list=os.listdir(train_img_dir)
train_mask_list = os.listdir(train_mask_dir)

val_img_list=os.listdir(val_img_dir)
val_mask_list = os.listdir(val_mask_dir)

In [3]:
batch_size = 1

train_dataset = imageLoader(train_img_dir, train_img_list, 
                                train_mask_dir, train_mask_list, batch_size)

val_dataset = imageLoader(val_img_dir, val_img_list, 
                                val_mask_dir, val_mask_list, batch_size)

In [4]:
input_images, target_images=train_dataset.__next__()

In [5]:
input_images.shape

(1, 128, 128, 128, 4)

In [6]:
target_images.shape

(1, 128, 128, 128, 4)

In [7]:
# Print shape information
print("Shape of input images:", input_images.shape)
print("Shape of target images:", target_images.shape)

Shape of input images: (1, 128, 128, 128, 4)
Shape of target images: (1, 128, 128, 128, 4)


In [8]:
num_classes=4


In [9]:
def build_discriminator():
    source_image= Input(shape=(128, 128, 128, 4))
    target_image= Input(shape=(128, 128, 128, num_classes))
    
    cat= Concatenate()([source_image, target_image]) ## this makes the Gan conditional
    d= Conv3D(64, kernel_size=4, strides=2, padding="same", 
                     kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
                     use_bias=False)(cat)
    d=LeakyReLU(0.3)(d)
    d=Dropout(0.2)(d)
    
    
    d=Conv3D(128, kernel_size=4, strides=2, padding="same", 
                     kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
                     use_bias=False)(d)
    d= BatchNormalization()(d)
    d=LeakyReLU(0.3)(d)
    d=Dropout(0.2)(d)
    
    
    d=Conv3D(256, kernel_size=4, strides=2, padding="same", 
                     kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
                     use_bias=False)(d)
    d= BatchNormalization()(d)
    d=LeakyReLU(0.3)(d)
    d=Dropout(0.2)(d)
    
    
    d=Conv3D(512, kernel_size=4, strides=2, padding="same", 
                     kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
                     use_bias=False)(d)
    d= BatchNormalization()(d)
    d=LeakyReLU(0.2)(d)
    
    
    
    d=Conv3D(1, kernel_size=4, strides=1, padding="same", 
                     kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
                     use_bias=False)(d)
    
    
    
    return Model([source_image, target_image], d)

In [10]:
discriminator = build_discriminator()
discriminator.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 128, 128, 1  0           []                               
                                28, 4)]                                                           
                                                                                                  
 input_2 (InputLayer)           [(None, 128, 128, 1  0           []                               
                                28, 4)]                                                           
                                                                                                  
 concatenate (Concatenate)      (None, 128, 128, 12  0           ['input_1[0][0]',                
                                8, 8)                             'input_2[0][0]']            

In [11]:
def encoder_block(inputs, filters, batchnorm=True):
    g= Conv3D(filters, kernel_size=4, strides=2, padding="same", 
                     kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
                     use_bias=False)(inputs)
    if batchnorm:
        g= BatchNormalization()(g)
                                
    g= LeakyReLU(0.2)(g)
    g = Dropout(0.2)(g)
    return g

In [12]:
def decoder_block(inputs, skips, filters, dropout=True):
    g= Conv3DTranspose(filters, kernel_size=4, strides=2, padding="same", 
                        kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), use_bias=False)(inputs)
    g= BatchNormalization()(g)
    g= LeakyReLU(0.3)(g)
    g= Concatenate()([g, skips])
    g=Dropout(0.2)(g)
    
    return g    

In [13]:
def bottleneck(inputs, filters):
    x= Conv3D(filters, kernel_size=4, strides=2, padding="same", 
                     kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
                     use_bias=False)(inputs)
    x=BatchNormalization()(x)
    x= LeakyReLU(0.3)(x)
    
    for i in range(4):
        y=Conv3D(filters, kernel_size=4, strides=1, padding="same", 
                     kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
                     use_bias=False)(x)
        x = BatchNormalization()(y)
        x= LeakyReLU(0.3)(x)
        x = Concatenate()([x, y])
        print(x.shape)
        
    return x

In [14]:
def build_generator():
    input_image=Input(shape=(128, 128, 128, 4))
    
    d1= encoder_block(input_image, 64, batchnorm=False) 
    print(d1.shape)
    d2= encoder_block(d1, 128) 
    print(d2.shape)
    d3= encoder_block(d2, 256) 
    print(d3.shape)
    
    
    neck= bottleneck(d3, 512)#8x8x1024
    
    
    u1= decoder_block(neck, d3, 256) 
    print(u1.shape)
    u2 = decoder_block(u1, d2, 128) 
    print(u2.shape)
    u3= decoder_block(u2, d1,64) 
    print(u3.shape)
    
    
    
    final_conv= Conv3DTranspose(num_classes, kernel_size=4, strides=2, padding="same", 
                        kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), use_bias=False)(u3)
    print(final_conv.shape)
    
    out= Activation("softmax")(final_conv)
    print(out.shape)
    return Model(input_image, out)

In [15]:
generator = build_generator()
generator.summary()

(None, 64, 64, 64, 64)
(None, 32, 32, 32, 128)
(None, 16, 16, 16, 256)
(None, 8, 8, 8, 1024)
(None, 8, 8, 8, 1024)
(None, 8, 8, 8, 1024)
(None, 8, 8, 8, 1024)
(None, 16, 16, 16, 512)
(None, 32, 32, 32, 256)
(None, 64, 64, 64, 128)
(None, 128, 128, 128, 4)
(None, 128, 128, 128, 4)
Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 128, 128, 1  0           []                               
                                28, 4)]                                                           
                                                                                                  
 conv3d_5 (Conv3D)              (None, 64, 64, 64,   16384       ['input_3[0][0]']                
                                64)                                                               
         

                                                                                                  
 concatenate_3 (Concatenate)    (None, 8, 8, 8, 102  0           ['leaky_re_lu_10[0][0]',         
                                4)                                'conv3d_11[0][0]']              
                                                                                                  
 conv3d_12 (Conv3D)             (None, 8, 8, 8, 512  33554432    ['concatenate_3[0][0]']          
                                )                                                                 
                                                                                                  
 batch_normalization_9 (BatchNo  (None, 8, 8, 8, 512  2048       ['conv3d_12[0][0]']              
 rmalization)                   )                                                                 
                                                                                                  
 leaky_re_

In [16]:
#####Losses

In [17]:
def generalized_dice_loss(y_true, y_pred):
    smooth = 1e-5  # Smoothing factor to avoid division by zero
    
    # Flatten the true and predicted tensors
    y_true_flat = tf.keras.backend.flatten(y_true)
    y_pred_flat = tf.keras.backend.flatten(y_pred)
    
    # Compute the weights for each class based on inverse of class frequency
    weights = 1.0 / (tf.reduce_sum(y_true_flat, axis=0)**2 + smooth)
    # Compute the intersection and sum of true and predicted tensors
    intersection = weights*(tf.reduce_sum(y_true_flat * y_pred_flat, axis=0))
    intersection= tf.reduce_sum(intersection)
    
    sum_true = tf.reduce_sum(y_true_flat, axis=0)
    sum_pred = tf.reduce_sum(y_pred_flat, axis=0)
    union=tf.reduce_sum(weights*(sum_true+sum_pred))
    
    # Compute the dice coefficient for each class
    dice_coefficient = (2.0 * intersection + smooth) / (union + smooth)
    generalized_dice_loss= 1- dice_coefficient
    return generalized_dice_loss
    

In [18]:
# Discriminator loss ###L2 loss between real and fake outputs
def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = tf.reduce_mean(tf.math.pow(tf.ones_like(disc_real_output)- disc_real_output,2))
    fake_loss = tf.reduce_mean(tf.math.pow(tf.zeros_like(disc_generated_output)- disc_generated_output,2))
    total_disc_loss = 0.5*(real_loss + fake_loss)

    return total_disc_loss

In [19]:
# Generator loss
def generator_loss(disc_generated_output, generated_images, target):
    gan_loss= tf.reduce_mean(tf.math.pow(tf.ones_like(disc_generated_output)-disc_generated_output,2))
    dice_loss= generalized_dice_loss(target, generated_images)
    total_gen_loss= 5*dice_loss + gan_loss

    return total_gen_loss

In [20]:
# Instantiate generator and discriminator
generator = build_generator()
discriminator = build_discriminator()

(None, 64, 64, 64, 64)
(None, 32, 32, 32, 128)
(None, 16, 16, 16, 256)
(None, 8, 8, 8, 1024)
(None, 8, 8, 8, 1024)
(None, 8, 8, 8, 1024)
(None, 8, 8, 8, 1024)
(None, 16, 16, 16, 512)
(None, 32, 32, 32, 256)
(None, 64, 64, 64, 128)
(None, 128, 128, 128, 4)
(None, 128, 128, 128, 4)


In [21]:
# Optimizers
generator_optimizer = tf.keras.optimizers.Adam(0.0002, beta_1=0.5, beta_2=0.999)
discriminator_optimizer = tf.keras.optimizers.Adam(0.0002, beta_1=0.5, beta_2=0.999)

In [22]:
@tf.function
def train_step(input_image, target):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # Generate fake image
        generated_images = generator(input_image, training=True)

        # Discriminator outputs
        disc_real_output = discriminator([input_image, target], training=True)
        disc_generated_output = discriminator([input_image, generated_images], training=True)

        # Calculate losses
        gen_total_loss = generator_loss(disc_generated_output, generated_images, target)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

    # Calculate gradients
    generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    # Apply gradients
    generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))
    return gen_total_loss, disc_loss

In [23]:
num_epochs=50
total_images = 275
steps_per_epoch = total_images // batch_size

In [25]:
for epoch in range(num_epochs):
    print("Epoch:", epoch+1)
    
    # Create a progress bar for the training dataset
    progress_bar = tqdm(range(steps_per_epoch), desc=f"Epoch {epoch+1}")
    
    # Iterate over the training dataset
    for step in progress_bar:
        # Get the next batch of input and target images
        input_images, target_images = next(train_dataset)
        
        # Perform a single training step
        gen_loss, disc_loss = train_step(input_images, target_images)
        
        # Update the progress bar description with the current losses
        progress_bar.set_postfix({"Generator Loss": gen_loss.numpy(), "Discriminator Loss": disc_loss.numpy()})

Epoch: 1


Epoch 1: 100%|██████████| 275/275 [05:12<00:00,  1.14s/it, Generator Loss=0.45, Discriminator Loss=0.325] 


Epoch: 2


Epoch 2: 100%|██████████| 275/275 [05:04<00:00,  1.11s/it, Generator Loss=0.426, Discriminator Loss=0.39] 


Epoch: 3


Epoch 3: 100%|██████████| 275/275 [05:03<00:00,  1.10s/it, Generator Loss=1.71, Discriminator Loss=0.93]   


Epoch: 4


Epoch 4: 100%|██████████| 275/275 [05:05<00:00,  1.11s/it, Generator Loss=0.67, Discriminator Loss=0.24]   


Epoch: 5


Epoch 5: 100%|██████████| 275/275 [05:05<00:00,  1.11s/it, Generator Loss=0.951, Discriminator Loss=0.108] 


Epoch: 6


Epoch 6: 100%|██████████| 275/275 [05:04<00:00,  1.11s/it, Generator Loss=0.678, Discriminator Loss=0.182] 


Epoch: 7


Epoch 7: 100%|██████████| 275/275 [05:03<00:00,  1.10s/it, Generator Loss=1.26, Discriminator Loss=0.146]  


Epoch: 8


Epoch 8: 100%|██████████| 275/275 [05:14<00:00,  1.14s/it, Generator Loss=0.505, Discriminator Loss=0.148] 


Epoch: 9


Epoch 9: 100%|██████████| 275/275 [05:08<00:00,  1.12s/it, Generator Loss=0.633, Discriminator Loss=0.12]  


Epoch: 10


Epoch 10: 100%|██████████| 275/275 [05:04<00:00,  1.11s/it, Generator Loss=0.456, Discriminator Loss=0.195] 


Epoch: 11


Epoch 11: 100%|██████████| 275/275 [05:20<00:00,  1.17s/it, Generator Loss=0.387, Discriminator Loss=0.28]  


Epoch: 12


Epoch 12: 100%|██████████| 275/275 [05:12<00:00,  1.13s/it, Generator Loss=0.498, Discriminator Loss=0.174] 


Epoch: 13


Epoch 13: 100%|██████████| 275/275 [05:12<00:00,  1.14s/it, Generator Loss=0.438, Discriminator Loss=0.131] 


Epoch: 14


Epoch 14: 100%|██████████| 275/275 [05:16<00:00,  1.15s/it, Generator Loss=0.277, Discriminator Loss=0.265] 


Epoch: 15


Epoch 15: 100%|██████████| 275/275 [05:16<00:00,  1.15s/it, Generator Loss=0.207, Discriminator Loss=0.317] 


Epoch: 16


Epoch 16: 100%|██████████| 275/275 [05:15<00:00,  1.15s/it, Generator Loss=0.323, Discriminator Loss=0.307] 


Epoch: 17


Epoch 17: 100%|██████████| 275/275 [05:09<00:00,  1.13s/it, Generator Loss=0.871, Discriminator Loss=0.188] 


Epoch: 18


Epoch 18: 100%|██████████| 275/275 [05:08<00:00,  1.12s/it, Generator Loss=0.453, Discriminator Loss=0.23]  


Epoch: 19


Epoch 19: 100%|██████████| 275/275 [05:10<00:00,  1.13s/it, Generator Loss=0.604, Discriminator Loss=0.223] 


Epoch: 20


Epoch 20: 100%|██████████| 275/275 [04:59<00:00,  1.09s/it, Generator Loss=1.31, Discriminator Loss=0.46]   


Epoch: 21


Epoch 21: 100%|██████████| 275/275 [05:05<00:00,  1.11s/it, Generator Loss=0.187, Discriminator Loss=0.236] 


Epoch: 22


Epoch 22: 100%|██████████| 275/275 [05:01<00:00,  1.10s/it, Generator Loss=0.254, Discriminator Loss=0.304] 


Epoch: 23


Epoch 23: 100%|██████████| 275/275 [05:02<00:00,  1.10s/it, Generator Loss=0.162, Discriminator Loss=0.249] 


Epoch: 24


Epoch 24: 100%|██████████| 275/275 [05:03<00:00,  1.11s/it, Generator Loss=0.273, Discriminator Loss=0.245] 


Epoch: 25


Epoch 25: 100%|██████████| 275/275 [05:04<00:00,  1.11s/it, Generator Loss=0.254, Discriminator Loss=0.276]


Epoch: 26


Epoch 26: 100%|██████████| 275/275 [05:01<00:00,  1.10s/it, Generator Loss=0.245, Discriminator Loss=0.27] 


Epoch: 27


Epoch 27: 100%|██████████| 275/275 [05:02<00:00,  1.10s/it, Generator Loss=0.336, Discriminator Loss=0.232] 


Epoch: 28


Epoch 28: 100%|██████████| 275/275 [05:03<00:00,  1.10s/it, Generator Loss=0.308, Discriminator Loss=0.236] 


Epoch: 29


Epoch 29: 100%|██████████| 275/275 [05:04<00:00,  1.11s/it, Generator Loss=0.448, Discriminator Loss=0.196] 


Epoch: 30


Epoch 30: 100%|██████████| 275/275 [05:02<00:00,  1.10s/it, Generator Loss=0.353, Discriminator Loss=0.232] 


Epoch: 31


Epoch 31: 100%|██████████| 275/275 [05:03<00:00,  1.10s/it, Generator Loss=0.457, Discriminator Loss=0.1]   


Epoch: 32


Epoch 32: 100%|██████████| 275/275 [05:02<00:00,  1.10s/it, Generator Loss=0.293, Discriminator Loss=0.246] 


Epoch: 33


Epoch 33: 100%|██████████| 275/275 [05:02<00:00,  1.10s/it, Generator Loss=0.237, Discriminator Loss=0.426] 


Epoch: 34


Epoch 34: 100%|██████████| 275/275 [05:04<00:00,  1.11s/it, Generator Loss=0.978, Discriminator Loss=0.371] 


Epoch: 35


Epoch 35: 100%|██████████| 275/275 [04:58<00:00,  1.08s/it, Generator Loss=0.712, Discriminator Loss=0.255] 


Epoch: 36


Epoch 36: 100%|██████████| 275/275 [05:43<00:00,  1.25s/it, Generator Loss=0.545, Discriminator Loss=0.199] 


Epoch: 37


Epoch 37: 100%|██████████| 275/275 [05:01<00:00,  1.10s/it, Generator Loss=0.919, Discriminator Loss=0.182] 


Epoch: 38


Epoch 38: 100%|██████████| 275/275 [05:01<00:00,  1.10s/it, Generator Loss=0.777, Discriminator Loss=0.0778]


Epoch: 39


Epoch 39: 100%|██████████| 275/275 [05:02<00:00,  1.10s/it, Generator Loss=0.443, Discriminator Loss=0.306] 


Epoch: 40


Epoch 40: 100%|██████████| 275/275 [05:03<00:00,  1.10s/it, Generator Loss=0.29, Discriminator Loss=0.326]  


Epoch: 41


Epoch 41: 100%|██████████| 275/275 [05:03<00:00,  1.10s/it, Generator Loss=0.198, Discriminator Loss=0.253] 


Epoch: 42


Epoch 42: 100%|██████████| 275/275 [05:01<00:00,  1.09s/it, Generator Loss=0.401, Discriminator Loss=0.124] 


Epoch: 43


Epoch 43: 100%|██████████| 275/275 [05:00<00:00,  1.09s/it, Generator Loss=0.889, Discriminator Loss=0.215] 


Epoch: 44


Epoch 44: 100%|██████████| 275/275 [04:59<00:00,  1.09s/it, Generator Loss=0.327, Discriminator Loss=0.32]  


Epoch: 45


Epoch 45: 100%|██████████| 275/275 [04:59<00:00,  1.09s/it, Generator Loss=0.119, Discriminator Loss=0.399] 


Epoch: 46


Epoch 46: 100%|██████████| 275/275 [04:59<00:00,  1.09s/it, Generator Loss=0.365, Discriminator Loss=0.154] 


Epoch: 47


Epoch 47: 100%|██████████| 275/275 [04:58<00:00,  1.09s/it, Generator Loss=0.253, Discriminator Loss=0.287] 


Epoch: 48


Epoch 48: 100%|██████████| 275/275 [05:00<00:00,  1.09s/it, Generator Loss=0.796, Discriminator Loss=0.294] 


Epoch: 49


Epoch 49: 100%|██████████| 275/275 [04:58<00:00,  1.08s/it, Generator Loss=0.379, Discriminator Loss=0.209] 


Epoch: 50


Epoch 50: 100%|██████████| 275/275 [04:58<00:00,  1.08s/it, Generator Loss=0.512, Discriminator Loss=0.227] 


In [26]:
# Save the trained generator model
from tensorflow.keras.models import save_model
# Save the trained generator model
generator.save("generator50")
# Save the trained discriminator model
discriminator.save("discriminator50")





INFO:tensorflow:Assets written to: generator50\assets


INFO:tensorflow:Assets written to: generator50\assets






INFO:tensorflow:Assets written to: discriminator50\assets


INFO:tensorflow:Assets written to: discriminator50\assets
