# Pix2Pix Caricature Generation

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/USERNAME/REPO/blob/main/aicaricaturist/pix2pix_caricature/caricature_training.ipynb)

This notebook implements a Pix2Pix model for face-to-caricature translation using PyTorch.

# Install required packages

In [None]:
!pip install torch torchvision opencv-python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torchvision.utils import save_image
import os
import glob
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from google.colab import drive
import cv2
import random

## Mount Google Drive

In [None]:
# Mount Google Drive
drive.mount('/content/drive')

# Set data directory
DATA_DIR = '/content/drive/MyDrive/caricature Project Diffusion/paired_caricature'
CHECKPOINT_DIR = '/content/drive/MyDrive/caricature_checkpoints'

# Create checkpoint directory if it doesn't exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Data Loading and Preprocessing

In [None]:
class CaricatureDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.face_paths = sorted(glob.glob(os.path.join(root_dir, '*_f.png')))
        self.caricature_paths = sorted(glob.glob(os.path.join(root_dir, '*_c.png')))
        
        # Base transforms for normalization
        self.normalize = transforms.Normalize((0.5,), (0.5,))
        
        # Color jitter for contrast/brightness variation
        self.color_jitter = transforms.ColorJitter(
            brightness=0.2,
            contrast=0.2,
            saturation=0.0,  # No saturation change for grayscale
            hue=0.0  # No hue change for grayscale
        )
    
    def __len__(self):
        return len(self.face_paths)
    
    def apply_transforms(self, face, caricature):
        # Convert to grayscale first
        face = TF.to_grayscale(face, num_output_channels=1)
        caricature = TF.to_grayscale(caricature, num_output_channels=1)
        
        # Resize to exactly 512x512
        face = TF.resize(face, (512, 512), interpolation=TF.InterpolationMode.BILINEAR)
        caricature = TF.resize(caricature, (512, 512), interpolation=TF.InterpolationMode.BILINEAR)
        
        # Create a larger white canvas (600x600) to allow for translation
        canvas_size = 600
        face_canvas = Image.new('L', (canvas_size, canvas_size), 255)
        caricature_canvas = Image.new('L', (canvas_size, canvas_size), 255)
        
        # Random translation
        max_shift = canvas_size - 512
        x_shift = random.randint(0, max_shift)
        y_shift = random.randint(0, max_shift)
        face_canvas.paste(face, (x_shift, y_shift))
        caricature_canvas.paste(caricature, (x_shift, y_shift))
        
        # Crop back to 512x512 from the center
        start = (canvas_size - 512) // 2
        end = start + 512
        face = TF.crop(face_canvas, start, start, 512, 512)
        caricature = TF.crop(caricature_canvas, start, start, 512, 512)
        
        # Random rotation (-15 to 15 degrees)
        angle = random.uniform(-15, 15)
        face = TF.rotate(face, angle, fill=255)
        caricature = TF.rotate(caricature, angle, fill=255)
        
        # Random horizontal flip
        if random.random() > 0.5:
            face = TF.hflip(face)
            caricature = TF.hflip(caricature)
        
        # Apply color jitter
        if random.random() > 0.5:
            face = self.color_jitter(face)
            caricature = self.color_jitter(caricature)
        
        return face, caricature
    
    def __getitem__(self, idx):
        face_path = self.face_paths[idx]
        caricature_path = self.caricature_paths[idx]
        
        # Load images
        face = Image.open(face_path).convert('RGB')
        caricature = Image.open(caricature_path).convert('RGB')
        
        # Apply consistent transforms to both images
        face, caricature = self.apply_transforms(face, caricature)
        
        # Convert face to edges using Canny
        face_np = np.array(face)
        face_edges = cv2.Canny(face_np, 100, 200)
        face_edges = Image.fromarray(face_edges)
        
        # Convert to tensors and normalize
        face_tensor = TF.to_tensor(face_edges)
        caricature_tensor = TF.to_tensor(caricature)
        
        face_tensor = self.normalize(face_tensor)
        caricature_tensor = self.normalize(caricature_tensor)
        
        return face_tensor, caricature_tensor

# Create dataset and dataloader with smaller batch size
dataset = CaricatureDataset(DATA_DIR)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)

## Model Implementation

In [None]:
def init_weights(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True, dropout=0.0):
        super().__init__()
        layers = [nn.Conv2d(in_channels, out_channels, 4, stride=2, padding=1, bias=False)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2))
        if dropout > 0:
            layers.append(nn.Dropout(dropout))
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class UNetUpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.0):
        super().__init__()
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ]
        if dropout > 0:
            layers.append(nn.Dropout(dropout))
        self.block = nn.Sequential(*layers)

    def forward(self, x, skip):
        x = self.block(x)
        x = torch.cat([x, skip], 1)
        return x

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Modified input channels to 1 for grayscale
        self.down1 = UNetBlock(1, 64, normalize=False)
        self.down2 = UNetBlock(64, 128)
        self.down3 = UNetBlock(128, 256)
        self.down4 = UNetBlock(256, 512)
        self.down5 = UNetBlock(512, 512)
        self.down6 = UNetBlock(512, 512)
        self.down7 = UNetBlock(512, 512)
        self.down8 = UNetBlock(512, 512)
        
        self.up1 = UNetUpBlock(512, 512, dropout=0.5)
        self.up2 = UNetUpBlock(1024, 512, dropout=0.5)
        self.up3 = UNetUpBlock(1024, 512, dropout=0.5)
        self.up4 = UNetUpBlock(1024, 512)
        self.up5 = UNetUpBlock(1024, 256)
        self.up6 = UNetUpBlock(512, 128)
        self.up7 = UNetUpBlock(256, 64)
        
        # Modified output channels to 1 for grayscale
        self.final = nn.Sequential(
            nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)
        
        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)
        
        return self.final(u7)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        def discriminator_block(in_channels, out_channels, normalize=True):
            layers = [nn.Conv2d(in_channels, out_channels, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_channels))
            layers.append(nn.LeakyReLU(0.2))
            return layers
        
        # Modified input channels to 2 (1 for face + 1 for caricature)
        self.model = nn.Sequential(
            *discriminator_block(2, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, face, caricature):
        img_input = torch.cat([face, caricature], 1)
        return self.model(img_input)

# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Initialize weights
generator.apply(init_weights)
discriminator.apply(init_weights)

# Loss functions
criterion_GAN = nn.BCEWithLogitsLoss()
criterion_L1 = nn.L1Loss()

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

## Training

In [None]:
losses_D = []
losses_G = []

def train_model(num_epochs=1000):
    for epoch in range(num_epochs):
        running_loss_D = 0.0
        running_loss_G = 0.0

        for i, (faces, caricatures) in enumerate(dataloader):
            faces = faces.to(device)
            caricatures = caricatures.to(device)
            
            # Train Generator
            optimizer_G.zero_grad()
            gen_caricatures = generator(faces)
            
            # GAN loss
            pred_fake = discriminator(faces, gen_caricatures)
            # Create ground truth labels matching discriminator output size
            valid = torch.ones_like(pred_fake, requires_grad=False)
            fake = torch.zeros_like(pred_fake, requires_grad=False)
            
            loss_GAN = criterion_GAN(pred_fake, valid)
            
            # L1 loss
            loss_L1 = criterion_L1(gen_caricatures, caricatures)
            
            # Total loss
            loss_G = loss_GAN + 100 * loss_L1
            loss_G.backward()
            optimizer_G.step()
            
            # Train Discriminator
            optimizer_D.zero_grad()
            
            # Real loss
            pred_real = discriminator(faces, caricatures)
            loss_real = criterion_GAN(pred_real, valid)
            
            # Fake loss
            pred_fake = discriminator(faces, gen_caricatures.detach())
            loss_fake = criterion_GAN(pred_fake, fake)
            
            # Total discriminator loss
            loss_D = (loss_real + loss_fake) / 2
            loss_D.backward()
            optimizer_D.step()

            # Accumulate batch losses
            running_loss_D += loss_D.item()
            running_loss_G += loss_G.item()
            
            if i % 100 == 0:
                print(f'[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] '
                      f'[D loss: {loss_D.item():.4f}] [G loss: {loss_G.item():.4f}]')
        
        # Compute average loss over this epoch
        epoch_loss_D = running_loss_D / len(dataloader)
        epoch_loss_G = running_loss_G / len(dataloader)

        # Store for plotting
        losses_D.append(epoch_loss_D)
        losses_G.append(epoch_loss_G)
        
        if (epoch + 1) % 20 == 0:
            # Save model checkpoints
            torch.save({
                'epoch': epoch,
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
            }, os.path.join(CHECKPOINT_DIR, f'checkpoint_epoch_{epoch}.pt'))

# Train the model
train_model()

In [None]:
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss Over Epochs")
plt.plot(losses_G, label="G Loss")
plt.plot(losses_D, label="D Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

## Inference

In [None]:
def generate_caricature(face_path, checkpoint_path):
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path)
    generator.load_state_dict(checkpoint['generator_state_dict'])
    generator.eval()
    
    # Load and preprocess image
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    
    # Load and convert to edges
    face = Image.open(face_path).convert('RGB')
    face_np = np.array(face)
    face_gray = cv2.cvtColor(face_np, cv2.COLOR_RGB2GRAY)
    face_edges = cv2.Canny(face_gray, 100, 200)
    face_edges = Image.fromarray(face_edges)
    
    face_tensor = transform(face_edges).unsqueeze(0).to(device)
    
    with torch.no_grad():
        generated = generator(face_tensor)
        generated = (generated * 0.5 + 0.5).clamp(0, 1)
        generated = generated.squeeze().cpu().numpy()
    
    return generated

# Example usage:
path = '/content/drive/MyDrive/caricature Project Diffusion/test_01.png'
checkpoint_path = os.path.join(CHECKPOINT_DIR, 'checkpoint_epoch_199.pt')
generated = generate_caricature(path, checkpoint_path)
plt.imshow(generated, cmap='gray')
plt.axis('off')
plt.show()