In [None]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.callbacks import Callback
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from tensorflow import keras
import tensorflow_addons as tfa
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Model
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import os
import pathlib
import random
import nibabel as nib
import imageio
import os, random, json, PIL, shutil, re, imageio, glob

## parameters

**Because the model have been trained, here set epochs to 21 for presentation**

**the total epochs have trained is 120**

In [None]:
BUFFER_SIZE = 5000
HEIGHT = 256
WIDTH = 256
CHANNELS = 1
BATCH_SIZE = 4
EPOCHS = 21
TRANSFORMER_BLOCKS = 6
GENERATOR_LR = 2e-4
DISCRIMINATOR_LR = 2e-4

## preprocess datasets

In [None]:
def preprocess_image_T2(image):
    image= tf.image.decode_png(image, channels=1)
    image = tf.image.pad_to_bounding_box(image, offset_height=0, offset_width=0, target_height=256, target_width=256)
    image = tf.image.resize(image, [256, 256])
    image = tf.image.rot90(image, k=3) # rotate 270º
    image = tf.image.flip_up_down(image)
    image = (image-127.5)/127.5
    return image

def preprocess_image(image):
    image= tf.image.decode_png(image, channels=1)
    #image = tf.image.pad_to_bounding_box(image, offset_height=0, offset_width=60, target_height=256, target_width=256)
    image = tf.image.resize(image, [256, 256])
    image = (image-127.5)/127.5
    return image

def preprocess_image_test(image):
    image= tf.image.decode_png(image, channels=1)
    image = tf.image.pad_to_bounding_box(image, offset_height=0, offset_width=0, target_height=256, target_width=256)
    image = tf.image.resize(image, [256, 256])
    image = (image-127.5)/127.5
    return image

def load_and_preprocess_image(path):
    image = tf.io.read_file(path)
    return preprocess_image(image)

def load_and_preprocess_image_test(path):
    image = tf.io.read_file(path)
    return preprocess_image_test(image)

def load_and_preprocess_image_T2(path):
    image = tf.io.read_file(path)
    return preprocess_image_T2(image)

In [None]:
data_root = pathlib.Path('../input/ixit2-slices/image slice-T2')
all_image_paths_T2 = list(data_root.glob('*/*'))
all_image_paths_T2.sort()
new_all_image_paths_T2 = []
for i in range(len(all_image_paths_T2)):
    if (i+4)%10 == 0:
        new_all_image_paths_T2.append(all_image_paths_T2[i])
all_image_paths_T2 = new_all_image_paths_T2
all_image_paths_T2 = [str(path) for path in all_image_paths_T2[:500]]
image_count = len(all_image_paths_T2)
ds_T2 = tf.data.Dataset.from_tensor_slices((all_image_paths_T2))
dataset_T2 = ds_T2.map(load_and_preprocess_image_T2).batch(BATCH_SIZE).repeat().shuffle(512)
dataset_T2 = dataset_T2.cache()
dataset_T2_test = ds_T2.map(load_and_preprocess_image_T2).batch(1)

In [None]:
data_root = pathlib.Path('../input/head-ct-hemorrhage/head_ct')
all_image_paths_ct = list(data_root.glob('*/*'))
all_image_paths_ct.sort()
all_image_paths_ct = [str(path) for path in all_image_paths_ct]
image_count = len(all_image_paths_ct)
ds_ct = tf.data.Dataset.from_tensor_slices((all_image_paths_ct))
dataset_ct = ds_ct.map(load_and_preprocess_image).batch(BATCH_SIZE).repeat().shuffle(512)
dataset_ct = dataset_ct.cache()
dataset_ct_test = ds_ct.map(load_and_preprocess_image).batch(1)

## encoder block 

In [None]:
conv_initializer = tf.random_normal_initializer(mean=0.0, stddev=0.02)
gamma_initializer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
    
def encoder_block(input_layer, filters, size=3, strides=2, apply_instancenorm=True, activation=layers.ReLU(), name='block_x'):
    block = layers.Conv2D(filters, size, 
                     strides=strides, 
                     padding='same', 
                     use_bias=False, 
                     kernel_initializer=conv_initializer, 
                     name=f'encoder_{name}')(input_layer)

    if apply_instancenorm:
        block = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(block)
        
    block = activation(block)

    return block

def transformer_block(input_layer, size=3, strides=1, name='block_x'):
    filters = input_layer.shape[-1]
    
    block = layers.Conv2D(filters, size, strides=strides, padding='same', use_bias=False, 
                     kernel_initializer=conv_initializer, name=f'transformer_{name}_1')(input_layer)
#     block = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(block)
    block = layers.ReLU()(block)
    
    block = layers.Conv2D(filters, size, strides=strides, padding='same', use_bias=False, 
                     kernel_initializer=conv_initializer, name=f'transformer_{name}_2')(block)
#     block = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(block)
    
    block = layers.Add()([block, input_layer])

    return block

def decoder_block(input_layer, filters, size=3, strides=2, apply_instancenorm=True, name='block_x'):
    block = layers.Conv2DTranspose(filters, size, 
                              strides=strides, 
                              padding='same', 
                              use_bias=False, 
                              kernel_initializer=conv_initializer, 
                              name=f'decoder_{name}')(input_layer)

    if apply_instancenorm:
        block = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(block)

    block = layers.ReLU()(block)
    
    return block

# Resized convolution
def decoder_rc_block(input_layer, filters, size=3, strides=1, apply_instancenorm=True, name='block_x'):
    block = tf.image.resize(images=input_layer, method='bilinear', 
                            size=(input_layer.shape[1]*2, input_layer.shape[2]*2))
    
#     block = tf.pad(block, [[0, 0], [1, 1], [1, 1], [0, 0]], "SYMMETRIC") # Works only with GPU
#     block = L.Conv2D(filters, size, strides=strides, padding='valid', use_bias=False, # Works only with GPU
    block = layers.Conv2D(filters, size, 
                     strides=strides, 
                     padding='same', 
                     use_bias=False, 
                     kernel_initializer=conv_initializer, 
                     name=f'decoder_{name}')(block)

    if apply_instancenorm:
        block = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(block)

    block = layers.ReLU()(block)
    
    return block

## generator

In [None]:
def generator_fn(height=HEIGHT, width=WIDTH, channels=CHANNELS, transformer_blocks=TRANSFORMER_BLOCKS):
    OUTPUT_CHANNELS = 1
    inputs = layers.Input(shape=[height, width, channels], name='input_image')

    # Encoder
    enc_1 = encoder_block(inputs, 64,  7, 1, apply_instancenorm=False, activation=layers.ReLU(), name='block_1') # (bs, 256, 256, 64)
    enc_2 = encoder_block(enc_1, 128, 3, 2, apply_instancenorm=True, activation=layers.ReLU(), name='block_2')   # (bs, 128, 128, 128)
    enc_3 = encoder_block(enc_2, 256, 3, 2, apply_instancenorm=True, activation=layers.ReLU(), name='block_3')   # (bs, 64, 64, 256)
    
    # Transformer
    x = enc_3
    for n in range(transformer_blocks):
        x = transformer_block(x, 3, 1, name=f'block_{n+1}') # (bs, 64, 64, 256)

    # Decoder
    x_skip = layers.Concatenate(name='enc_dec_skip_1')([x, enc_3]) # encoder - decoder skip connection
    
    dec_1 = decoder_block(x_skip, 128, 3, 2, apply_instancenorm=True, name='block_1') # (bs, 128, 128, 128)
    x_skip = layers.Concatenate(name='enc_dec_skip_2')([dec_1, enc_2]) # encoder - decoder skip connection
    
    dec_2 = decoder_block(x_skip, 64,  3, 2, apply_instancenorm=True, name='block_2') # (bs, 256, 256, 64)
    x_skip = layers.Concatenate(name='enc_dec_skip_3')([dec_2, enc_1]) # encoder - decoder skip connection

    outputs = last = layers.Conv2D(OUTPUT_CHANNELS, 7, 
                              strides=1, padding='same', 
                              kernel_initializer=conv_initializer, 
                              use_bias=False, 
                              activation='tanh', 
                              name='decoder_output_block')(x_skip) # (bs, 256, 256, 3)

    generator = Model(inputs, outputs)
    
    return generator
sample_generator = generator_fn()
sample_generator.summary()

## discriminator 

In [None]:
def discriminator_fn(height=HEIGHT, width=WIDTH, channels=CHANNELS):
    inputs = layers.Input(shape=[height, width, channels], name='input_image')
    #inputs_patch = L.experimental.preprocessing.RandomCrop(height=70, width=70, name='input_image_patch')(inputs) # Works only with GPU

    # Encoder    
    x = encoder_block(inputs, 64,  4, 2, apply_instancenorm=False, activation=layers.LeakyReLU(0.2), name='block_1') # (bs, 128, 128, 64)
    x = encoder_block(x, 128, 4, 2, apply_instancenorm=True, activation=layers.LeakyReLU(0.2), name='block_2')       # (bs, 64, 64, 128)
    x = encoder_block(x, 256, 4, 2, apply_instancenorm=True, activation=layers.LeakyReLU(0.2), name='block_3')       # (bs, 32, 32, 256)
    x = encoder_block(x, 512, 4, 1, apply_instancenorm=True, activation=layers.LeakyReLU(0.2), name='block_4')       # (bs, 32, 32, 512)

    outputs = layers.Conv2D(1, 4, strides=1, padding='valid', kernel_initializer=conv_initializer)(x)                # (bs, 29, 29, 1)
    
    discriminator = Model(inputs, outputs)
    
    return discriminator


sample_discriminator = discriminator_fn()
sample_discriminator.summary()

## create generator and discriminator

In [None]:
T1_generator = generator_fn() # transforms T2 to T1
T2_generator = generator_fn() # transforms T1 paintings to be T2

T1_discriminator = discriminator_fn() # differentiates real T1 and generated T1
T2_discriminator = discriminator_fn() # differentiates real T2 and generated T2

## If continue training, load saved model.
## if new training, please make it Markdown cell

In [None]:
T1_generator = tf.keras.models.load_model('../input/mrict-model/T1/generate_mri_3_22')
T1_discriminator = tf.keras.models.load_model('../input/mrict-model/T1/discriminate_mri_3_22')
T2_generator = tf.keras.models.load_model('../input/mrict-model/T2/generate_ct_3_22')
T2_discriminator = tf.keras.models.load_model('../input/mrict-model/T2/discriminate_ct_3_22')

## build cycle gan

In [None]:
class CycleGan(keras.Model):
    def __init__(
        self,
        T1_generator,
        T2_generator,
        T1_discriminator,
        T2_discriminator,
        lambda_cycle=10,
    ):
        super(CycleGan, self).__init__()
        self.T1_gen = T1_generator
        self.T2_gen = T2_generator
        self.T1_disc = T1_discriminator
        self.T2_disc = T2_discriminator
        self.lambda_cycle = lambda_cycle
        
    def compile(
        self,
        T1_gen_optimizer,
        T2_gen_optimizer,
        T1_disc_optimizer,
        T2_disc_optimizer,
        gen_loss_fn,
        disc_loss_fn,
        cycle_loss_fn,
        identity_loss_fn
    ):
        super(CycleGan, self).compile()
        self.T1_gen_optimizer = T1_gen_optimizer
        self.T2_gen_optimizer = T2_gen_optimizer
        self.T1_disc_optimizer = T1_disc_optimizer
        self.T2_disc_optimizer = T2_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
    def train_step(self, batch_data):
        real_T1, real_T2 = batch_data
        
        with tf.GradientTape(persistent=True) as tape:
            # T2 to T1 back to T2
            fake_T1 = self.T1_gen(real_T2, training=True)
            cycled_T2 = self.T2_gen(fake_T1, training=True)

            # T1 to T2 back to T1
            fake_T2 = self.T2_gen(real_T1, training=True)
            cycled_T1 = self.T1_gen(fake_T2, training=True)

            # generating itself
            same_T1 = self.T1_gen(real_T1, training=True)
            same_T2 = self.T2_gen(real_T2, training=True)

            # discriminator used to check, inputing real images
            disc_real_T1 = self.T1_disc(real_T1, training=True)
            disc_real_T2 = self.T2_disc(real_T2, training=True)

            # discriminator used to check, inputing fake images
            disc_fake_T1 = self.T1_disc(fake_T1, training=True)
            disc_fake_T2 = self.T2_disc(fake_T2, training=True)

            # evaluates generator loss
            T1_gen_loss = self.gen_loss_fn(disc_fake_T1)
            T2_gen_loss = self.gen_loss_fn(disc_fake_T2)

            # evaluates total cycle consistency loss
            total_cycle_loss = self.cycle_loss_fn(real_T1, cycled_T1, self.lambda_cycle) + self.cycle_loss_fn(real_T2, cycled_T2, self.lambda_cycle)

            # evaluates total generator loss
            total_T1_gen_loss = T1_gen_loss + total_cycle_loss + self.identity_loss_fn(real_T1, same_T1, self.lambda_cycle)
            total_T2_gen_loss = T2_gen_loss + total_cycle_loss + self.identity_loss_fn(real_T2, same_T2, self.lambda_cycle)

            # evaluates discriminator loss
            T1_disc_loss = self.disc_loss_fn(disc_real_T1, disc_fake_T1)
            T2_disc_loss = self.disc_loss_fn(disc_real_T2, disc_fake_T2)

        # Calculate the gradients for generator and discriminator
        T1_generator_gradients = tape.gradient(total_T1_gen_loss,
                                                  self.T1_gen.trainable_variables)
        T2_generator_gradients = tape.gradient(total_T2_gen_loss,
                                                  self.T2_gen.trainable_variables)

        T1_discriminator_gradients = tape.gradient(T1_disc_loss,
                                                      self.T1_disc.trainable_variables)
        T2_discriminator_gradients = tape.gradient(T2_disc_loss,
                                                      self.T2_disc.trainable_variables)

        # Apply the gradients to the optimizer
        self.T1_gen_optimizer.apply_gradients(zip(T1_generator_gradients,
                                                 self.T1_gen.trainable_variables))

        self.T2_gen_optimizer.apply_gradients(zip(T2_generator_gradients,
                                                 self.T2_gen.trainable_variables))

        self.T1_disc_optimizer.apply_gradients(zip(T1_discriminator_gradients,
                                                  self.T1_disc.trainable_variables))

        self.T2_disc_optimizer.apply_gradients(zip(T2_discriminator_gradients,
                                                  self.T2_disc.trainable_variables))
        
        return {
            "T1_gen_loss": total_T1_gen_loss,
            "T2_gen_loss": total_T2_gen_loss,
            "T1_disc_loss": T1_disc_loss,
            "T2_disc_loss": T2_disc_loss
        }

**loss function**

In [None]:
def discriminator_loss(real, generated):
     real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)

     generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)

     total_disc_loss = real_loss + generated_loss

     return total_disc_loss * 0.5

In [None]:
def generator_loss(generated):
    return tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)

In [None]:
 def calc_cycle_loss(real_image, cycled_image, LAMBDA):
    loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

    return LAMBDA * loss1

In [None]:
def identity_loss(real_image, same_image, LAMBDA):
    loss = tf.reduce_mean(tf.abs(real_image - same_image))
    return LAMBDA * 0.5 * loss

## tune the learning rate by steps

In [None]:
@tf.function
def linear_schedule_with_warmup(step):
    """ Create a schedule with a learning rate that decreases linearly after
    linearly increasing during a warmup period.
    """
    lr_start   = 2e-4
    lr_max     = 2e-4
    lr_min     = 0.
    
    steps_per_epoch = int(max(500, 500)//BATCH_SIZE)
    total_steps = EPOCHS * steps_per_epoch
    warmup_steps = 1
    hold_max_steps = total_steps * 0.8
    
    if step < warmup_steps:
        lr = (lr_max - lr_start) / warmup_steps * step + lr_start
    elif step < warmup_steps + hold_max_steps:
        lr = lr_max
    else:
        lr = lr_max * ((total_steps - step) / (total_steps - warmup_steps - hold_max_steps))
        if lr_min is not None:
            lr = tf.math.maximum(lr_min, lr)

    return lr

steps_per_epoch = int(max(500, 500)//BATCH_SIZE)
total_steps = EPOCHS * steps_per_epoch
rng = [i for i in range(0, total_steps, 50)]
y = [linear_schedule_with_warmup(x) for x in rng]

sns.set(style="whitegrid")
fig, ax = plt.subplots(figsize=(20, 6))
plt.plot(rng, y)
print(f'{EPOCHS} total epochs and {steps_per_epoch} steps per epoch')
print(f'Learning rate schedule: {y[0]:.3g} to {max(y):.3g} to {y[-1]:.3g}')

In [None]:
lr_T1_gen = lambda: linear_schedule_with_warmup(tf.cast(T1_generator_optimizer.iterations, tf.float32))
lr_T2_gen = lambda: linear_schedule_with_warmup(tf.cast(T2_generator_optimizer.iterations, tf.float32))
    
T1_generator_optimizer = optimizers.Adam(learning_rate=lr_T1_gen, beta_1=0.5)
T2_generator_optimizer = optimizers.Adam(learning_rate=lr_T2_gen, beta_1=0.5)

# Create discriminators
lr_T1_disc = lambda: linear_schedule_with_warmup(tf.cast(T1_discriminator_optimizer.iterations, tf.float32))
lr_T2_disc = lambda: linear_schedule_with_warmup(tf.cast(T2_discriminator_optimizer.iterations, tf.float32))
    
T1_discriminator_optimizer = optimizers.Adam(learning_rate=lr_T1_disc, beta_1=0.5)
T2_discriminator_optimizer = optimizers.Adam(learning_rate=lr_T2_disc, beta_1=0.5)

# Create GAN
gan_model = CycleGan(T1_generator, T2_generator, 
                     T1_discriminator, T2_discriminator)

gan_model.compile(T1_gen_optimizer=T1_generator_optimizer,
                      T2_gen_optimizer=T2_generator_optimizer,
                      T1_disc_optimizer=T1_discriminator_optimizer,
                      T2_disc_optimizer=T2_discriminator_optimizer,
                      gen_loss_fn=generator_loss,
                      disc_loss_fn=discriminator_loss,
                      cycle_loss_fn=calc_cycle_loss,
                      identity_loss_fn=identity_loss)

## Make a callback, on every epoch end, save models and figures

In [None]:
# Callbacks
class GANMonitor(Callback):
    """A callback to generate and save images after each epoch"""

    def __init__(self, num_img=4, T1_path='T1', T2_path='T2'):
        self.num_img = num_img
        self.T1_path = T1_path
        self.T2_path = T2_path
        # Create directories to save the generate images
        if not os.path.exists(self.T1_path):
            os.makedirs(self.T1_path)
        if not os.path.exists(self.T2_path):
            os.makedirs(self.T2_path)

    def on_epoch_end(self, epoch, logs=None):
        if epoch%7 == 0:
            fig = plt.figure(figsize=(12,8))
            for i, img in enumerate(dataset_ct_test.take(self.num_img)):
                prediction = T1_generator(img, training=False).numpy()
                prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
                cycle = T2_generator(prediction, training=False)
                plt.subplot(6,4,i +1)
                plt.imshow(img[0], cmap = 'gray')
                plt.axis('off')
                plt.subplot(6,4,i +5)
                plt.imshow(prediction[0], cmap = 'gray')
                plt.axis('off')
                plt.subplot(6,4,i +9)
                plt.imshow(cycle[0], cmap = 'gray')
                plt.axis('off')
            
       
            for i, img in enumerate(dataset_T2_test.take(self.num_img)):
                prediction = T2_generator(img, training=False).numpy()
                prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
                cycle = T1_generator(prediction, training=False)
                plt.subplot(6,4,i + 13)
                plt.imshow(img[0], cmap = 'gray')
                plt.axis('off')
                plt.subplot(6,4,i + 17)
                plt.imshow(prediction[0], cmap = 'gray')
                plt.axis('off')
                plt.subplot(6,4,i + 21)
                plt.imshow(cycle[0], cmap = 'gray')
                plt.axis('off')
            plt.show()
            plt.savefig(f'{self.T1_path}/visualiation_{i}_{epoch+1}')
        if epoch%7 == 0:
            tf.keras.models.save_model(T1_generator, f'{self.T1_path}/generate_mri_{i}_{epoch+1}')
            tf.keras.models.save_model(T1_discriminator, f'{self.T1_path}/discriminate_mri_{i}_{epoch+1}')
            tf.keras.models.save_model(T2_generator, f'{self.T2_path}/generate_ct_{i}_{epoch+1}')
            tf.keras.models.save_model(T2_discriminator, f'{self.T2_path}/discriminate_ct_{i}_{epoch+1}')

## train model

In [None]:
gan_ds = tf.data.Dataset.zip((dataset_T2, dataset_ct))

In [None]:
history =   gan_model.fit(gan_ds,
                          epochs=EPOCHS,
                          callbacks=[GANMonitor()],
                          steps_per_epoch=(max(500, 500)//BATCH_SIZE) ).history    

## show results

In [None]:
test_ds = tf.data.Dataset.zip((dataset_T2_test, dataset_ct_test)).shuffle(10) 

In [None]:
def estimation_generator_T1(T1_generator, test_ds):
    fig = plt.figure(figsize= (16,5)) 
    psnr = []
    ssim = []

    for i, img in enumerate(test_ds.take(8)): 
        image_T1, image_T2 = img
        generated_T1 = T1_generator(image_T2)
        plt.subplot(2,8,(i+1))
        plt.imshow(image_T2[0], cmap = 'gray')
        plt.title('CT image')
        plt.axis('off')
        plt.subplot(2,8, (i+9))
        plt.imshow(generated_T1[0], cmap = 'gray')
        plt.title('Generated \n MRI image')
        plt.axis('off')
    
    for i, img in enumerate(test_ds.take(100)):   
        image_T1, image_T2 = img
        generated_T1 = T1_generator(image_T2)
        psnr.append( tf.image.psnr(generated_T1, image_T1, max_val=2))
        ssim.append( tf.image.ssim(generated_T1, image_T1, max_val=2))

    psnr_mean = np.mean(psnr)
    psnr_std = np.std(psnr, ddof=1)
    ssim_mean = np.mean(ssim) 
    ssim_std = np.std(ssim, ddof=1)
    print('========CT to MRI =========')
    print('PSNR ={}, std ={}'.format(psnr_mean, psnr_std))
    print('SSIM ={}, std ={}'.format(ssim_mean, ssim_std))
    plt.show()
    
    
def estimation_generator_T2(T2_generator, test_ds):
    fig = plt.figure(figsize= (16,5)) 
    psnr = []
    ssim = []

    for i, img in enumerate(test_ds.take(8)): 
        image_T1, image_T2 = img
        generated_T2 = T2_generator(image_T1)
        plt.subplot(2,8,(i+1))
        plt.imshow(image_T1[0], cmap = 'gray')
        plt.title('MRI image')
        plt.axis('off')
        plt.subplot(2,8, (i+9))
        plt.imshow(generated_T2[0], cmap = 'gray')
        plt.title('Generated \n CT image')
        plt.axis('off')
    
    for i, img in enumerate(test_ds.take(100)):   
        image_T1, image_T2 = img
        generated_T2 = T2_generator(image_T1)
        psnr.append( tf.image.psnr(generated_T2, image_T2, max_val=2))
        ssim.append( tf.image.ssim(generated_T2, image_T2, max_val=2))

    psnr_mean = np.mean(psnr)
    psnr_std = np.std(psnr, ddof=1)
    ssim_mean = np.mean(ssim) 
    ssim_std = np.std(ssim, ddof=1)
    print('========MRI to CT =========')
    print('PSNR ={}, std ={}'.format(psnr_mean, psnr_std))
    print('SSIM ={}, std ={}'.format(ssim_mean, ssim_std))
    plt.show()

In [None]:
estimation_generator_T2(T2_generator, test_ds)
estimation_generator_T1(T1_generator, test_ds)