In [None]:
pip install git+https://www.github.com/keras-team/keras-contrib.git -q

In [None]:
!mkdir outputs

In [None]:
from tensorflow.keras.optimizers import Adam
from keras.models import Model
from keras.models import Sequential
from keras.layers import Conv2D,Input,Conv2DTranspose,Concatenate,Activation
from keras.layers import LeakyReLU
from keras.initializers import RandomNormal
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.utils.vis_utils import plot_model

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image
import random

In [None]:
def resnet_block(number_of_filters,prev_layer):
    weight_initialization = RandomNormal(stddev=0.02)
    
    block_part1 = Conv2D(number_of_filters, (3,3), padding='same', kernel_initializer=weight_initialization)(prev_layer)
    block_part1 = InstanceNormalization(axis=-1)(block_part1)
    block_part1 = Activation('relu')(block_part1)

    block_part2 = Conv2D(number_of_filters, (3,3), padding='same', kernel_initializer=weight_initialization)(block_part1)
    block_part2 = InstanceNormalization(axis=-1)(block_part2)

    block = Concatenate()([block_part2, prev_layer])
    return block

In [None]:
#patch gan discriminator 
#C64-C128-C256-C512

def discriminator(image_shape,model_name):
    weight_initialization = RandomNormal(stddev=0.02)
    in_image = Input(shape=image_shape)
    #the below resolutions are for 256,256
    # C64
    #128x128
    d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=weight_initialization)(in_image)
    d = LeakyReLU(alpha=0.2)(d)

    # C128
    #64x64
    d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=weight_initialization)(d)
    d = InstanceNormalization(axis=-1)(d)
    d = LeakyReLU(alpha=0.2)(d)

    # C256
    #32x32
    d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=weight_initialization)(d)
    d = InstanceNormalization(axis=-1)(d)
    d = LeakyReLU(alpha=0.2)(d)

    # C512
    #16x16
    d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=weight_initialization)(d)
    d = InstanceNormalization(axis=-1)(d)
    d = LeakyReLU(alpha=0.2)(d)

    # second last output layer
    #16x16 no strides
    d = Conv2D(512, (4,4), padding='same', kernel_initializer=weight_initialization)(d)
    d = InstanceNormalization(axis=-1)(d)
    d = LeakyReLU(alpha=0.2)(d)

    # patch output
    #16x16
    patch_out = Conv2D(1, (4,4), padding='same', kernel_initializer=weight_initialization)(d)

    model = Model(in_image, patch_out,name = model_name)
    model.compile(loss='mse', optimizer=Adam(learning_rate=0.0002, beta_1=0.5), loss_weights=[0.5])
    return model

In [None]:
#c7s1-64,d128,d256,R256,R256,R256,R256,R256,R256,u128,u64,c7s1-3

def generator(image_shape,resnet_block_count,name):
    weight_initialization = RandomNormal(stddev=0.02)
    in_image = Input(shape=image_shape)
    #the below resolutions are for 256,256
    
    #encoder
    
    # c7s1-64 7x7 kernel stride=1 64 filters
    #256x256
    g = Conv2D(64, (7,7), padding='same', kernel_initializer=weight_initialization)(in_image)
    g = InstanceNormalization(axis=-1)(g)
    g = Activation('relu')(g)

    # d128
    #128x128
    g = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=weight_initialization)(g)
    g = InstanceNormalization(axis=-1)(g)
    g = Activation('relu')(g)

    # d256
    #64x64
    g = Conv2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=weight_initialization)(g)
    g = InstanceNormalization(axis=-1)(g)
    g = Activation('relu')(g)

    # R256
    #9 resnet blocks where dimension does not decrease
    for _ in range(resnet_block_count):
        g = resnet_block(256, g)

    # u128
    #128x128
    g = Conv2DTranspose(128, (3,3), strides=(2,2), padding='same', kernel_initializer=weight_initialization)(g)
    g = InstanceNormalization(axis=-1)(g)
    g = Activation('relu')(g)

    # u64
    #256x256
    g = Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', kernel_initializer=weight_initialization)(g)
    g = InstanceNormalization(axis=-1)(g)
    g = Activation('relu')(g)

    # c7s1-3 
    #as our image is gray scale we will use c7s1-1
    #256x256
    g = Conv2D(1, (7,7), padding='same', kernel_initializer=weight_initialization)(g)
    g = InstanceNormalization(axis=-1)(g)
    out_image = Activation('tanh')(g)

    model = Model(in_image, out_image) #model is not directly compiled because the weights of generator are updated using composite models
    return model

In [None]:
#4 losses for composite model
#adveserial loss (normal gan loss) using mse
#forward cycle (mae)
#backward cycle (mae)
#identity loss (mae)

In [None]:
def composite_model(image_shape,g_model_1, d_model, g_model_2,name):
    g_model_1.trainable = True
    d_model.trainable = False
    g_model_2.trainable = False

    # discriminator element
    #domain-B_image --> generator A --> domain-A_image --> discriminator(is the image in domain A or not) 
    input_gen = Input(shape=image_shape)
    gen1_out = g_model_1(input_gen)
    output_d = d_model(gen1_out)

    # identity element
    #domain-A_image --> generator A --> domain-A_image

    input_id = Input(shape=image_shape)
    output_id = g_model_1(input_id)

    # forward cycle
    #domain-B_image --> generator A --> domain-A_image --> generator B --> domain-B_image
    output_f = g_model_2(gen1_out)

    # backward cycle
    #domain-A_image --> generator B --> domain-B_image --> generator A --> domain-A_image
    gen2_out = g_model_2(input_id)
    output_b = g_model_1(gen2_out)

    #input_id layer = domain A image
    #input_gen = domain B image

    model = Model([input_gen, input_id], [output_d, output_id, output_f, output_b],name = name)
    opt = Adam(learning_rate=0.0002, beta_1=0.5)
    model.compile(loss=['mse', 'mae', 'mae', 'mae'], loss_weights=[1, 5, 10, 10], optimizer=opt)

    return model

In [None]:
img_shape = (256,256,1)


gen_a_to_b = generator(img_shape,9,"gen_a_to_b")
gen_b_to_a = generator(img_shape,9,"gen_b_to_a")

disc_a = discriminator(img_shape,"disc_a")
disc_b = discriminator(img_shape,"disc_b")

#two composite models required for training two generators
composite_a_to_b = composite_model(img_shape,gen_a_to_b,disc_b,gen_b_to_a,"composite_a_to_b")

composite_b_to_a = composite_model(img_shape,gen_b_to_a,disc_a,gen_a_to_b,"composite_b_to_a")

In [None]:
plot_model(gen_a_to_b)

In [None]:
#training steps

# 1.train gen_a_to_b
# 2.train disc_b
# 3.train gen_b_to_a
# 4.train disc_a

In [None]:
def make_dataset(path):
    images = []
    files = os.listdir(path)    
    for f in files:
        img = Image.open(os.path.join(path,f))
        img = img.resize((256,256),Image.BICUBIC).convert("L")
        img = np.asarray(img)
        img = (img / 127.5) - 1.
        images.append(img)
    images =  np.array(images)
    images = np.expand_dims(images,axis=-1)
    return images

In [None]:
def save_picture(model_name,epoch,step):
    test_imgs = []
    images = None

    if model_name == "gen_a_to_b":
        for _ in range(9):
            test_imgs.append(test_a[random.randint(0,test_a.shape[0]-1)])
        images = gen_a_to_b.predict(np.array(test_imgs))
    else:
        for _ in range(9):
            test_imgs.append(test_b[random.randint(0,test_b.shape[0]-1)])
        images = gen_b_to_a.predict(np.array(test_imgs))

    fig, axs = plt.subplots(3,3)
    count = 0
    for i in range(3):
        for j in range(3):
            axs[i,j].imshow(np.squeeze(images[count]*0.5+0.5),cmap="gray")
            axs[i,j].axis('off')
            count += 1
    plt.show()
    plt.close()
    fig.savefig(f"outputs/{model_name}-{epoch}-{step}.png")

In [None]:
train_a = make_dataset("/kaggle/input/horse2zebra-dataset/trainA")
train_b = make_dataset("/kaggle/input/horse2zebra-dataset/trainB")
test_a = make_dataset("/kaggle/input/horse2zebra-dataset/testA")
test_b = make_dataset("/kaggle/input/horse2zebra-dataset/testB")

In [None]:
print(train_a.shape,train_b.shape,test_a.shape,test_b.shape)
print(np.max(train_a),np.min(train_a))

In [None]:
def generate_real_samples(domain, n_samples, patch_shape):
    x = None
    if domain == "train_a":
        ix = np.random.randint(0, train_a.shape[0]-1, n_samples)
        x = train_a[ix]
    else:
        ix = np.random.randint(0, train_b.shape[0]-1, n_samples)
        x = train_b[ix]
    
    y = np.ones((n_samples, patch_shape, patch_shape, 1))
    return x, y


In [None]:
def generate_fake_samples(model_name,data,patch_shape):
    x = None
    if model_name == "gen_a_to_b":
        x = gen_a_to_b.predict(data)
    else:
        x = gen_b_to_a.predict(data)
    
    y = np.zeros((len(x), patch_shape, patch_shape, 1))
    return x, y

In [None]:
#update the discriminators using a history of generated images & Maximum pool size is 50
# update image pool for fake images

def update_image_pool(pool, images, max_size=50):
	selected = list()
	for image in images:
		if len(pool) < max_size:
			# stock the pool
			pool.append(image)
			selected.append(image)
		elif random.random() < 0.5:
			# use image, but don't add it to the pool
			selected.append(image)
		else:
			# replace an existing image and use replaced image
			ix = random.randint(0, len(pool)-1)
			selected.append(pool[ix])
			pool[ix] = image
	return np.asarray(selected)

In [None]:
epochs = 30
batch_size = 1
steps_per_epoch = train_a.shape[0]
patch_size = 16
poolA, poolB = list(), list()

for i in range(epochs):
    print("#################",i+1,"#################")
    for j in range(steps_per_epoch):
        print("step ",j+1,end="\r")
        
        #real samples
        X_realA, y_realA = generate_real_samples("train_a", batch_size, patch_size)
        X_realB, y_realB = generate_real_samples("train_b", batch_size, patch_size)
		
        # fake samples
        X_fakeA, y_fakeA = generate_fake_samples("gen_b_to_a", X_realB, patch_size)
        X_fakeB, y_fakeB = generate_fake_samples("gen_a_to_b", X_realA, patch_size)
  
		# update Pool
        X_fakeA = update_image_pool(poolA, X_fakeA)
        X_fakeB = update_image_pool(poolB, X_fakeB)
  
		# update generator B->A via adversarial and cycle loss
        g_loss2, _, _, _, _  = composite_b_to_a.train_on_batch([X_realB, X_realA], [y_realA, X_realA, X_realB, X_realA])
  
		# update discriminator for A -> [real/fake]
        dA_loss1 = disc_a.train_on_batch(X_realA, y_realA)
        dA_loss2 = disc_a.train_on_batch(X_fakeA, y_fakeA)
  
		# update generator A->B via adversarial and cycle loss
        g_loss1, _, _, _, _ = composite_a_to_b.train_on_batch([X_realA, X_realB], [y_realB, X_realB, X_realA, X_realB])
  
		# update discriminator for B -> [real/fake]
        dB_loss1 = disc_b.train_on_batch(X_realB, y_realB)
        dB_loss2 = disc_b.train_on_batch(X_fakeB, y_fakeB)
  
#         print('>%d, dA[%.3f,%.3f] dB[%.3f,%.3f] g[%.3f,%.3f]' % (i+1, dA_loss1,dA_loss2, dB_loss1,dB_loss2, g_loss1,g_loss2))
        if j%100 == 0:
            save_picture("gen_a_to_b",i+1,j+1)
            save_picture("gen_b_to_a",i+1,j+1)

In [None]:
gen_a_to_b.save("gen_a_to_b.h5")
gen_b_to_a.save("gen_b_to_a.h5")