# mount google drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# install and import necessary package for modelling

In [None]:
!pip install -q git+https://www.github.com/keras-team/keras-contrib.git

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for keras-contrib (setup.py) ... [?25l[?25hdone


In [None]:
from random import random
import random
from numpy import load
from numpy import zeros
from numpy import ones
from numpy import asarray
from numpy.random import randint
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.layers import *
from tensorflow.keras import Model
from tensorflow.keras.optimizers import *
from numpy import asarray
from numpy import *
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from matplotlib import pyplot
from tensorflow.keras.layers import Dropout
import tensorflow as tf
import numpy as np

# build discriminator model

In [None]:

# define the discriminator model
def define_discriminator(image_shape):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# source image input
	in_image = Input(shape=image_shape)
	# C64
	d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(in_image)
	d = LeakyReLU(alpha=0.2)(d)
	# C128
	d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = InstanceNormalization(axis=-1)(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C256
	d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = InstanceNormalization(axis=-1)(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C512
	d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = InstanceNormalization(axis=-1)(d)
	d = LeakyReLU(alpha=0.2)(d)
	# second last output layer
	d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
	d = InstanceNormalization(axis=-1)(d)
	d = LeakyReLU(alpha=0.2)(d)
	# patch output
	patch_out = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
	# define model
	model = Model(in_image, patch_out)
	# compile model
	model.compile(loss='mse', optimizer=Adam(lr=0.0002, beta_1=0.5), loss_weights=[0.5])
	return model


# build resnet block

In [None]:
# generator a resnet block
def resnet_block(n_filters, input_layer):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# first layer convolutional layer
	g = Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(input_layer)
	g = InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# second convolutional layer
	g = Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(g)
	g = InstanceNormalization(axis=-1)(g)
	# concatenate merge channel-wise with input layer
	g = Concatenate()([g, input_layer])
	return g

# build generator model

In [None]:
# define the standalone generator model
def define_generator(image_shape, n_resnet=9):
    # weight initialization
    init = RandomNormal(stddev=0.02)
    # image input
    in_image = Input(shape=image_shape)
    # c7s1-64
    g = Conv2D(64, (7,7), padding='same', kernel_initializer=init)(in_image)
    g = InstanceNormalization(axis=-1)(g)
    g = Activation('relu')(g)
    g = Dropout(0.5)(g)
    # d128
    g = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
    g = InstanceNormalization(axis=-1)(g)
    g = Activation('relu')(g)
    # d256
    g = Conv2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
    g = InstanceNormalization(axis=-1)(g)
    g = Activation('relu')(g)
    # R256
    for _ in range(n_resnet):
        g = resnet_block(256, g)
        g = Dropout(0.5)(g) # monte carlo dropout to Estimate Uncertainty
    # u128
    g = Conv2DTranspose(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
    g = InstanceNormalization(axis=-1)(g)
    g = Activation('relu')(g)
    # u64
    g = Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
    g = InstanceNormalization(axis=-1)(g)
    g = Activation('relu')(g)
    # c7s1-3
    g = Conv2D(3, (7,7), padding='same', kernel_initializer=init)(g)
    g = InstanceNormalization(axis=-1)(g)
    out_image = Activation('tanh')(g)
    # define model
    model = Model(in_image, out_image)
    return model

# combine all model

In [None]:
# Define the composite model with confidence map
def define_composite_model(g_model_1, d_model, g_model_2, image_shape):
    # ensure the model we're updating is trainable
    g_model_1.trainable = True
    # mark discriminator as not trainable
    d_model.trainable = False
    # mark other generator model as not trainable
    g_model_2.trainable = False
    # discriminator element
    input_gen = tf.keras.Input(shape=image_shape)
    gen1_out = g_model_1(input_gen)

    output_d = d_model(gen1_out)
    # identity element
    input_id = tf.keras.Input(shape=image_shape)
    output_id = g_model_1(input_id)
    # forward cycle
    output_f = g_model_2(gen1_out)
    # backward cycle
    gen2_out = g_model_2(input_id)
    output_b = g_model_1(gen2_out)
    # define model
    model = tf.keras.Model([input_gen, input_id], [output_d, output_id, output_f, output_b])
    # compile model
    opt = tf.keras.optimizers.Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss=['mse', 'mae', 'mae', 'mae'], loss_weights=[1, 5, 10, 10], optimizer=opt)
    return model

# Utils for training loops

In [None]:
# load and prepare training images
def load_real_samples(filename):
	# load the dataset
	data = load(filename)
	# unpack arrays
	X1, X2 = data['arr_0'], data['arr_1']
	# scale from [0,255] to [-1,1]
	X1 = (X1 - 127.5) / 127.5
	X2 = (X2 - 127.5) / 127.5

	return [X1, X2]

# select a batch of random samples, returns images and target
def generate_real_samples(dataset, n_samples, patch_shape):
	# choose random instances
	ix = randint(0, dataset.shape[0], n_samples)
	# retrieve selected images
	X = dataset[ix]
	# generate 'real' class labels (1)
	y = ones((n_samples, patch_shape, patch_shape, 1))
	return X, y

# generate a batch of images, returns images and targets
def generate_fake_samples(g_model, dataset, patch_shape):
	# generate fake instance
	X = g_model.predict(dataset)
	# create 'fake' class labels (0)
	y = zeros((len(X), patch_shape, patch_shape, 1))
	return X, y

# save the generator models to file
def save_models(step, g_model_AtoB, g_model_BtoA):
	# save the first generator model
	filename1 = 'g_model_AtoB_%06d.h5' % (step+1)
	g_model_AtoB.save(filename1)
	# save the second generator model
	filename2 = 'g_model_BtoA_%06d.h5' % (step+1)
	g_model_BtoA.save(filename2)
	print('>Saved: %s and %s' % (filename1, filename2))

def update_image_pool(pool, images, max_size=50):
    selected = list()
    for image in images:
        if len(pool) < max_size:
            # stock the pool
            pool.append(image)
            selected.append(image)
        elif random.random() < 0.5:
            # use image, but don't add it to the pool
            selected.append(image)
        else:
            # replace an existing image and use the replaced image
            ix = np.random.randint(0, len(pool))
            selected.append(pool[ix])
            pool[ix] = image
    return np.asarray(selected)


In [None]:
# generate samples with Monte Carlo Dropout and save as a plot
def summarize_performance(step, g_model, trainX, name, n_samples=3, n_mc_samples=3):
    # select a sample of input images
    X_in, _ = generate_real_samples(trainX, n_samples, 0)

    # Initialize an empty list to store the generated images
    generated_images = []


    # Generate translated images using Monte Carlo Dropout
    for _ in range(n_mc_samples):
        X_out, _ = generate_fake_samples(g_model, X_in, 0)
        X_out = (X_out + 1) / 2.0  # Scale from [-1,1] to [0,1]
        generated_images.append(X_out)

    # Set the size of the figure and adjust subplot spacing
    # fig = pyplot.figure(figsize=(n_mc_samples*6, n_samples*6))
    # fig.subplots_adjust(hspace=0.5)
    fig = pyplot.figure(figsize=(n_mc_samples*6, (n_samples+1)*6))
    fig.subplots_adjust(hspace=0.3, wspace=0.1)


    # Plot real images
    for i in range(n_samples):
        # pyplot.subplot(n_samples + 2, n_mc_samples + 2, i + 1)
        pyplot.subplot(n_samples + 1, n_mc_samples, i + 1)
        pyplot.axis('off')
        pyplot.imshow(X_in[i][:, :, 0])

    # Plot generated fake images
    for j in range(n_mc_samples):
        for i in range(n_samples):
            # pyplot.subplot(n_samples + 2, n_mc_samples + 2, (j + 1) * (n_samples + 2) + i + 1)
            pyplot.subplot(n_samples + 1, n_mc_samples, (j * n_samples) + (i + 1) + n_mc_samples)
            pyplot.axis('off')
            pyplot.imshow(generated_images[j][i][:, :, 0])

    # Calculate the average of the generated images
    avg_generated_images = np.mean(generated_images, axis=0)

    # Plot translated images with uncertainty
    for i in range(n_samples):
        # pyplot.subplot(n_samples + 2, n_mc_samples + 2, (n_mc_samples + 2) * (n_samples + 1) + i + 1)
        pyplot.subplot(n_samples + 1, n_mc_samples, (n_mc_samples * n_samples) + (i + 1))
        pyplot.axis('off')
        pyplot.imshow(avg_generated_images[i][:, :, 0])


    # Save plot to file
    filename1 = '%s_generated_plot_%06d.png' % (name, (step + 1))
    pyplot.savefig(filename1)
    pyplot.close()



    # Generate original 3D Image

    # Set the size of the figure and adjust subplot spacing
    # fig = pyplot.figure(figsize=(n_mc_samples*6, n_samples*6))
    # fig.subplots_adjust(hspace=0.5)
    fig2 = pyplot.figure(figsize=(n_mc_samples*6, (n_samples+1)*6))
    fig2.subplots_adjust(hspace=0.3, wspace=0.1)


    # Plot real images
    for i in range(n_samples):
        # pyplot.subplot(n_samples + 2, n_mc_samples + 2, i + 1)
        pyplot.subplot(n_samples + 1, n_mc_samples, i + 1)
        pyplot.axis('off')
        pyplot.imshow(X_in[i])

    # Plot generated fake images
    for j in range(n_mc_samples):
        for i in range(n_samples):
            # pyplot.subplot(n_samples + 2, n_mc_samples + 2, (j + 1) * (n_samples + 2) + i + 1)
            pyplot.subplot(n_samples + 1, n_mc_samples, (j * n_samples) + (i + 1) + n_mc_samples)
            pyplot.axis('off')
            pyplot.imshow(generated_images[j][i])

    # Calculate the average of the generated images
    avg_generated_images = np.mean(generated_images, axis=0)

    # Plot translated images with uncertainty
    for i in range(n_samples):
        # pyplot.subplot(n_samples + 2, n_mc_samples + 2, (n_mc_samples + 2) * (n_samples + 1) + i + 1)
        pyplot.subplot(n_samples + 1, n_mc_samples, (n_mc_samples * n_samples) + (i + 1))
        pyplot.axis('off')
        pyplot.imshow(avg_generated_images[i])





    # Save plot to file
    filename2 = '%s_generated_real_plot_%06d.png' % (name, (step + 1))
    pyplot.savefig(filename2)
    pyplot.close()


# Train Model

In [None]:
# train cyclegan models
def train(d_model_A, d_model_B, g_model_AtoB, g_model_BtoA, c_model_AtoB, c_model_BtoA, dataset):
	# define properties of the training run
	n_epochs, n_batch, = 60, 1
	# determine the output square shape of the discriminator
	n_patch = d_model_A.output_shape[1]
	# unpack dataset
	trainA, trainB = dataset
	# prepare image pool for fakes
	poolA, poolB = list(), list()
	# calculate the number of batches per training epoch
	bat_per_epo = int(len(trainA) / n_batch) # 300
	# calculate the number of training iterations
	n_steps = bat_per_epo * n_epochs
	# manually enumerate epochs
	for i in range(n_steps):
		# select a batch of real samples
		X_realA, y_realA = generate_real_samples(trainA, n_batch, n_patch)
		X_realB, y_realB = generate_real_samples(trainB, n_batch, n_patch)
		# generate a batch of fake samples
		X_fakeA, y_fakeA = generate_fake_samples(g_model_BtoA, X_realB, n_patch)
		X_fakeB, y_fakeB = generate_fake_samples(g_model_AtoB, X_realA, n_patch)
		# update fakes from pool
		X_fakeA = update_image_pool(poolA, X_fakeA)
		X_fakeB = update_image_pool(poolB, X_fakeB)
		# update generator B->A via adversarial and cycle loss
		g_loss2, _, _, _, _  = c_model_BtoA.train_on_batch([X_realB, X_realA], [y_realA, X_realA, X_realB, X_realA])
		# update discriminator for A -> [real/fake]
		dA_loss1 = d_model_A.train_on_batch(X_realA, y_realA)
		dA_loss2 = d_model_A.train_on_batch(X_fakeA, y_fakeA)
		# update generator A->B via adversarial and cycle loss
		g_loss1, _, _, _, _ = c_model_AtoB.train_on_batch([X_realA, X_realB], [y_realB, X_realB, X_realA, X_realB])
		# update discriminator for B -> [real/fake]
		dB_loss1 = d_model_B.train_on_batch(X_realB, y_realB)
		dB_loss2 = d_model_B.train_on_batch(X_fakeB, y_fakeB)
		# summarize performance
		print('>%d, dA[%.3f,%.3f] dB[%.3f,%.3f] g[%.3f,%.3f]' % (i+1, dA_loss1,dA_loss2, dB_loss1,dB_loss2, g_loss1,g_loss2))
		# evaluate the model performance every so often
		if (i+1) % (bat_per_epo * 1) % (n_epochs) == 0:
			# plot A->B translation
			summarize_performance(i, g_model_AtoB, trainA, 'AtoB')
			# plot B->A translation
			summarize_performance(i, g_model_BtoA, trainB, 'BtoA')
		if (i+1) % (bat_per_epo * 5) == 0:
			# save the models
			save_models(i, g_model_AtoB, g_model_BtoA)

# load image data
dataset = load_real_samples('/content/drive/MyDrive/mri2ct_512.npz')
print('Loaded', dataset[0].shape, dataset[1].shape)
# define input shape based on the loaded dataset
image_shape = dataset[0].shape[1:]
# generator: A -> B
g_model_AtoB = define_generator(image_shape)
# generator: B -> A
g_model_BtoA = define_generator(image_shape)
# discriminator: A -> [real/fake]
d_model_A = define_discriminator(image_shape)
# discriminator: B -> [real/fake]
d_model_B = define_discriminator(image_shape)
# composite: A -> B -> [real/fake, A]
c_model_AtoB = define_composite_model(g_model_AtoB, d_model_B, g_model_BtoA, image_shape)
# composite: B -> A -> [real/fake, B]
c_model_BtoA = define_composite_model(g_model_BtoA, d_model_A, g_model_AtoB, image_shape)
# train models
train(d_model_A, d_model_B, g_model_AtoB, g_model_BtoA, c_model_AtoB, c_model_BtoA, dataset)

Loaded (300, 512, 512, 3) (300, 512, 512, 3)




>1, dA[2.154,4.941] dB[1.373,58.094] g[1198.283,1716.061]
>2, dA[84.379,10.852] dB[6.921,2.813] g[1214.799,1731.016]
>3, dA[11.550,13.354] dB[11.979,10.251] g[1107.733,1696.099]
>4, dA[5.917,7.573] dB[0.739,0.790] g[1112.313,1667.364]
>5, dA[1.408,0.406] dB[0.546,0.191] g[1132.506,1672.333]
>6, dA[0.438,0.553] dB[0.239,0.178] g[1130.948,1659.328]
>7, dA[0.201,0.208] dB[0.338,0.125] g[1102.644,1642.618]
>8, dA[0.141,0.215] dB[0.169,0.103] g[1168.447,1677.707]
>9, dA[0.152,0.144] dB[0.137,0.139] g[1106.346,1645.746]
>10, dA[0.148,0.138] dB[0.163,0.145] g[1103.117,1639.196]
>11, dA[0.155,0.123] dB[0.122,0.140] g[1133.817,1658.740]
>12, dA[0.122,0.122] dB[0.095,0.116] g[1206.470,1707.637]
>13, dA[0.138,0.105] dB[0.092,0.100] g[1214.724,1723.297]
>14, dA[0.124,0.106] dB[0.097,0.113] g[1113.651,1655.117]
>15, dA[0.121,0.097] dB[0.075,0.097] g[1202.783,1704.508]
>16, dA[0.110,0.097] dB[0.076,0.096] g[1219.084,1725.410]
>17, dA[0.118,0.100] dB[0.140,0.159] g[1108.591,1650.104]
>18, dA[0.106,0.







>61, dA[0.095,0.085] dB[0.143,0.189] g[1096.558,1633.596]
>62, dA[0.094,0.081] dB[0.164,0.186] g[1099.820,1639.659]
>63, dA[0.050,0.093] dB[0.338,0.672] g[1111.819,1651.131]
>64, dA[0.067,0.054] dB[1.111,0.225] g[1211.078,1711.232]
>65, dA[0.061,0.109] dB[0.088,0.317] g[1113.307,1647.622]
>66, dA[0.096,0.031] dB[0.121,0.091] g[1151.393,1674.467]
>67, dA[0.074,0.083] dB[0.357,0.351] g[1118.358,1654.508]
>68, dA[0.081,0.055] dB[0.220,0.296] g[1112.752,1648.941]
>69, dA[0.046,0.121] dB[0.480,0.310] g[1196.956,1697.199]
>70, dA[0.068,0.089] dB[0.202,0.077] g[1208.474,1717.643]
>71, dA[0.029,0.151] dB[0.304,0.859] g[1211.146,1718.712]
>72, dA[0.111,0.050] dB[0.842,0.100] g[1111.455,1649.364]
>73, dA[0.088,0.176] dB[0.181,0.455] g[1210.656,1716.059]
>74, dA[0.018,0.089] dB[0.268,0.194] g[1105.782,1646.774]
>75, dA[0.051,0.040] dB[0.619,0.287] g[1202.746,1710.145]
>76, dA[0.058,0.144] dB[0.057,0.653] g[1192.971,1708.879]
>77, dA[0.090,0.189] dB[0.337,0.174] g[1096.498,1628.821]
>78, dA[0.172,







>121, dA[0.089,0.224] dB[0.451,0.320] g[1141.260,1666.286]
>122, dA[0.254,0.145] dB[0.223,0.059] g[1086.533,1620.333]
>123, dA[0.128,0.233] dB[0.130,0.099] g[1100.064,1639.375]
>124, dA[0.084,0.060] dB[0.086,0.050] g[1097.178,1637.052]
>125, dA[0.181,0.208] dB[0.036,0.137] g[1184.100,1679.384]
>126, dA[1.128,0.424] dB[0.143,0.373] g[1101.373,1644.365]
>127, dA[1.465,2.477] dB[0.756,0.512] g[1104.891,1644.327]
>128, dA[1.224,0.211] dB[0.091,0.153] g[1101.688,1644.167]
>129, dA[0.744,0.449] dB[0.056,0.074] g[1127.203,1661.925]
>130, dA[0.143,0.149] dB[0.050,0.039] g[1186.519,1700.129]
>131, dA[0.173,0.116] dB[0.068,0.101] g[1165.254,1687.974]
>132, dA[0.081,0.078] dB[0.061,0.038] g[1168.211,1684.956]
>133, dA[0.066,0.093] dB[0.125,0.139] g[1099.440,1640.695]
>134, dA[0.073,0.064] dB[0.034,0.102] g[1109.154,1651.043]
>135, dA[0.066,0.061] dB[0.341,0.284] g[1185.895,1702.345]
>136, dA[0.039,0.177] dB[0.031,0.362] g[1092.179,1626.813]
>137, dA[0.294,0.098] dB[0.791,0.582] g[1082.459,1610.99

