# Pix2Pix - Image-to-Image Translation (Ultimate v4)

Ce notebook implémente une version **très avancée** de Pix2Pix/GAN pour une qualité maximale.

## Nouveautés v4 (High Fidelity) :
1.  **Perceptual Loss (VGG19)** : Au lieu de comparer les pixels (flou), on compare les "features" extraites par un VGG pré-entraîné. Cela force le générateur à respecter la structure et la texture.
2.  **Replay Buffer** : Le discriminateur s'entraîne sur un historique d'images générées pour éviter les oscillations.
3.  **ResNet Generator (9 blocks)** : Architecture plus profonde et puissante que le U-Net classique (inspiré de CycleGAN/Pix2PixHD) pour une meilleure propagation de l'information.
4.  **TTUR** : Learning Rate différent pour le Générateur (2e-4) et le Discriminateur (4e-4).
5.  **Optimisations** : Kaiming Init, Num Workers = 4.

In [89]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.utils import save_image
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import random

## 1. Configuration

In [90]:
class Config:
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    # TTUR: Deux Learning Rates differents
    LR_G = 2e-4
    LR_D = 4e-4
    BATCH_SIZE = 1 
    NUM_WORKERS = 4 # Optimisation CPU
    IMAGE_SIZE = 256
    CHANNELS_IMG = 3
    L1_LAMBDA = 10 
    VGG_LAMBDA = 10 # Poids de la Perceptual Loss
    NUM_EPOCHS = 7
    LOAD_MODEL = False
    SAVE_MODEL = True
    CHECKPOINT_DISC = "disc_v4.pth.tar"
    CHECKPOINT_GEN = "gen_v4.pth.tar"
    
    TRAIN_DIR = "/kaggle/input/pix2pix-dataset/edges2shoes/edges2shoes/train"
    VAL_DIR = "/kaggle/input/pix2pix-dataset/edges2shoes/edges2shoes/val"
    
    TRAIN_SIZE_LIMIT = None 
    
print(f"Device used: {Config.DEVICE}")

Device used: cuda


## 2. Pipeline de Données

In [91]:
class Pix2PixDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        if os.path.exists(root_dir):
            self.list_files = os.listdir(root_dir)
        else:
            print(f"Attention: Le dossier {root_dir} n'existe pas.")
            self.list_files = []

    def __len__(self):
        return len(self.list_files)

    def __getitem__(self, index):
        img_file = self.list_files[index]
        img_path = os.path.join(self.root_dir, img_file)
        try:
            image = np.array(Image.open(img_path))
        except:
            return self.__getitem__(index + 1)

        w = image.shape[1]
        cutoff = w // 2
        input_image = image[:, :cutoff, :] 
        target_image = image[:, cutoff:, :]

        input_image = Image.fromarray(input_image)
        target_image = Image.fromarray(target_image)
        
        # Augmentation plus agressive
        resize = transforms.Resize((286, 286))
        input_image = resize(input_image)
        target_image = resize(target_image)
        
        i, j, h, w_crop = transforms.RandomCrop.get_params(
            input_image, output_size=(Config.IMAGE_SIZE, Config.IMAGE_SIZE)
        )
        input_image = transforms.functional.crop(input_image, i, j, h, w_crop)
        target_image = transforms.functional.crop(target_image, i, j, h, w_crop)
        
        if torch.rand(1) > 0.5:
            input_image = transforms.functional.hflip(input_image)
            target_image = transforms.functional.hflip(target_image)
            
        base_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])
        
        input_image = base_transform(input_image)
        target_image = base_transform(target_image)

        return input_image, target_image

## 3. Utilitaires Avancés : ReplayBuffer & VGG Loss

In [92]:
class ReplayBuffer:
    def __init__(self, max_size=50):
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return torch.cat(to_return)

class VGGLoss(nn.Module):
    def __init__(self):
        super().__init__()
        # Utilisation de VGG19 pre-entraine, jusqu'a la couche ReLU5_4
        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features
        self.slice = nn.Sequential()
        # On garde les 35 premieres couches (perceptual features)
        for i in range(35):
            self.slice.add_module(str(i), vgg[i])
        self.slice.eval() # Freeze
        for param in self.slice.parameters():
            param.requires_grad = False

    def forward(self, x, y):
        # Les images doivent etre denormalisees de [-1, 1] vers [0, 1] puis normalisees ImageNet si besoin
        # Ici on suppose que le VGG est robuste, on passe juste les images.
        x_vgg = self.slice(x)
        y_vgg = self.slice(y)
        loss = nn.MSELoss()(x_vgg, y_vgg)
        return loss

## 4. Architecture ResNet Generator (9 Blocks) & Discriminator

In [93]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, padding_mode="reflect"),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, padding_mode="reflect"),
            nn.InstanceNorm2d(channels),
        )

    def forward(self, x):
        return x + self.block(x)

class ResNetGenerator(nn.Module):
    def __init__(self, in_channels=3, features=64, num_residuals=9):
        super().__init__()
        # Initial Conv
        model = [
            nn.Conv2d(in_channels, features, kernel_size=7, padding=3, padding_mode="reflect"),
            nn.InstanceNorm2d(features),
            nn.ReLU(inplace=True),
        ]
        # Downsampling
        in_features = features
        for _ in range(2):
            out_features = in_features * 2
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features

        # Residual Blocks
        for _ in range(num_residuals):
            model += [ResidualBlock(in_features)]

        # Upsampling
        for _ in range(2):
            out_features = in_features // 2
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features

        # Output Layer
        model += [
            nn.Conv2d(features, 3, kernel_size=7, padding=3, padding_mode="reflect"),
            nn.Tanh(),
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        # Discriminator standart PatchGAN
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels * 2, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
            nn.LeakyReLU(0.2, inplace=True),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, feature, 4, stride=1 if feature == features[-1] else 2, padding=1, bias=False, padding_mode="reflect"),
                    nn.InstanceNorm2d(feature, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                )
            )
            in_channels = feature

        layers.append(
            nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect")
        )

        self.model = nn.Sequential(*layers)

    def forward(self, x, y):
        x = torch.cat([x, y], dim=1)
        x = self.initial(x)
        return self.model(x)

## 5. Training Loop v4

In [94]:
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.kaiming_normal_(m.weight, a=0.02, mode='fan_in', nonlinearity='leaky_relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight, 1.0, 0.02)
        nn.init.constant_(m.bias, 0)

def train_fn(disc, gen, loader, opt_disc, opt_gen, l1_loss, mse_loss, vgg_loss, g_scaler, d_scaler, replay_buffer):
    loop = tqdm(loader, leave=True)

    for idx, (x, y) in enumerate(loop):
        x, y = x.to(Config.DEVICE), y.to(Config.DEVICE)
        
        # --- Generator Forward ---
        with torch.amp.autocast('cuda'):
            y_fake = gen(x)
            
            # Generator Loss
            D_fake = disc(x, y_fake)
            G_gan_loss = mse_loss(D_fake, torch.ones_like(D_fake))
            
            G_l1_loss = l1_loss(y_fake, y) * Config.L1_LAMBDA
            G_vgg_loss = vgg_loss(y_fake, y) * Config.VGG_LAMBDA
            
            G_loss = G_gan_loss + G_l1_loss + G_vgg_loss

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()
        
        # --- Discriminator Forward ---
        with torch.amp.autocast('cuda'):
            # Replay Buffer: On recupere un mix d'images fake actuelles et passees
            y_fake_buffer = replay_buffer.push_and_pop(y_fake.detach())
            
            D_real = disc(x, y)
            D_fake = disc(x, y_fake_buffer)
            
            D_real_loss = mse_loss(D_real, torch.ones_like(D_real) * 0.9) # Label smoothing
            D_fake_loss = mse_loss(D_fake, torch.zeros_like(D_fake))
            
            D_loss = (D_real_loss + D_fake_loss) * 0.5

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        if idx % 10 == 0:
            loop.set_postfix(D=D_loss.item(), G=G_loss.item(), VGG=G_vgg_loss.item())

## 6. Main

In [95]:
def save_some_examples(gen, val_loader, epoch, folder):
    x, y = next(iter(val_loader))
    x, y = x.to(Config.DEVICE), y.to(Config.DEVICE)
    if not os.path.exists(folder):
        os.makedirs(folder)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5
        x = x * 0.5 + 0.5
        y = y * 0.5 + 0.5
        save_image(y_fake, folder + f"/gen_{epoch}.png")
        save_image(x, folder + f"/input_{epoch}.png")
        save_image(y, folder + f"/label_{epoch}.png")
    gen.train()

def main():
    disc = Discriminator().to(Config.DEVICE)
    gen = ResNetGenerator(num_residuals=9).to(Config.DEVICE)
    
    disc.apply(weights_init)
    gen.apply(weights_init)
    
    # TTUR: LRs differents
    opt_disc = optim.Adam(disc.parameters(), lr=Config.LR_D, betas=(0.5, 0.999))
    opt_gen = optim.Adam(gen.parameters(), lr=Config.LR_G, betas=(0.5, 0.999))

    scheduler_G = optim.lr_scheduler.StepLR(opt_gen, step_size=20, gamma=0.5)
    scheduler_D = optim.lr_scheduler.StepLR(opt_disc, step_size=20, gamma=0.5)

    MSE_LOSS = nn.MSELoss()
    L1_LOSS = nn.L1Loss()
    VGG_LOSS = VGGLoss().to(Config.DEVICE)

    train_dataset = Pix2PixDataset(root_dir=Config.TRAIN_DIR)
    
    if Config.TRAIN_SIZE_LIMIT and len(train_dataset) > Config.TRAIN_SIZE_LIMIT:
        indices = torch.randperm(len(train_dataset))[:Config.TRAIN_SIZE_LIMIT]
        train_dataset = torch.utils.data.Subset(train_dataset, indices)

    if len(train_dataset) > 0:
        train_loader = DataLoader(
            train_dataset,
            batch_size=Config.BATCH_SIZE,
            shuffle=True,
            num_workers=Config.NUM_WORKERS,
            pin_memory=True,
        )
        val_dataset = Pix2PixDataset(root_dir=Config.VAL_DIR)
        val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False) if len(val_dataset) > 0 else train_loader

        g_scaler = torch.amp.GradScaler('cuda')
        d_scaler = torch.amp.GradScaler('cuda')
        replay_buffer = ReplayBuffer()

        for epoch in range(Config.NUM_EPOCHS):
            print(f"Epoch {epoch}/{Config.NUM_EPOCHS} | LR_G: {scheduler_G.get_last_lr()[0]:.6f}")
            train_fn(
                disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, MSE_LOSS, VGG_LOSS, g_scaler, d_scaler, replay_buffer
            )
            
            scheduler_G.step()
            scheduler_D.step()
            
            if Config.SAVE_MODEL and epoch % 5 == 0:
                save_some_examples(gen, val_loader, epoch, folder="evaluation")
                torch.save(gen.state_dict(), Config.CHECKPOINT_GEN)
                torch.save(disc.state_dict(), Config.CHECKPOINT_DISC)
    else:
        print("No data found.")

if __name__ == "__main__":
    main()

Epoch 0/7 | LR_G: 0.000200


100%|██████████| 49825/49825 [53:55<00:00, 15.40it/s, D=0.00181, G=11.7, VGG=9.03] 


Epoch 1/7 | LR_G: 0.000200


100%|██████████| 49825/49825 [53:46<00:00, 15.44it/s, D=0.000111, G=23.3, VGG=19.9]


Epoch 2/7 | LR_G: 0.000200


100%|██████████| 49825/49825 [53:44<00:00, 15.45it/s, D=0.000293, G=12.3, VGG=7.96]


Epoch 3/7 | LR_G: 0.000200


100%|██████████| 49825/49825 [53:44<00:00, 15.45it/s, D=0.0002, G=13.7, VGG=11.3]  


Epoch 4/7 | LR_G: 0.000200


100%|██████████| 49825/49825 [53:43<00:00, 15.46it/s, D=2.27e-5, G=21.5, VGG=18.3] 


Epoch 5/7 | LR_G: 0.000200


100%|██████████| 49825/49825 [53:55<00:00, 15.40it/s, D=2.51e-5, G=11, VGG=8.85]   


Epoch 6/7 | LR_G: 0.000200


  6%|▋         | 3192/49825 [03:26<50:19, 15.45it/s, D=7.43e-5, G=10.7, VGG=8.78] 


KeyboardInterrupt: 

## 7. Inference & Custom Sketch

In [98]:
def predict_custom_sketch(image_path):
    print(f"--- Prediction sur {image_path} ---")
    gen = ResNetGenerator(num_residuals=9).to(Config.DEVICE)
    if os.path.exists(Config.CHECKPOINT_GEN):
        gen.load_state_dict(torch.load(Config.CHECKPOINT_GEN, map_location=Config.DEVICE))
        gen.eval()
    else:
        print("Erreur: Pas de checkpoint.")
        return

    if not os.path.exists(image_path):
        print("Image introuvable")
        return
        
    img = Image.open(image_path).convert("RGB")
    t = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
    ])
    x = t(img).unsqueeze(0).to(Config.DEVICE)
    with torch.no_grad():
        res = gen(x).squeeze().cpu() * 0.5 + 0.5
        plt.imshow(res.permute(1, 2, 0))
        plt.axis("off")
        plt.show()