In [15]:
from random import random
from numpy import load
from numpy import zeros
from numpy import ones
from numpy import asarray
from numpy.random import randint
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import plot_model
from keras.initializers import RandomNormal
from keras.models import Model
from keras.models import Input
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Activation
from keras.layers import Concatenate
from matplotlib import pyplot

In [16]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [17]:
from keras.layers import Layer, InputSpec
from keras import initializers, regularizers, constraints
from keras import backend as K


class InstanceNormalization(Layer):
    """Instance normalization layer.
    Normalize the activations of the previous layer at each step,
    i.e. applies a transformation that maintains the mean activation
    close to 0 and the activation standard deviation close to 1.
    # Arguments
        axis: Integer, the axis that should be normalized
            (typically the features axis).
            For instance, after a `Conv2D` layer with
            `data_format="channels_first"`,
            set `axis=1` in `InstanceNormalization`.
            Setting `axis=None` will normalize all values in each
            instance of the batch.
            Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors.
        epsilon: Small float added to variance to avoid dividing by zero.
        center: If True, add offset of `beta` to normalized tensor.
            If False, `beta` is ignored.
        scale: If True, multiply by `gamma`.
            If False, `gamma` is not used.
            When the next layer is linear (also e.g. `nn.relu`),
            this can be disabled since the scaling
            will be done by the next layer.
        beta_initializer: Initializer for the beta weight.
        gamma_initializer: Initializer for the gamma weight.
        beta_regularizer: Optional regularizer for the beta weight.
        gamma_regularizer: Optional regularizer for the gamma weight.
        beta_constraint: Optional constraint for the beta weight.
        gamma_constraint: Optional constraint for the gamma weight.
    # Input shape
        Arbitrary. Use the keyword argument `input_shape`
        (tuple of integers, does not include the samples axis)
        when using this layer as the first layer in a Sequential model.
    # Output shape
        Same shape as input.
    # References
        - [Layer Normalization](https://arxiv.org/abs/1607.06450)
        - [Instance Normalization: The Missing Ingredient for Fast Stylization](
        https://arxiv.org/abs/1607.08022)
    """
    def __init__(self,
                 axis=None,
                 epsilon=1e-3,
                 center=True,
                 scale=True,
                 beta_initializer='zeros',
                 gamma_initializer='ones',
                 beta_regularizer=None,
                 gamma_regularizer=None,
                 beta_constraint=None,
                 gamma_constraint=None,
                 **kwargs):
        super(InstanceNormalization, self).__init__(**kwargs)
        self.supports_masking = True
        self.axis = axis
        self.epsilon = epsilon
        self.center = center
        self.scale = scale
        self.beta_initializer = initializers.get(beta_initializer)
        self.gamma_initializer = initializers.get(gamma_initializer)
        self.beta_regularizer = regularizers.get(beta_regularizer)
        self.gamma_regularizer = regularizers.get(gamma_regularizer)
        self.beta_constraint = constraints.get(beta_constraint)
        self.gamma_constraint = constraints.get(gamma_constraint)

    def build(self, input_shape):
        ndim = len(input_shape)
        if self.axis == 0:
            raise ValueError('Axis cannot be zero')

        if (self.axis is not None) and (ndim == 2):
            raise ValueError('Cannot specify axis for rank 1 tensor')

        self.input_spec = InputSpec(ndim=ndim)

        if self.axis is None:
            shape = (1,)
        else:
            shape = (input_shape[self.axis],)

        if self.scale:
            self.gamma = self.add_weight(shape=shape,
                                         name='gamma',
                                         initializer=self.gamma_initializer,
                                         regularizer=self.gamma_regularizer,
                                         constraint=self.gamma_constraint)
        else:
            self.gamma = None
        if self.center:
            self.beta = self.add_weight(shape=shape,
                                        name='beta',
                                        initializer=self.beta_initializer,
                                        regularizer=self.beta_regularizer,
                                        constraint=self.beta_constraint)
        else:
            self.beta = None
        self.built = True

    def call(self, inputs, training=None):
        input_shape = K.int_shape(inputs)
        reduction_axes = list(range(0, len(input_shape)))

        if self.axis is not None:
            del reduction_axes[self.axis]

        del reduction_axes[0]

        mean = K.mean(inputs, reduction_axes, keepdims=True)
        stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
        normed = (inputs - mean) / stddev

        broadcast_shape = [1] * len(input_shape)
        if self.axis is not None:
            broadcast_shape[self.axis] = input_shape[self.axis]

        if self.scale:
            broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
            normed = normed * broadcast_gamma
        if self.center:
            broadcast_beta = K.reshape(self.beta, broadcast_shape)
            normed = normed + broadcast_beta
        return normed

    def get_config(self):
        config = {
            'axis': self.axis,
            'epsilon': self.epsilon,
            'center': self.center,
            'scale': self.scale,
            'beta_initializer': initializers.serialize(self.beta_initializer),
            'gamma_initializer': initializers.serialize(self.gamma_initializer),
            'beta_regularizer': regularizers.serialize(self.beta_regularizer),
            'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
            'beta_constraint': constraints.serialize(self.beta_constraint),
            'gamma_constraint': constraints.serialize(self.gamma_constraint)
        }
        base_config = super(InstanceNormalization, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

In [18]:
# define the Discriminator model: 70x70 patchGAN

def discriminator(image_shape):
  init = RandomNormal(stddev=0.02)  # weight initialization
  in_image = Input(shape=image_shape)  # source image input
  
  d = Conv2D(64, (4,4), strides=(2,2), padding="same", kernel_initializer=init)(in_image)  # C64: 4x4 kernel, stride=2x2
  d = LeakyReLU(alpha=0.2)(d)
  d = Conv2D(128, (4,4), strides=(2,2), padding="same", kernel_initializer=init)(d)  # C128: 4x4 kernel, stride=2x2
  d = InstanceNormalization(axis=-1)(d)
  d = LeakyReLU(alpha=0.2)(d)
  d = Conv2D(256, (4,4), strides=(2,2), padding="same", kernel_initializer=init)(d)  # C256: 4x4 kernel, stride=2x2
  d = InstanceNormalization(axis=-1)(d)
  d = LeakyReLU(alpha=0.2)(d)
  
  d = Conv2D(512, (4,4), padding="same", kernel_initializer=init)(d)  # second last output layer : 4x4 kernel, stride=1x1
  d = InstanceNormalization(axis=-1)(d)
  d = LeakyReLU(alpha=0.2)(d)
  
  patch_out = Conv2D(1, (4,4), padding="same", kernel_initializer=init)(d)  # patch output
  
  model = Model(inputs=in_image, outputs=patch_out)
  model.compile(loss="mse", optimizer=Adam(lr=0.0002, beta_1=0.5), loss_weights=[0.5])  # compile model
  plot_model(model, to_file='/content/drive/MyDrive/Strojno_projekt/models/discriminator_cycleGAN.png', show_shapes=True)

  return model

In [19]:
# residual block that contains two 3×3 convolutional layers with the same number of filters on both layers

def resnet_block(n_filters, input_layer):
	init = RandomNormal(stddev=0.02)  # weight initialization
	
	g = Conv2D(n_filters, (3,3), padding="same", kernel_initializer=init)(input_layer)
	g = InstanceNormalization(axis=-1)(g)
	g = Activation("relu")(g)
	
	g = Conv2D(n_filters, (3,3), padding="same", kernel_initializer=init)(g)
	g = InstanceNormalization(axis=-1)(g)
	
	g = Concatenate()([g, input_layer])  # concatenate channel-wise with input layer

	return g

In [20]:
# define the Generator model: encoder-decoder type architecture
# c7s1-k =  7×7 Convolution-InstanceNorm-ReLU layer with k filters and stride 1. 
# dk = 3×3 Convolution-InstanceNorm-ReLU layer with k filters and stride 2.
# Rk = residual block that contains two 3×3 convolutional layers
# uk = 3×3 fractional-strided-Convolution InstanceNorm-ReLU layer with k filters and stride 1/2

def generator(image_shape, n_resnet=9):
  init = RandomNormal(stddev=0.02)  # weight initialization
  in_image = Input(shape=image_shape)  # image input
  
  # c7s1-64
  g = Conv2D(64, (7,7), padding="same", kernel_initializer=init)(in_image)
  g = InstanceNormalization(axis=-1)(g)
  g = Activation("relu")(g)

  # d128
  g = Conv2D(128, (3,3), strides=(2,2), padding="same", kernel_initializer=init)(g)
  g = InstanceNormalization(axis=-1)(g)
  g = Activation("relu")(g)

  # d256
  g = Conv2D(256, (3,3), strides=(2,2), padding="same", kernel_initializer=init)(g)
  g = InstanceNormalization(axis=-1)(g)
  g = Activation("relu")(g)
  
  # R256
  for _ in range(n_resnet):
    g = resnet_block(256, g)
  
  # u128
  g = Conv2DTranspose(128, (3,3), strides=(2,2), padding="same", kernel_initializer=init)(g)
  g = InstanceNormalization(axis=-1)(g)
  g = Activation("relu")(g)

  # u64
  g = Conv2DTranspose(64, (3,3), strides=(2,2), padding="same", kernel_initializer=init)(g)
  g = InstanceNormalization(axis=-1)(g)
  g = Activation("relu")(g)

  # c7s1-3
  g = Conv2D(3, (7,7), padding="same", kernel_initializer=init)(g)
  g = InstanceNormalization(axis=-1)(g)
  out_image = Activation("tanh")(g)
  
  model = Model(inputs=in_image, outputs=out_image)
  plot_model(model, to_file='/content/drive/MyDrive/Strojno_projekt/models/generator_cycleGAN.png', show_shapes=True)

  return model

In [21]:
# define a composite model for updating generators by adversarial and cycle loss

def composite_model(g_model_1, d_model, g_model_2, image_shape):
  g_model_1.trainable = True
  d_model.trainable = False
  g_model_2.trainable = False
    
  # adversarial loss
  input_gen = Input(shape=image_shape)
  gen1_out = g_model_1(input_gen)
  output_d = d_model(gen1_out)

  # identity loss
  input_id = Input(shape=image_shape)
  output_id = g_model_1(input_id)

  # cycle loss - forward
  output_f = g_model_2(gen1_out)

  # cycle loss - backward
  gen2_out = g_model_2(input_id)
  output_b = g_model_1(gen2_out)
    
  # define model graph
  model = Model(inputs=[input_gen, input_id], outputs=[output_d, output_id, output_f, output_b])
  
  opt = Adam(lr=0.0002, beta_1=0.5)
  model.compile(loss=["mse", "mae", "mae", "mae"], loss_weights=[1, 5, 10, 10], optimizer=opt)
  plot_model(model, to_file='/content/drive/MyDrive/Strojno_projekt/models/gan_cycleGAN.png', show_shapes=True)

  return model

In [22]:
# load and prepare training images

def load_real_samples(filename):
  data = load(filename)
  X1, X2 = data["arr_0"], data["arr_1"]
	
  # [0,255] -> [-1,1]
  X1 = (X1 - 127.5) / 127.5
  X2 = (X2 - 127.5) / 127.5

  return [X1, X2]

In [23]:
# select a batch of random samples, returns images and target

def generate_real_samples(dataset, n_samples, patch_shape):
	ix = randint(0, dataset.shape[0], n_samples)  # choose random instances
	X = dataset[ix]

	y = ones((n_samples, patch_shape, patch_shape, 1))
 
	return X, y

In [24]:
# generate a batch of images, returns images and targets
 
def generate_fake_samples(g_model, dataset, patch_shape):
	X = g_model.predict(dataset)
 
	y = zeros((len(X), patch_shape, patch_shape, 1))

	return X, y

In [25]:
# periodically save the generator models to file

def save_models(step, g_model_AtoB, g_model_BtoA):
	filename1 = "/content/drive/MyDrive/Strojno_projekt/models/cycleGAN_saved_models/g_model_AtoB_%06d.h5" % (step+1)  # save the first generator model
	g_model_AtoB.save(filename1)
	
	filename2 = "/content/drive/MyDrive/Strojno_projekt/models/cycleGAN_saved_models/g_model_BtoA_%06d.h5" % (step+1)  # save the second generator model
	g_model_BtoA.save(filename2)
 
	print("\n")
	print(">Saved: %s and %s" % (filename1, filename2))

In [26]:
# periodically generate images using the save model and plot input and output images

def summarize_performance(step, g_model, trainX, name, n_samples=5):
	X_in, _ = generate_real_samples(trainX, n_samples, 0)  # select a sample of input images
	X_out, _ = generate_fake_samples(g_model, X_in, 0)  # generate translated images
	# [-1,1] -> [0,1]
	X_in = (X_in + 1) / 2.0
	X_out = (X_out + 1) / 2.0
	
  # plot real images
	for i in range(n_samples):
		pyplot.subplot(2, n_samples, 1 + i)
		pyplot.axis("off")
		pyplot.imshow(X_in[i])
	# plot translated image
	for i in range(n_samples):
		pyplot.subplot(2, n_samples, 1 + n_samples + i)
		pyplot.axis("off")
		pyplot.imshow(X_out[i])
	# save plot to file
	filename1 = "/content/drive/MyDrive/Strojno_projekt/models/cycleGAN_saved_models/%s_generated_plot_%06d.png" % (name, (step+1))
	pyplot.savefig(filename1)
	pyplot.close()

In [27]:
# update image pool for fake images to reduce model oscillation
# update discriminators using a history of generated images rather than the ones produced by the latest generators

def update_image_pool(pool, images, max_size=50):
	selected = []
	for image in images:
		if len(pool) < max_size:
			# stock the pool
			pool.append(image)
			selected.append(image)
		elif random() < 0.5:
			# use image, but don't add it to the pool
			selected.append(image)
		else:
			# replace an existing image and use replaced image
			ix = randint(0, len(pool))
			selected.append(pool[ix])
			pool[ix] = image
      
	return asarray(selected)

In [28]:
# train cycleGAN models

def train(d_model_A, d_model_B, g_model_AtoB, g_model_BtoA, c_model_AtoB, c_model_BtoA, dataset, epochs=5):
	n_epochs, batch_size = epochs, 1
	n_patch = d_model_A.output_shape[1]  # determine the output square shape of the discriminator

	trainA, trainB = dataset
	poolA, poolB = [], []
	bat_per_epo = int(len(trainA) / batch_size)
	n_steps = bat_per_epo * n_epochs  # number of iterations

	dA1_losses, dA2_losses = [], []
	dB1_losses, dB2_losses = [], []
	g1_losses, g2_losses = [], []
    
	# manually enumerate iterations
	for i in range(n_steps):
		# select a batch of real samples from each domain (A and B)
		X_realA, y_realA = generate_real_samples(trainA, batch_size, n_patch)
		X_realB, y_realB = generate_real_samples(trainB, batch_size, n_patch)
	
		# generate a batch of fake samples using both B to A and A to B generators
		X_fakeA, y_fakeA = generate_fake_samples(g_model_BtoA, X_realB, n_patch)
		X_fakeB, y_fakeB = generate_fake_samples(g_model_AtoB, X_realA, n_patch)
	
		# update fake images in the pool
		X_fakeA = update_image_pool(poolA, X_fakeA)
		X_fakeB = update_image_pool(poolB, X_fakeB)
        
		# update generator B->A via the composite model
		g_loss2, _, _, _, _  = c_model_BtoA.train_on_batch([X_realB, X_realA], [y_realA, X_realA, X_realB, X_realA])
	
		# update discriminator for A -> [real/fake]
		dA_loss1 = d_model_A.train_on_batch(X_realA, y_realA)
		dA_loss2 = d_model_A.train_on_batch(X_fakeA, y_fakeA)
		
    # update generator A->B via the composite model
		g_loss1, _, _, _, _ = c_model_AtoB.train_on_batch([X_realA, X_realB], [y_realB, X_realB, X_realA, X_realB])
	
		# update discriminator for B -> [real/fake]
		dB_loss1 = d_model_B.train_on_batch(X_realB, y_realB)
		dB_loss2 = d_model_B.train_on_batch(X_fakeB, y_fakeB)
		
		print("Iteration>%d, dA[%.3f,%.3f] dB[%.3f,%.3f] g[%.3f,%.3f]" % (i+1, dA_loss1, dA_loss2, dB_loss1, dB_loss2, g_loss1, g_loss2))
		# summarize performance	
		if (i+1) % (bat_per_epo//4) == 0:
			dA1_losses.append(dA_loss1)
			dA2_losses.append(dA_loss2)
			dB1_losses.append(dB_loss1)
			dB2_losses.append(dB_loss2)
			g1_losses.append(g_loss1)
			g2_losses.append(g_loss2)
			
		if (i+1) % bat_per_epo == 0:
			summarize_performance(i, g_model_AtoB, trainA, "AtoB")  # plot A->B translation
			summarize_performance(i, g_model_BtoA, trainB, "BtoA")  # plot B->A translation
			save_models(i, g_model_AtoB, g_model_BtoA)
	 
	pyplot.plot(dA1_losses)
	pyplot.xlabel("Epoch")
	pyplot.ylabel("Loss")
	pyplot.title("DiscriminatorA1 loss")
	pyplot.show()

	pyplot.plot(dA2_losses)
	pyplot.xlabel("Epoch")
	pyplot.ylabel("Loss")
	pyplot.title("DiscriminatorA2 loss")
	pyplot.show()

	pyplot.plot(dB1_losses)
	pyplot.xlabel("Epoch")
	pyplot.ylabel("Loss")
	pyplot.title("DiscriminatorB1 loss")
	pyplot.show()

	pyplot.plot(dB2_losses)
	pyplot.xlabel("Epoch")
	pyplot.ylabel("Loss")
	pyplot.title("DiscriminatorB2 loss")
	pyplot.show()
 
	pyplot.plot(g1_losses)
	pyplot.xlabel("Epoch")
	pyplot.ylabel("Loss")
	pyplot.title("Generator1 loss")
	pyplot.show()

	pyplot.plot(g2_losses)
	pyplot.xlabel("Epoch")
	pyplot.ylabel("Loss")
	pyplot.title("Generator2 loss")
	pyplot.show()