# Import Pakages 
* numpy
* tensorflow
* matplotlib
* pathlib
* glob
* Tcl
* tensorflow_addons.layers
* RandomNormal
* Model
* Input,Conv2D,Conv2DTranspose,LeakyReLU,Activation,Concatenate
  

In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import pathlib
import glob 
from tkinter import Tcl
import tensorflow_addons.layers as tfal
from keras.initializers import RandomNormal
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input,Conv2D,Conv2DTranspose,LeakyReLU,Activation,Concatenate

# PATHS OF DATA 

In [None]:
data_dir_t1 = pathlib.Path(r"path of folder that contain T2 images")
data_dir_t2 = pathlib.Path(r"path of folder that contain FLAIR images")

print("T2 MRI images: ",len(list(data_dir_t1.glob('*/*.png'))))
print("FLAIR MRI images: ",len(list(data_dir_t2.glob('*/*.png'))))

# Initializing constants

In [None]:
BUFFER_SIZE = 1000
BATCH_SIZE = 1
EPOCHS = 400
img_height = 256
img_width = 256

# T2 MRI images Train set


In [None]:
tr1_train = tf.keras.preprocessing.image_dataset_from_directory(
                              data_dir_t1,
                              seed=123,
                              validation_split = 0.06,
                              subset = 'training',
                              labels=None,
                              image_size = (img_height, img_width),
                              batch_size=BATCH_SIZE)

# T2 MRI images Test set


In [None]:
# T2 MRI images Test set
tr1_test = tf.keras.preprocessing.image_dataset_from_directory(
                              data_dir_t1,
                              seed=123,
                              validation_split = 0.06,
                              subset = 'validation',
                              image_size=(img_height, img_width),
                              batch_size=1)

# FLAIR MRI images Train set


In [None]:
tr2_train = tf.keras.preprocessing.image_dataset_from_directory(
                              data_dir_t2,
                              seed=123,
                              validation_split = 0.06,
                              subset = 'training',
                              labels=None,
                              image_size = (img_height, img_width),
                              batch_size=BATCH_SIZE)

# FLAIR MRI images Test set


In [None]:
tr2_test = tf.keras.preprocessing.image_dataset_from_directory(
                              data_dir_t2,
                              seed=123,
                              validation_split = 0.06,
                              subset = 'validation',
                              image_size=(img_height, img_width),
                              batch_size=1)

# DATA LOADER :
## DATA.cache()
* method caches the dataset in memory or on disk after the first epoch
## DATA.prefetch(buffer_size=AUTOTUNE)
*overlaps the preprocessing and 
execution of the model during training. 

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
tr1_train = tr1_train.cache().prefetch(buffer_size=AUTOTUNE)
tr1_test = tr1_test.cache().prefetch(buffer_size=AUTOTUNE)

tr2_train = tr2_train.cache().prefetch(buffer_size=AUTOTUNE)
tr2_test = tr2_test.cache().prefetch(buffer_size=AUTOTUNE)

# normalizing

In [None]:
# normalizing the images to [-1, 1]
def normalize(image):
    image = (image/127.5)-1
    return image

In [None]:
# process both classes of MRI images
tr1_train = tr1_train.map(lambda x: (normalize(x)))
tr2_train = tr2_train.map(lambda x: (normalize(x)))
tr1_test = tr1_test.map(lambda x,_: (normalize(x)))
tr2_test = tr2_test.map(lambda x,_: (normalize(x)))

# sample 

In [None]:
sample_tr1 = next(iter(tr1_train))
sample_tr2 = next(iter(tr2_train))

print(sample_tr1.shape)


plt.title('T2')
plt.imshow(sample_tr1[0].numpy()[:, :, 0] * 0.5 + 0.5, cmap='gray')


In [None]:
plt.title('FLAIR')
plt.imshow(sample_tr2[0].numpy()[:, :, 0] * 0.5 + 0.5, cmap='gray')    

# generator model 
* squeeze_attention_unet

In [None]:
import tensorflow.keras.layers as L
from tensorflow.keras.models import Model

def conv_block(x, num_filters):
    x = L.Conv2D(num_filters, 3, padding="same")(x)
    x = tfal.InstanceNormalization(axis=-1)(x)
    x = L.Activation("relu")(x)

    x = L.Conv2D(num_filters, 3, padding="same")(x)
    x = tfal.InstanceNormalization(axis=-1)(x)
    x = L.Activation("relu")(x)

    return x


def se_block(x, num_filters, ratio=8):
    se_shape = (1, 1, num_filters)
    se = L.GlobalAveragePooling2D()(x)
    se = L.Reshape(se_shape)(se)
    se = L.Dense(num_filters // ratio, activation="relu", use_bias=False)(se)
    se = L.Dense(num_filters, activation="sigmoid", use_bias=False)(se)
    se = L.Reshape(se_shape)(se)
    x = L.Multiply()([x, se])
    return x

def encoder_block(x, num_filters):
    x = conv_block(x, num_filters)
    x = se_block(x, num_filters)
    p = L.MaxPool2D((2, 2))(x)
    return x, p

def decoder_block(x, s, num_filters):
    x = L.UpSampling2D(interpolation="bilinear")(x)
    x = L.Concatenate()([x, s])
    x = conv_block(x, num_filters)
    x = se_block(x, num_filters)
    return x

def squeeze_attention_unet(input_shape=(256, 256, 3)):
    """ Inputs """
    inputs = L.Input(input_shape)

    """ Encoder """
    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)


    b1 = conv_block(p4, 1024)
    b1 = se_block(b1, 1024)
    


    """ Decoder """
    d =  decoder_block(b1, s4, 512)
    d1 = decoder_block(d, s3, 256)
    d2 = decoder_block(d1, s2, 128)
    d3 = decoder_block(d2, s1, 64)

    """ Outputs """
    outputs = L.Conv2D(3, (1, 1), activation='tanh')(d3)

    """ Model """
    
    model = Model(inputs, outputs, name="Squeeze-Attention-UNET")
    return model


In [None]:
generator_g = squeeze_attention_unet()

In [None]:
generator_g.summary()

# DISCRIMINATOR 
* PATCH GAN

In [None]:
def downsample(filters, size, apply_norm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                                               kernel_initializer=initializer, use_bias=False))
    
    if apply_norm:
        result.add(tfal.InstanceNormalization(axis=-1))
    result.add(tf.keras.layers.LeakyReLU())
    return result

In [None]:
def discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
    x = inp
    down1 = downsample(64, 4, False)(x) # (bs, 16, 16, 64)
    down2 = downsample(128, 4)(down1)
    down3 = downsample(256, 4)(down2)
    

    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    conv = tf.keras.layers.Conv2D(512, 4, strides=1, kernel_initializer=initializer,
                                  use_bias=False)(zero_pad1) # (bs, 31, 31, 512)
    norm1 = tfal.InstanceNormalization()(conv)
    leaky_relu = tf.keras.layers.LeakyReLU()(norm1)
    zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = tf.keras.layers.Conv2D(3, 4, strides=1, kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)
    return tf.keras.Model(inputs=inp, outputs=last)

In [None]:
discriminator_x = discriminator()

In [None]:
discriminator_x.summary()

## Performing Predicton on untrained model


In [None]:
to_tr2 = generator_g(sample_tr1)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_tr1, to_tr2, sample_tr2]
title = ['tr1', 'To tr2', 'tr2']

for i in range(len(imgs)):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    if i % 2 == 0:
        plt.imshow(imgs[i][0].numpy()[:, :, 0] * 0.5 + 0.5, cmap='gray')
    else:
        plt.imshow(imgs[i][0].numpy()[:, :, 0] * 0.5 * contrast + 0.5, cmap='gray')
plt.show()

# LOSSES 

In [None]:
LAMBDA = 100.0 

In [None]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

    generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

    total_disc_loss = real_loss + generated_loss

    return total_disc_loss

In [None]:
def generator_loss(disc_generated_output, gen_output, target):
    gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
    # mean absolute error
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    l2_loss = 1 - tf.reduce_mean(tf.image.ssim(target, gen_output, max_val = 2.0))
    l3_loss = (l1_loss + l2_loss) / 2
    total_gen_loss = gan_loss + (LAMBDA * l3_loss)
    return total_gen_loss, gan_loss, l3_loss

In [None]:
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5 ) # , beta_1=0.5

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5 ) # , beta_1=0.5

# SAVING CHECK POINTS

In [None]:
checkpoint_path = r"path of folder saving check points"



ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           
                           discriminator_x=discriminator_x,
                        
                           generator_g_optimizer=generator_g_optimizer,
                           
                           discriminator_x_optimizer=discriminator_x_optimizer,
                         )

# Ref: https://www.tensorflow.org/api_docs/python/tf/train/CheckpointManager
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=300)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)

    print(f'Last Check Point: {ckpt_manager.latest_checkpoint}')
    print('Latest checkpoint restored!!')

### SHOWING SAMPLE EVERY EPOCH 

In [None]:
def generate_images(model1, test1,test2,  gen_g_loss, disc_x):
    prediction1 = model1(test1)
#     prediction2 = model2(test2)
    
    test1 = np.rot90(test1[0, :, :, 0], 3)
    test2 = np.rot90(test2[0, :, :, 0], 3)
    prediction1 = np.rot90(prediction1[0, :, :, 0], 3)
#     prediction2 = np.rot90(prediction2[0, :, :, 0], 3)
    
    plt.figure(figsize=(10, 10))
    display_list = [test1, prediction1, test2]
    
    title = ['T2 True', 'FALIR predicted', 'FLAIR True']
    
    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i] * 0.5 + 0.5, cmap='gray')
        plt.axis('off')
    plt.text(-600, 300 ,"gen_g_er = {:.3f}, disc_x_er = {:.3f}".format(gen_g_loss, disc_x))
    plt.savefig(r'path of saving images/image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()


# Train Step

In [None]:
# Ref: https://www.tensorflow.org/guide/function
@tf.function
def train_step(real_x, real_y):
    # persistent is set to True because the tape is used more than
    # once to calculate the gradients.
    
    with tf.GradientTape(persistent=True) as tape:
        # Generator G translates X -> Y
        fake_y = generator_g(real_x, training=True)
        

        
        disc_real_x = discriminator_x(real_y, training=True)
#         disc_real_y = discriminator_y(real_y, training=True)
        
        disc_fake_x = discriminator_x(fake_y, training=True)
#         disc_fake_y = discriminator_y(fake_y, training=True)

        # calculate the loss
        total_gen_g_loss = generator_loss(disc_fake_x,fake_y,real_y)
#         gen_f_loss = generator_loss(disc_fake_x)
                
        disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
        
    # Calculate the gradients for generator and discriminator
    generator_g_gradients = tape.gradient(total_gen_g_loss, generator_g.trainable_variables)
    
    discriminator_x_gradients = tape.gradient(disc_x_loss, discriminator_x.trainable_variables)
    
    # Apply the gradients to the optimizer
    generator_g_optimizer.apply_gradients(zip(generator_g_gradients, generator_g.trainable_variables))
    
    discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients, discriminator_x.trainable_variables))
    
    return total_gen_g_loss, disc_x_loss

# Test step

In [None]:
@tf.function
def test_step(real_x, real_y):
    # persistent is set to True because the tape is used more than
    # once to calculate the gradients.
    with tf.GradientTape(persistent=True) as tape:
        # Generator G translates X -> Y
        fake_y = generator_g(real_x, training=True)
        

        
        disc_real_x = discriminator_x(real_y, training=True)
        
        disc_fake_x = discriminator_x(fake_y, training=True)

        # calculate the loss
        total_gen_g_loss = generator_loss(disc_fake_x,fake_y,real_y)
        
        

        
        # Total generator loss = adversarial loss + cycle loss
  
        
        disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
#         total_gen_g_loss = gen_g_loss + (total_cycle_loss + identity_loss(real_y, same_y)) / 10.0
#         total_gen_f_loss = gen_f_loss + (total_cycle_loss + identity_loss(real_x, same_x)) / 10.0
    
    return total_gen_g_loss, disc_x_loss

## text saving the history of training and validation losses 

In [None]:
# history of generator loss (training)
history_gen_1 = []
# history of discrimnitor loss (training)
history_dis_1 = []
# history of generator loss (validation)
val_g_hist_1 = []
# history of discrimnitor loss (validation)
val_d_hist_1 = []

# STRAT TRAINING 

In [None]:
for epoch in range(1, EPOCHS+1):
    gen_g_loss = 0
    gen_f_loss = 0
    disc_x = 0
    disc_y = 0
    for image_x,image_y in tf.data.Dataset.zip((tr1_train,tr2_train)):
        gen_g_loss_temp, disc_x_temp = train_step(image_x,image_y)
        gen_g_loss = gen_g_loss + gen_g_loss_temp[0]  / 19289          * BATCH_SIZE
        disc_x = disc_x + disc_x_temp / 19289          * BATCH_SIZE
        
    tot_loss_gen = gen_g_loss 
    tot_loss_dis = disc_x 



    history_gen_1.append([tot_loss_gen])
    history_dis_1.append([tot_loss_dis])



    np.savetxt(r'path of saving text/history_gen.txt_1', history_gen_1, fmt='%f')
    np.savetxt(r'path of saving text/history_dis.txt_1', history_dis_1, fmt='%f')

    generate_images(generator_g, sample_tr1, sample_tr2, tot_loss_gen, disc_x)

    ckpt_save_path = ckpt_manager.save()
    print('Saving checkpoint for epoch', epoch, 'at', ckpt_save_path)
    gen_g_loss = 0
    gen_f_loss = 0
    disc_x = 0
    disc_y = 0
    for image_x ,image_y in tf.data.Dataset.zip((tr1_test , tr2_test)):
        gen_g_loss_temp, disc_x_temp = test_step(image_x,image_y)
        gen_g_loss = gen_g_loss + gen_g_loss_temp[0]  / 1231   * BATCH_SIZE
        disc_x = disc_x + disc_x_temp / 1231   * BATCH_SIZE
    tot_loss_gen = gen_g_loss 
    tot_loss_dis = disc_x 

    val_g_hist_1.append([tot_loss_gen])
    val_d_hist_1.append([tot_loss_dis])

  
    np.savetxt(r'path of saving text/val_g_hist_1.txt_1', val_g_hist_1, fmt='%f')
    np.savetxt(r'path of saving text/val_d_hist_1.txt_1', val_d_hist_1, fmt='%f')