# Pix2Pix - Image Translation

This is an attempt to re-implement the paper Pix2pix-GAN

Paper: https://arxiv.org/pdf/1611.07004.pdf

In [1]:
import glob
import numpy as np
import tensorflow as tf

In [2]:
image_size = 256
channels = 3
num_images = 800 # 3200
batch_size = 16

In [3]:
def load_images(image_path):
    
    images = tf.io.read_file(image_path)
    images = tf.image.decode_jpeg(images, channels = channels)
    
    images = tf.image.resize(images, [256, 512], method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    
    inp_images, tar_images = images[:, 256:, :], images[:, :256, :]
    
    inp_images = tf.cast(inp_images, tf.float32)
    tar_images = tf.cast(tar_images, tf.float32)
    
    inp_images = (inp_images/127.5) - 1
    tar_images = (tar_images/127.5) - 1
    
    inp_images = tf.image.resize(inp_images, [image_size, image_size], method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    tar_images = tf.image.resize(tar_images, [image_size, image_size], method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    
    return inp_images, tar_images


def load_data(path, num_images = num_images):
    data_list = glob.glob(path)[:num_images]
    
    data = tf.data.Dataset.list_files(data_list)
    data = data.map(load_images)
    data = data.shuffle(num_images).batch(batch_size)
    return data

In [4]:
DIR = 'E:\Image Datasets\Pix2Pix Anime\Datasets\data'
train_data = load_data(DIR + '\\train\\*.png')
test_data = load_data(DIR + '\\val\\*.png', 200)

In [5]:
class Pix2Pix(object):
    def __init__(self, image_size, channels):
        self.init = tf.random_normal_initializer(0, 0.02)
        self.image_shape = (image_size, image_size, channels)
        
    def encoder(self, inp, filters, kernel_size = (3, 3), batch_norm = True):
        
        x = tf.keras.layers.Conv2D(filters = filters, kernel_size = kernel_size, strides = (2, 2), 
                                   padding = 'same', kernel_initializer = self.init, use_bias = True)(inp)
        
        if batch_norm:
            x = tf.keras.layers.BatchNormalization()(x)
            
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        return x
    
    def decoder(self, inp, skip, filters, kernel_size = (3, 3), dropout = 0):
        
        x = tf.keras.layers.Conv2DTranspose(filters = filters, kernel_size = kernel_size, strides = (2, 2), 
                                            padding = 'same', kernel_initializer = self.init, use_bias = True)(inp)
        x = tf.keras.layers.BatchNormalization()(x)
        
        if dropout != 0:
            x = tf.keras.layers.Dropout(dropout)(x)
            
        x = tf.keras.layers.ReLU()(x)
        x = tf.keras.layers.Concatenate()([x, skip])
        return x
    
    
class Generator(Pix2Pix):
    def __init__(self, image_size, channels):
        super().__init__(image_size, channels)
        pass
    
    def generator(self):
        
        inputs = tf.keras.layers.Input(shape = self.image_shape)
        
        x = inputs
        
        e1 = self.encoder(x, 64, batch_norm = False)
        e2 = self.encoder(e1, 128)
        e3 = self.encoder(e2, 256)
        e4 = self.encoder(e3, 512)
        e5 = self.encoder(e4, 512)
        e6 = self.encoder(e5, 512)
        e7 = self.encoder(e6, 512)
        
        b = self.encoder(e7, 512)
        
        d1 = self.decoder(b, e7, 512)
        d2 = self.decoder(d1, e6, 512)
        d3 = self.decoder(d2, e5, 512)
        d4 = self.decoder(d3, e4, 512)
        d5 = self.decoder(d4, e3, 256)
        d6 = self.decoder(d5, e2, 128)
        d7 = self.decoder(d6, e1, 64)
        
        x = tf.keras.layers.Conv2DTranspose(filters = 3, kernel_size = (3, 3), strides = (2, 2), 
                                            padding = 'same', kernel_initializer = self.init, use_bias = True, 
                                            activation = 'tanh')(d7)
        x = tf.keras.layers.BatchNormalization()(x)
        
        return tf.keras.Model(inputs = inputs, outputs = x, name = 'pix2pix_generator')
    
    
class Discriminator(Pix2Pix):
    def __init__(self, image_size, channels):
        super().__init__(image_size, channels)
        pass
    
    def discriminator(self):
        inp_images = tf.keras.layers.Input(shape = self.image_shape, name = 'input_image')
        tar_images = tf.keras.layers.Input(shape = self.image_shape, name = 'target_image')
        
        x = tf.keras.layers.Concatenate()([inp_images, tar_images])
        
        x = self.encoder(x, 64)
        x = self.encoder(x, 128)
        x = self.encoder(x, 256)
        
        x = tf.keras.layers.Conv2D(filters = 1, kernel_size = (3, 3), strides = (1, 1), 
                                   padding = 'valid', kernel_initializer = self.init, use_bias = True)(x)
        
        return tf.keras.Model(inputs = [inp_images, tar_images], outputs = x, name = 'pix2pix_discriminator')

    
gen = Generator(image_size, channels).generator()
disc = Discriminator(image_size, channels).discriminator()

In [6]:
class Losses(object):
    def __init__(self):
        self.loss_object = tf.keras.losses.BinaryCrossentropy(from_logits = True)
        self.LAMBDA = 100
        
    def discriminator_loss(self, disc_real_out, disc_gen_out):
        real_loss = self.loss_object(tf.ones_like(disc_real_out), disc_real_out)
        gen_loss = self.loss_object(tf.zeros_like(disc_gen_out), disc_gen_out)
        return real_loss + gen_loss
    
    def generator_loss(self, disc_gen_out, gen_out, target):
        gen_loss = self.loss_object(tf.ones_like(disc_gen_out), disc_gen_out)
        l1_loss = tf.math.reduce_mean(tf.math.abs(target - gen_out)) * self.LAMBDA
        return gen_loss + l1_loss
    
l = Losses()

In [7]:
gen_optimizer = tf.keras.optimizers.Adam(learning_rate = 2e-4, beta_1 = 0.5)
disc_optimizer = tf.keras.optimizers.Adam(learning_rate = 2e-4, beta_1 = 0.5)

In [8]:
@tf.function
def train_step(inp, tar):
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_out = gen(inp, training = True)
        
        disc_real_out = disc([inp, tar], training = True)
        disc_gen_out = disc([inp, gen_out], training = True)
        
        disc_loss = l.discriminator_loss(disc_real_out, disc_gen_out)
        gen_loss = l.generator_loss(disc_gen_out, gen_out, tar)
        
    gen_grads = gen_tape.gradient(gen_loss, gen.trainable_variables)
    disc_grads = disc_tape.gradient(disc_loss, disc.trainable_variables)
    
    gen_optimizer.apply_gradients(zip(gen_grads, gen.trainable_variables))
    disc_optimizer.apply_gradients(zip(disc_grads, disc.trainable_variables))
    
    return disc_loss, gen_loss

In [9]:
def train(data, epochs = 1):
    for e in range(epochs):
        print(f'Epoch: {e} Starts')
        
        for inp, tar in data:
            disc_loss, gen_loss = train_step(inp, tar)
            print('.', end='')
            
        print(f'\nGenerator Loss: {gen_loss} \t Discriminator Loss: {disc_loss}')
        print(f'Epoch: {e} Ends\n\n\n')

In [10]:
train(train_data)

Epoch: 0 Starts
..................................................
Generator Loss: 75.62659454345703 	 Discriminator Loss: 0.5931085348129272
Epoch: 0 Ends



