In [None]:
from keras.layers import Layer, Input, Dropout, Conv2D, Activation, add, UpSampling2D,     Conv2DTranspose, Flatten, Reshape
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization, InputSpec
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Model
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import time
import os
import keras.backend as K
import tensorflow as tf
from skimage.transform import resize
from skimage import color
from helper_funcs import *

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="7"

# ### Model parameters
# 
# This CycleGAN implementation allows a lot of freedom on both the training parameters and the network architecture.

opt = {}

# Data
opt['channels'] = 1
opt['img_shape'] = (256,256,1)

# Architecture parameters
opt['use_dropout'] = False  # Dropout in residual blocks
opt['use_bias'] = True  # Use bias
opt['use_resize_convolution'] = True  # Resize convolution - instead of transpose convolution in deconvolution layers (uk) - can reduce checkerboard artifacts but the blurring might affect the cycle-consistency

# Tweaks
opt['REAL_LABEL'] = 1.0  # Use e.g. 0.9 to avoid training the discriminators to zero loss

# ### Model architecture
# 
# #### Layer blocks
# These are the individual layer blocks that are used to build the generators and discriminator. More information can be found in the appendix of the [CycleGAN paper](https://arxiv.org/abs/1703.10593).

# Discriminator layers
def ck(model, opt, x, k, use_normalization, use_bias):
    x = Conv2D(filters=k, kernel_size=4, strides=2, padding='same', use_bias=use_bias)(x)
    if use_normalization:
        x = model['normalization'](axis=3, center=True, epsilon=1e-5)(x, training=True)
    x = LeakyReLU(alpha=0.2)(x)
    return x

# First generator layer
def c7Ak(model, opt, x, k):
    x = Conv2D(filters=k, kernel_size=7, strides=1, padding='valid', use_bias=opt['use_bias'])(x)
    x = model['normalization'](axis=3, center=True, epsilon=1e-5)(x, training=True)
    x = Activation('relu')(x)
    return x

# Downsampling
def dk(model, opt, x, k):  # Should have reflection padding
    x = Conv2D(filters=k, kernel_size=3, strides=2, padding='same', use_bias=opt['use_bias'])(x)
    x = model['normalization'](axis=3, center=True, epsilon=1e-5)(x, training=True)
    x = Activation('relu')(x)
    return x

# Residual block
def Rk(model, opt, x0):
    k = int(x0.shape[-1])

    # First layer
    x = ReflectionPadding2D((1,1))(x0)
    x = Conv2D(filters=k, kernel_size=3, strides=1, padding='valid', use_bias=opt['use_bias'])(x)
    x = model['normalization'](axis=3, center=True, epsilon=1e-5)(x, training=True)
    x = Activation('relu')(x)

    if opt['use_dropout']:
        x = Dropout(0.5)(x)

    # Second layer
    x = ReflectionPadding2D((1, 1))(x)
    x = Conv2D(filters=k, kernel_size=3, strides=1, padding='valid', use_bias=opt['use_bias'])(x)
    x = model['normalization'](axis=3, center=True, epsilon=1e-5)(x, training=True)
    # Merge
    x = add([x, x0])

    return x

# Upsampling
def uk(model, opt, x, k):
    # (up sampling followed by 1x1 convolution <=> fractional-strided 1/2)
    if opt['use_resize_convolution']:
        x = UpSampling2D(size=(2, 2))(x)  # Nearest neighbor upsampling
        x = ReflectionPadding2D((1, 1))(x)
        x = Conv2D(filters=k, kernel_size=3, strides=1, padding='valid', use_bias=opt['use_bias'])(x)
    else:
        x = Conv2DTranspose(filters=k, kernel_size=3, strides=2, padding='same', use_bias=opt['use_bias'])(x)  # this matches fractionally stided with stride 1/2
    x = model['normalization'](axis=3, center=True, epsilon=1e-5)(x, training=True)
    x = Activation('relu')(x)
    return x

# #### Architecture functions

def build_generator(model, opt, name=None):
    # Layer 1: Input

    input_img = Input(shape=opt['img_shape'])
    x = ReflectionPadding2D((3, 3))(input_img)
    x = c7Ak(model, opt, x, 32)

    # Layer 2-3: Downsampling
    x = dk(model, opt, x, 64)
    x = dk(model, opt, x, 128)

    # Layers 4-12: Residual blocks
    for _ in range(4, 13):
        x = Rk(model, opt, x)

    # Layer 13:14: Upsampling
    x = uk(model, opt, x, 64)
    x = uk(model, opt, x, 32)

    # Layer 15: Output
    x = ReflectionPadding2D((3, 3))(x)
    x = Conv2D(opt['channels'], kernel_size=7, strides=1, padding='valid', use_bias=True)(x)
    x = Activation('tanh')(x)
    # x = Reshape((217,181,1))(x)
    # print("Generator Model:")
    # print(Model(inputs=input_img, outputs=x, name=name).summary())
    return Model(inputs=input_img, outputs=x, name=name)

# Load Model

model = {}
# Normalization
model['normalization'] = InstanceNormalization
model['G_A2B'] = build_generator(model, opt, name='G_A2B_model')
model['G_B2A'] = build_generator(model, opt, name='G_B2A_model')


In [None]:
# This cell tests the A2B images. I know this code is not good. If someone can help me optimize this code, I would be grateful.
# You need to change the "weight_path" and "save_path" to your path. And according to your file name change the conditional sentence.

GA2B = model['G_A2B']

# You need to update the "weight_path", "save_path", "dataset_path" information
weight_path = '/data/jiayuan/saved_models'
save_path = '/data/jiayuan/SaveTestImage'
dataset_path = '/data/jiayuan/dataset/'

files_name = os.listdir(weight_path)

for name in files_name:
    if 'trip-ce-ssimG_rate5True10.0100datasettime' in name: # You need to change this.
#         b = name.split('.0')
        b = name.split('.0') # Change here
        c = b[0].split('True') # Change here
        file_name = c[0]+c[1]
        a=b[1].split('time') # Change here
        weight_file = os.path.join(weight_path,name)
        for weight in os.listdir(weight_file):
            if 'A2B' in weight and '.hdf5' in weight:
                print(os.path.join(weight_file,weight))
                GA2B.load_weights(os.path.join(weight_file,weight))
                m=weight.split("_")[5]
                c=m.split('.')

                if not os.path.exists(os.path.join(save_path,file_name,b[1],'a2b',c[0])):
                    os.makedirs(os.path.join(save_path,file_name,b[1],'a2b',c[0]))
                image_path = dataset_path+a[0]+'/testCT'
                image_another = dataset_path+a[0]+'/testMRI'
                if a[0]=="T1-T2":
                    image_path = dataset_path+a[0]+'/TestT1'
                    image_another = dataset_path+a[0]+'/TestT2'
                for images in os.listdir(image_path):
                    image = mpimg.imread(os.path.join(image_path,images))
                    image = resize(image,(256,256))
                    image = image[:, :, np.newaxis]
                    image = image * 2 - 1
                    real_image = image
                    image = np.reshape(image,(1, 256,256,1))
                    real_another = mpimg.imread(os.path.join(image_another,images))
                    real_another = resize(real_another,(256,256))
                    real_another = real_another[:, :, np.newaxis]
                    real_another = real_another * 2 - 1
                    im = GA2B.predict(image)
                    im = np.reshape(im,(256,256))
                    im = im[:, :, np.newaxis]
                    save_image_path = os.path.join(save_path,file_name,b[1],'a2b',c[0],images)
                    # This is to save the source image, predict image, and target image.
                    join_and_save(opt, (real_image, im, real_another), save_image_path)

In [None]:
# This cell tests the B2A images. I know this code is not good. If someone can help me optimize this code, I would be grateful.
# You need to change the "weight_path" and "save_path" to your path. And according to your file name change the conditional sentence.


GB2A = model['G_B2A']

files_name = os.listdir(weight_path)

for name in files_name:
    if 'trip-ce-ssimG_rate5True10.0100datasettime1' in name: # You need to change this.
        b = name.split('.0') # Change here
        c = b[0].split('True') # Change here
        file_name = c[0]+c[1]
        a=b[1].split('time') # Change here
        weight_file = os.path.join(weight_path,name)
        for weight in os.listdir(weight_file):
            if 'B2A' in weight and '.hdf5' in weight:
                print(os.path.join(weight_file,weight))
                GB2A.load_weights(os.path.join(weight_file,weight))
                m=weight.split("_")[5]
                c=m.split('.')

                if not os.path.exists(os.path.join(save_path,file_name,b[1],'b2a',c[0])):
                    os.makedirs(os.path.join(save_path,file_name,b[1],'b2a',c[0]))
                image_path = dataset_path+a[0]+'/testMRI'
                image_another = dataset_path+a[0]+'/testCT'
                if a[0]=="T1-T2":
                    image_path = dataset_path+a[0]+'/TestT2'
                    image_another = dataset_path+a[0]+'/TestT1'
                for images in os.listdir(image_path):
                    image = mpimg.imread(os.path.join(image_path,images))
                    image = resize(image,(256,256))
                    image = image[:, :, np.newaxis]
                    image = image * 2 - 1
                    real_image = image
                    image = np.reshape(image,(1, 256,256,1))
                    real_another = mpimg.imread(os.path.join(image_another,images))
                    real_another = resize(real_another,(256,256))
                    real_another = real_another[:, :, np.newaxis]
                    real_another = real_another * 2 - 1
                    real_another = real_another
                    im = GB2A.predict(image)
                    im = np.reshape(im,(256,256))
                    im = im[:, :, np.newaxis]
                    save_image_path = os.path.join(save_path,file_name,b[1],'b2a',c[0],images)
                    # This is to save the source image, predict image, and target image.
                    join_and_save(opt,(real_image, im, real_another) , save_image_path) 