### Image to image translation: pix2pix gan
Pix2pix gans are a sort of gans that are able to learn the mapping from between pairs of images. For example it can learn to transfrom black and white images into colorful images, turn google map photos into aerial images and also turn drawings into colourful images. 
This notebook is an attempt to train a pix2pix gan that learn the mapping between the drawing of a shoe and the actual RGB image of the shoe.


In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import glob
import random

In [None]:
def read_image(image_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_image(image, channels=3)
    width = tf.shape(image)[1]
    width_half = width//2

    input_image = image[:,:width_half,:]
    target_image = image[:,width_half:,:]

    input_image = tf.cast(input_image, dtype=tf.float32) 
    target_image = tf.cast(target_image, dtype=tf.float32)
    return input_image, target_image

In [None]:
def normalize(input_image, target_image):
    input_image = (input_image / 127.5) - 1
    target_image = (target_image / 127.5) - 1
    return input_image, target_image

In [None]:
def preprocess_fn(image_path):
    input_image, target_image = read_image(image_path)
    input_image, target_image = normalize(input_image, target_image)
    return input_image, target_image

In [None]:
train_dir = "/home/basaadi/projects/gan_trans/data/train"
sample_paths = glob.glob(train_dir+"/*.jpg")
test_sample_paths = sample_paths[1000:1100]
sample_paths = sample_paths[:1000]
print(f"len sample paths : {len(sample_paths)}")
print(f"len test_sample_paths  : {len(test_sample_paths)}")

In [None]:
sample_path = random.choice(sample_paths)
input_image, target_image = preprocess_fn(sample_path)


In [None]:
fig, axs = plt.subplots(1,2,figsize=(8,8))
axs[0].imshow(input_image)
axs[1].imshow(target_image)

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
batch_size = 32

train_dataset = tf.data.Dataset.from_tensor_slices(sample_paths)
train_dataset = train_dataset.map(preprocess_fn, num_parallel_calls=AUTOTUNE)
train_dataset = train_dataset.shuffle(100)
train_dataset = train_dataset.batch(batch_size)


In [None]:
# definition du modele
class Block(tf.keras.layers.Layer):
    def __init__(self,  out_ch):
        super().__init__()
        self.conv1 = tf.keras.layers.Conv2D(out_ch, (4, 4), 2,kernel_initializer='he_normal', padding='same')
        self.BatchNormalization1 = tf.keras.layers.BatchNormalization()
        self.leaky_relu1  = tf.keras.layers.LeakyReLU(alpha=0.01)

    def call(self, x, training=False):
        x = self.conv1(x)
        x = self.BatchNormalization1(x)
        x = self.leaky_relu1(x)
        return x
        
    def model(self, input_shape):
        x = tf.keras.Input(input_shape)
        return tf.keras.Model(inputs=[x], outputs=self.call(x))


In [None]:
block = Block(30)
print(block.model((256,256,1)).summary())

In [None]:

class Encoder(tf.keras.layers.Layer):    
    def __init__(self, chs=(32,64, 128, 256, 512)):
        super().__init__()
        self.enc_blocks = [Block(ch) for ch in chs]
    
    def call(self, x):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x)
            ftrs.append(x)
        return ftrs
        
    def model(self, input_shape):
        x = tf.keras.Input(input_shape)
        return tf.keras.Model(inputs=[x], outputs=self.call(x))

In [None]:
encoder = Encoder(chs=(64, 128, 256, 512, 512, 512, 512, 512))
input = tf.ones((1,256,256,3))
enc_ftrs = encoder(input)
# print(encoder.model((256,256,1)).summary())
for enc_ftr in enc_ftrs:
    print(enc_ftr.shape)

In [None]:
class Decoder(tf.keras.layers.Layer):
# class Decoder(tf.keras.layers.Layer):
    def __init__(self, chs=(64, 32, 16)):
        super().__init__()
        self.chs         = chs
        self.upconvs    = [tf.keras.layers.Conv2DTranspose(ch, (4, 4), strides=(2, 2), padding='same') for ch in chs[1:]]
        self.dec_blocks = [Block(chs[i]) for i in range(1, len(chs))]
        
    def call(self, x, encoder_features):
        for i in range(len(self.chs)-1):
            x        = self.upconvs[i](x)
            x        = tf.keras.layers.concatenate([x, encoder_features[i]])
        return x
        
    def model(self, input_shape):
        x = tf.keras.Input(input_shape)
        return tf.keras.Model(inputs=[x], outputs=self.call(x))

In [None]:
enc_ftrs[-1].shape

In [None]:
decoder = Decoder(chs=(512, 512, 512, 512, 512, 256, 128, 64))
dec_ftr = decoder(enc_ftrs[-1], enc_ftrs[::-1][1:])

In [None]:

class Generator(tf.keras.Model):
    def __init__(self, enc_chs=(64, 128, 256, 512, 512, 512, 512, 512), dec_chs=(512,512,512,512,512, 256, 128, 64),num_channels=3):
        super().__init__()
        self.encoder     = Encoder(enc_chs)
        self.decoder     = Decoder(dec_chs)
        self.conv = tf.keras.layers.Conv2DTranspose(num_channels, (4, 4), strides=(2, 2), padding='same')

    def call(self, x):
        enc_ftrs = self.encoder(x)
        out      = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        out      = self.conv(out)
        return out
    
    def model(self, input_shape):
        x = tf.keras.Input(shape=input_shape)
        return tf.keras.Model(inputs=[x], outputs=self.call(x))

In [None]:
generator = Generator()
generator.build(input_shape=(1,256,256,1))

In [None]:
class Discriminator(tf.keras.Model):
    def __init__(self, enc_chs=[64,128,256]):
        super().__init__()
        self.initializer = tf.random_normal_initializer(0.0, 0.02)
        self.zero_pad1 = tf.keras.layers.ZeroPadding2D()
        self.zero_pad2 = tf.keras.layers.ZeroPadding2D()
        self.conv1 = tf.keras.layers.Conv2D(512, 4, strides=1, kernel_initializer=self.initializer, use_bias=False)
        self.conv2 = tf.keras.layers.Conv2D(1, 4, strides=1, kernel_initializer=self.initializer, activation='sigmoid')
        
        self.batch_norm1 = tf.keras.layers.BatchNormalization()
        self.batch_norm2 = tf.keras.layers.BatchNormalization()
        
        self.leaky_relu = tf.keras.layers.LeakyReLU()
        self.enc_blocks = [Block(ch) for ch in enc_chs]

    def call(self, x1, x2):
        x = tf.keras.layers.concatenate([x1, x2])
        for enc_block in self.enc_blocks:
            x = enc_block(x)
        
        x = self.zero_pad1(x)
        x = self.conv1(x)
        x = self.batch_norm1(x)
        x = self.leaky_relu(x)
        x = self.zero_pad2(x)
        output = self.conv2(x)

        return output
    
    def model(self, input_shape=[256,256,3]):
        x1 = tf.keras.Input(shape=input_shape)
        x2 = tf.keras.Input(shape=input_shape)

        return tf.keras.Model(inputs=[x1,x2], outputs=self.call(x1,x2))
        

In [None]:
discriminator = Discriminator(enc_chs=[64,128,256])
print(discriminator.model().summary())

In [None]:
def generator_loss(disc_generated_output, gen_output, target, real_labels):
    Lambda = 100

    loss1 = tf.keras.losses.BinaryCrossentropy()
    loss2 = tf.keras.losses.MeanAbsoluteError()

    bce_loss = loss1(real_labels, disc_generated_output)  
    l1_loss = loss2(target, gen_output)

    total_gen_loss = bce_loss + Lambda*l1_loss

    return total_gen_loss, bce_loss, l1_loss

In [None]:
def discriminator_loss(disc_real_output, disc_generated_output, real_labels, fake_labels):

    loss = tf.keras.losses.BinaryCrossentropy()

    bce_loss_real = loss(real_labels, disc_real_output)
    bce_loss_fake = loss(fake_labels, disc_generated_output)
    total_loss = bce_loss_real + bce_loss_fake

    return total_loss

In [None]:
generator = Generator()
discriminator = Discriminator()
generator_optimizer = tf.keras.optimizers.Adam((2e-4), beta_1=0.5, beta_2=0.999)
discriminator_optimizer = tf.keras.optimizers.Adam((2e-4), beta_1=0.5, beta_2=0.999)

In [None]:
@tf.function
def train_step(inputs):
    input_image, target = inputs

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)

        disc_real_output = discriminator(input_image, target)

        disc_generated_output = discriminator(input_image, gen_output, training=True)
        
        real_targets = tf.ones_like(disc_real_output)
        fake_targets = tf.zeros_like(disc_real_output)

        gen_total_loss, gen_gan_loss, l1_loss = generator_loss(disc_generated_output, gen_output, target, real_targets)

        disc_loss = discriminator_loss(disc_real_output, disc_generated_output, real_targets, fake_targets)

        gen_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
        disc_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

        generator_optimizer.apply_gradients(zip(gen_gradients, generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(disc_gradients, discriminator.trainable_variables))

        return gen_gan_loss, l1_loss, disc_loss


In [None]:
EPOCHS = 1
def fit():
    for epoch in range(EPOCHS):
        num_batches = 0
        gan_loss, l1_loss, disc_loss = 0, 0, 0
        for dist_inputs in train_dataset:
            num_batches +=1
            gan_l, l1_l, disc_l = train_step(dist_inputs)
            gan_loss += gan_l
            l1_loss += l1_l
            disc_loss += disc_l

        gan_loss = gan_loss/num_batches
        l1_loss = l1_loss/num_batches
        disc_loss = disc_loss/num_batches

        print(f"Epoch: {epoch}: D_Loss: {disc_loss}: G_Loss: {gan_loss}: l1_loss: {l1_loss}")



In [None]:
fit()

In [None]:
# generator.save_weights('gen'+str(100)+".h5")

In [None]:
test_dataset = tf.data.Dataset.from_tensor_slices(test_sample_paths)
test_dataset = test_dataset.map(preprocess_fn, num_parallel_calls=AUTOTUNE)
test_dataset = test_dataset.shuffle(100)
test_dataset = test_dataset.batch(2)

In [None]:
generator = Generator()
generator.build((1,256,256,3))
generator.load_weights("gen100.h5")

In [None]:
for img, target in test_dataset.take(1):
    preds = generator(img, training=True)


In [None]:
fig, axs = plt.subplots(1,2,figsize=(8,8))
axs[0].imshow(img[0,:,:,:])
axs[1].imshow(preds[0,:,:,:])

#### sources
https://learnopencv.com/paired-image-to-image-translation-pix2pix/