# Pix2Pix Caricature Generation

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/USERNAME/REPO/blob/main/pix2pix_caricature/caricature_training.ipynb)

This notebook implements a Pix2Pix model for face-to-caricature translation.

In [None]:
# Install required packages
!pip install tensorflow tensorflow_addons

In [None]:
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
import matplotlib.pyplot as plt
import os
from google.colab import drive
from datetime import datetime

## Mount Google Drive

In [None]:
# Mount Google Drive
drive.mount('/content/drive')

# Set data directory
DATA_DIR = '/content/drive/MyDrive/caricature Project Diffusion/paired_caricature'
CHECKPOINT_DIR = '/content/drive/MyDrive/caricature_checkpoints'

## Data Loading and Preprocessing

In [None]:
def load_image(image_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image, channels=3)
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1  # Normalize to [-1, 1]
    return image

def load_paired_images(face_path, caricature_path):
    face = load_image(face_path)
    caricature = load_image(caricature_path)
    return face, caricature

# Create pairs of image paths
face_paths = sorted([os.path.join(DATA_DIR, f) for f in os.listdir(DATA_DIR) if f.endswith('_f.png')])
caricature_paths = sorted([os.path.join(DATA_DIR, f) for f in os.listdir(DATA_DIR) if f.endswith('_c.png')])

# Create TensorFlow dataset
dataset = tf.data.Dataset.from_tensor_slices((face_paths, caricature_paths))
dataset = dataset.map(lambda x, y: tf.py_function(
    load_paired_images, [x, y], [tf.float32, tf.float32]),
    num_parallel_calls=tf.data.AUTOTUNE)

## Data Augmentation

In [None]:
def augment(face_image, caricature_image):
    # Stack images to apply same augmentation
    stacked = tf.stack([face_image, caricature_image], axis=0)
    
    # Random horizontal flip
    if tf.random.uniform([]) > 0.5:
        stacked = tf.image.flip_left_right(stacked)
    
    # Random translation
    pad_size = 51  # 10% of 512
    padded = tf.pad(stacked, [[0,0], [pad_size,pad_size], [pad_size,pad_size], [0,0]], mode='REFLECT')
    crop_size = 512
    y = tf.random.uniform([], 0, 2*pad_size, dtype=tf.int32)
    x = tf.random.uniform([], 0, 2*pad_size, dtype=tf.int32)
    stacked = tf.image.crop_to_bounding_box(padded, y, x, crop_size, crop_size)
    
    return stacked[0], stacked[1]

# Apply augmentation to dataset
augmented_dataset = dataset.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
BATCH_SIZE = 1
train_dataset = augmented_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

## Pix2Pix Model Implementation

In [None]:
def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                                      kernel_initializer=initializer, use_bias=False))
    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())
    result.add(tf.keras.layers.LeakyReLU())
    return result

def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=2, padding='same',
                                               kernel_initializer=initializer, use_bias=False))
    result.add(tf.keras.layers.BatchNormalization())
    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))
    result.add(tf.keras.layers.ReLU())
    return result

def Generator():
    inputs = tf.keras.layers.Input(shape=[512, 512, 3])
    
    # Encoder
    down_stack = [
        downsample(64, 4, apply_batchnorm=False),  # (256, 256)
        downsample(128, 4),  # (128, 128)
        downsample(256, 4),  # (64, 64)
        downsample(512, 4),  # (32, 32)
        downsample(512, 4),  # (16, 16)
        downsample(512, 4),  # (8, 8)
        downsample(512, 4),  # (4, 4)
        downsample(512, 4),  # (2, 2)
    ]
    
    # Decoder
    up_stack = [
        upsample(512, 4, apply_dropout=True),  # (4, 4)
        upsample(512, 4, apply_dropout=True),  # (8, 8)
        upsample(512, 4, apply_dropout=True),  # (16, 16)
        upsample(512, 4),  # (32, 32)
        upsample(256, 4),  # (64, 64)
        upsample(128, 4),  # (128, 128)
        upsample(64, 4),   # (256, 256)
    ]
    
    initializer = tf.random_normal_initializer(0., 0.02)
    last = tf.keras.layers.Conv2DTranspose(3, 4, strides=2, padding='same',
                                           kernel_initializer=initializer,
                                           activation='tanh')  # (512, 512)
    
    x = inputs
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)
    
    skips = reversed(skips[:-1])
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x, skip])
    
    x = last(x)
    return tf.keras.Model(inputs=inputs, outputs=x)

def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    
    inp = tf.keras.layers.Input(shape=[512, 512, 3], name='input_image')
    tar = tf.keras.layers.Input(shape=[512, 512, 3], name='target_image')
    
    x = tf.keras.layers.concatenate([inp, tar])
    
    down1 = downsample(64, 4, False)(x)
    down2 = downsample(128, 4)(down1)
    down3 = downsample(256, 4)(down2)
    
    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)
    conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                  kernel_initializer=initializer,
                                  use_bias=False)(zero_pad1)
    
    batchnorm1 = tf.keras.layers.BatchNormalization()(conv)
    leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)
    
    zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)
    last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                  kernel_initializer=initializer)(zero_pad2)
    
    return tf.keras.Model(inputs=[inp, tar], outputs=last)

generator = Generator()
discriminator = Discriminator()

generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

## Loss Functions

In [None]:
LAMBDA = 100

def generator_loss(disc_generated_output, gen_output, target):
    gan_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)(
        tf.ones_like(disc_generated_output), disc_generated_output)
    
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    total_gen_loss = gan_loss + (LAMBDA * l1_loss)
    
    return total_gen_loss, gan_loss, l1_loss

def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)(
        tf.ones_like(disc_real_output), disc_real_output)
    
    generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)(
        tf.zeros_like(disc_generated_output), disc_generated_output)
    
    total_disc_loss = real_loss + generated_loss
    
    return total_disc_loss

## Training

In [None]:
@tf.function
def train_step(input_image, target):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)
        
        disc_real_output = discriminator([input_image, target], training=True)
        disc_generated_output = discriminator([input_image, gen_output], training=True)
        
        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
    
    generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
    generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))
    
    return gen_total_loss, disc_loss

def train(dataset, epochs):
    for epoch in range(epochs):
        start = time.time()
        
        for input_image, target in dataset:
            gen_loss, disc_loss = train_step(input_image, target)
        
        if (epoch + 1) % 20 == 0:
            checkpoint_prefix = os.path.join(CHECKPOINT_DIR, f"ckpt_epoch_{epoch}")
            generator.save_weights(checkpoint_prefix + '_generator.h5')
            discriminator.save_weights(checkpoint_prefix + '_discriminator.h5')
        
        print(f'Epoch {epoch+1}, Gen Loss: {gen_loss:.4f}, Disc Loss: {disc_loss:.4f}')
        print(f'Time taken for epoch {epoch+1}: {time.time()-start:.2f} sec')

# Create checkpoint directory
if not os.path.exists(CHECKPOINT_DIR):
    os.makedirs(CHECKPOINT_DIR)

# Train the model
EPOCHS = 200
train(train_dataset, EPOCHS)

## Inference

In [None]:
def generate_caricature(face_image_path, checkpoint_path):
    # Load the generator weights
    generator.load_weights(checkpoint_path + '_generator.h5')
    
    # Load and preprocess the input image
    input_image = load_image(face_image_path)
    input_image = tf.expand_dims(input_image, 0)
    
    # Generate prediction
    prediction = generator(input_image, training=False)
    
    # Convert the prediction back to uint8 format
    prediction = (prediction * 0.5 + 0.5) * 255
    prediction = tf.cast(prediction, tf.uint8)
    
    return prediction[0]

# Example usage:
# checkpoint_path = os.path.join(CHECKPOINT_DIR, 'ckpt_epoch_199')
# generated_caricature = generate_caricature('path_to_face_image.png', checkpoint_path)
# plt.imshow(generated_caricature)
# plt.axis('off')
# plt.show()