# 🖼️ Task 04: Image-to-Image Translation using Pix2Pix (cGAN)

This notebook demonstrates how to build and train a **conditional Generative Adversarial Network (cGAN)** called **Pix2Pix**, which performs **image-to-image translation** using paired datasets.

---

In [None]:
!pip install tensorflow matplotlib
!pip install tensorflow-datasets

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
import os

## 📚 Load Dataset (Horse → Zebra)

In [None]:
def normalize(input_image, input_mask):
    input_image = tf.cast(input_image, tf.float32) / 127.5 - 1
    return input_image

def load_dataset(name='cycle_gan/horse2zebra', split='train'):
    dataset, metadata = tfds.load(name=name, with_info=True, as_supervised=True)
    train_dataset = dataset.map(lambda x, y: (normalize(x['input_image'], None), normalize(y['target_image'], None)))
    return train_dataset.batch(1).take(100)

## 🧠 Define Generator (U-Net)

In [None]:
def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    result.add(Conv2D(filters, size, strides=2, padding='same',
                      kernel_initializer=initializer, use_bias=False))
    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())
    result.add(tf.keras.layers.LeakyReLU())
    return result

def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    result.add(Conv2DTranspose(filters, size, strides=2,
                               padding='same',
                               kernel_initializer=initializer,
                               use_bias=False))
    result.add(tf.keras.layers.BatchNormalization())
    if apply_dropout:
        result.add(Dropout(0.5))
    result.add(tf.keras.layers.ReLU())
    return result

def Generator():
    inputs = Input(shape=[256, 256, 3])
    down_stack = [
        downsample(64, 4, False),
        downsample(128, 4),
        downsample(256, 4),
        downsample(512, 4),
        downsample(512, 4),
        downsample(512, 4),
        downsample(512, 4),
        downsample(512, 4)
    ]
    up_stack = [
        upsample(512, 4, apply_dropout=True),
        upsample(512, 4, apply_dropout=True),
        upsample(512, 4, apply_dropout=True),
        upsample(512, 4),
        upsample(256, 4),
        upsample(128, 4),
        upsample(64, 4)
    ]
    initializer = tf.random_normal_initializer(0., 0.02)
    last = Conv2DTranspose(3, 4, strides=2, padding='same',
                             kernel_initializer=initializer, activation='tanh')
    x = inputs
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)
    skips = reversed(skips[:-1])
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = Concatenate()([x, skip])
    x = last(x)
    return Model(inputs=inputs, outputs=x)

## 🛡 Define Discriminator (PatchGAN)

In [None]:
def Discriminator():
    inp = Input(shape=[256, 256, 3], name='input_image')
    tar = Input(shape=[256, 256, 3], name='target_image')
    x = Concatenate()([inp, tar])
    down1 = downsample(64, 4, False)(x)
    down2 = downsample(128, 4)(down1)
    down3 = downsample(256, 4)(down2)
    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)
    conv = Conv2D(512, 4, strides=1, kernel_initializer=initializer, use_bias=False)(zero_pad1)
    batchnorm1 = BatchNormalization()(conv)
    leaky_relu = LeakyReLU()(batchnorm1)
    zero_pad2 = ZeroPadding2D()(leaky_relu)
    last = Conv2D(1, 4, strides=1, kernel_initializer=initializer)(zero_pad2)
    return Model(inputs=[inp, tar], outputs=last)

## 🧮 Define Losses and Optimizers

In [None]:
LAMBDA = 100
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real, generated):
    real_loss = loss_object(tf.ones_like(real), real)
    generated_loss = loss_object(tf.zeros_like(generated), generated)
    total_disc_loss = real_loss + generated_loss
    return total_disc_loss

def generator_loss(generated):
    gan_loss = loss_object(tf.ones_like(generated), generated)
    l1_loss = tf.reduce_mean(tf.abs(generated))
    total_gen_loss = gan_loss + LAMBDA * l1_loss
    return total_gen_loss

## 🏋️‍♂️ Train Step

In [None]:
@tf.function
def train_step(input_image, target):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)
        disc_real_output = discriminator([input_image, target], training=True)
        disc_generated_output = discriminator([input_image, gen_output], training=True)
        gen_total_loss = generator_loss(disc_generated_output)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
    generator_gradients = gen_tape.gradient(gen_total_loss,
                                      generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss,
                                          discriminator.trainable_variables)
    generator_optimizer.apply_gradients(zip(generator_gradients,
                                      generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                            discriminator.trainable_variables))
    return gen_total_loss, disc_loss

## 🧪 Training Loop

In [None]:
def train(dataset, epochs=40):
    for epoch in range(epochs):
        print(f'Epoch {epoch+1}/{epochs}')
        for n, (input_image, target) in dataset.enumerate():
            gen_loss, disc_loss = train_step(input_image, target)
            if n % 10 == 0:
                print(f'Step {n}, Gen Loss: {gen_loss:.4f}, Disc Loss: {disc_loss:.4f}')
        if (epoch + 1) % 10 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)
            generate_images(generator, input_image, target, epoch)

## 📊 Visualization

In [None]:
def generate_images(model, test_input, target, epoch=None):
    prediction = model(test_input, training=True)
    plt.figure(figsize=(12, 4))
    display_list = [test_input[0].numpy(), target[0].numpy(), prediction[0].numpy()]
    title = ['Input Image', 'Ground Truth', 'Predicted Image']
    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    if epoch is not None:
        plt.savefig(f"image_at_epoch_{epoch}.png")
    plt.show()

## 🚀 Main Execution

In [None]:
generator = Generator()
discriminator = Discriminator()

generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

train_dataset = load_dataset()

EPOCHS = 40
train(train_dataset, EPOCHS)