In [1]:
# ==================== Imports ====================
import os
import itertools
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt

# ==================== Paths ====================
base_path = "/kaggle/input/person-face-sketches"
train_real_dir = f"{base_path}/train/photos"
train_sketch_dir = f"{base_path}/train/sketches"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==================== Dataset ====================
class FaceSketchDataset(Dataset):
    def __init__(self, real_dir, sketch_dir, transform=None):
        self.real_images = sorted(os.listdir(real_dir))
        self.sketch_images = sorted(os.listdir(sketch_dir))
        self.real_dir = real_dir
        self.sketch_dir = sketch_dir
        self.transform = transform

    def __len__(self):
        return min(len(self.real_images), len(self.sketch_images))

    def __getitem__(self, idx):
        real_img = Image.open(os.path.join(self.real_dir, self.real_images[idx])).convert("RGB")
        sketch_img = Image.open(os.path.join(self.sketch_dir, self.sketch_images[idx])).convert("RGB")
        
        if self.transform:
            real_img = self.transform(real_img)
            sketch_img = self.transform(sketch_img)

        return {"real": real_img, "sketch": sketch_img}

transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Reduced size from 256x256 to 128x128
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_dataset = FaceSketchDataset(train_real_dir, train_sketch_dir, transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, pin_memory=True)

# ==================== Generator ====================
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        def downsample(in_feat, out_feat, norm=True):
            layers = [nn.Conv2d(in_feat, out_feat, 4, 2, 1)]
            if norm:
                layers.append(nn.InstanceNorm2d(out_feat))
            layers.append(nn.ReLU(inplace=True))
            return layers

        def upsample(in_feat, out_feat):
            return [
                nn.ConvTranspose2d(in_feat, out_feat, 4, 2, 1),
                nn.InstanceNorm2d(out_feat),
                nn.ReLU(inplace=True)
            ]

        self.model = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(3, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),

            *downsample(64, 128),
            *downsample(128, 256),

            *[ResidualBlock(256) for _ in range(3)],  # Reduced from 6 to 3

            *upsample(256, 128),
            *upsample(128, 64),

            nn.ReflectionPad2d(3),
            nn.Conv2d(64, 3, 7),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

# ==================== Discriminator ====================
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        def layer(in_c, out_c, norm=True):
            layers = [nn.Conv2d(in_c, out_c, 4, 2, 1)]
            if norm:
                layers.append(nn.InstanceNorm2d(out_c))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *layer(3, 64, norm=False),
            *layer(64, 128),
            *layer(128, 256),
            *layer(256, 512),
            nn.Conv2d(512, 1, 4, 1, 1)
        )

    def forward(self, x):
        return self.model(x)

# ==================== Initialize Models ====================
G_S2R = Generator().to(device)  # Sketch → Real
G_R2S = Generator().to(device)  # Real → Sketch
D_R = Discriminator().to(device)
D_S = Discriminator().to(device)

# ==================== Losses ====================
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

# ==================== Optimizers ====================
lr = 0.0002
beta1 = 0.5

optimizer_G = optim.Adam(itertools.chain(G_S2R.parameters(), G_R2S.parameters()), lr=lr, betas=(beta1, 0.999))
optimizer_D_R = optim.Adam(D_R.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D_S = optim.Adam(D_S.parameters(), lr=lr, betas=(beta1, 0.999))

# ==================== Training ====================
EPOCHS = 3
save_dir = "/kaggle/working/saved_models"
os.makedirs(save_dir, exist_ok=True)

for epoch in range(1, EPOCHS + 1):
    for i, batch in enumerate(train_loader):
        real = batch["real"].to(device)
        sketch = batch["sketch"].to(device)

        # === Train Generators ===
        optimizer_G.zero_grad()

        fake_real = G_S2R(sketch)
        pred_fake_real = D_R(fake_real)
        valid = torch.ones_like(pred_fake_real, device=device)
        fake = torch.zeros_like(pred_fake_real, device=device)
        loss_GAN_S2R = criterion_GAN(pred_fake_real, valid)

        fake_sketch = G_R2S(real)
        pred_fake_sketch = D_S(fake_sketch)
        valid = torch.ones_like(pred_fake_sketch, device=device)
        loss_GAN_R2S = criterion_GAN(pred_fake_sketch, valid)

        recovered_real = G_S2R(fake_sketch)
        loss_cycle_real = criterion_cycle(recovered_real, real)

        recovered_sketch = G_R2S(fake_real)
        loss_cycle_sketch = criterion_cycle(recovered_sketch, sketch)

        loss_identity_real = criterion_identity(G_S2R(real), real)
        loss_identity_sketch = criterion_identity(G_R2S(sketch), sketch)

        loss_G = (
            loss_GAN_S2R + loss_GAN_R2S
            + 10.0 * (loss_cycle_real + loss_cycle_sketch)
            + 5.0 * (loss_identity_real + loss_identity_sketch)
        )
        loss_G.backward()
        optimizer_G.step()

        # === Train Discriminator R ===
        optimizer_D_R.zero_grad()
        pred_real = D_R(real)
        valid = torch.ones_like(pred_real, device=device)
        fake = torch.zeros_like(pred_real, device=device)
        loss_D_real = criterion_GAN(pred_real, valid)

        pred_fake = D_R(fake_real.detach())
        loss_D_fake = criterion_GAN(pred_fake, fake)

        loss_D_R_total = 0.5 * (loss_D_real + loss_D_fake)
        loss_D_R_total.backward()
        optimizer_D_R.step()

        # === Train Discriminator S ===
        optimizer_D_S.zero_grad()
        pred_real = D_S(sketch)
        valid = torch.ones_like(pred_real, device=device)
        fake = torch.zeros_like(pred_real, device=device)
        loss_D_real = criterion_GAN(pred_real, valid)

        pred_fake = D_S(fake_sketch.detach())
        loss_D_fake = criterion_GAN(pred_fake, fake)

        loss_D_S_total = 0.5 * (loss_D_real + loss_D_fake)
        loss_D_S_total.backward()
        optimizer_D_S.step()

        if i % 100 == 0:
            print(f"[Epoch {epoch}/{EPOCHS}] [Batch {i}/{len(train_loader)}] "
                  f"[G loss: {loss_G.item():.4f}] "
                  f"[D_R loss: {loss_D_R_total.item():.4f}] "
                  f"[D_S loss: {loss_D_S_total.item():.4f}]")

    # Save generated samples
    save_image(fake_real * 0.5 + 0.5, f"/kaggle/working/fake_real_epoch_{epoch}.png")
    save_image(fake_sketch * 0.5 + 0.5, f"/kaggle/working/fake_sketch_epoch_{epoch}.png")

    # Save models
    torch.save(G_S2R.state_dict(), f"{save_dir}/G_S2R_epoch_{epoch}.pth")
    torch.save(G_R2S.state_dict(), f"{save_dir}/G_R2S_epoch_{epoch}.pth")
    torch.save(D_R.state_dict(), f"{save_dir}/D_R_epoch_{epoch}.pth")
    torch.save(D_S.state_dict(), f"{save_dir}/D_S_epoch_{epoch}.pth")


[Epoch 1/3] [Batch 0/2582] [G loss: 22.3031] [D_R loss: 0.6487] [D_S loss: 0.7607]
[Epoch 1/3] [Batch 100/2582] [G loss: 5.5232] [D_R loss: 0.2255] [D_S loss: 0.1530]
[Epoch 1/3] [Batch 200/2582] [G loss: 4.9291] [D_R loss: 0.3141] [D_S loss: 0.1451]
[Epoch 1/3] [Batch 300/2582] [G loss: 5.0709] [D_R loss: 0.2061] [D_S loss: 0.1806]
[Epoch 1/3] [Batch 400/2582] [G loss: 3.9239] [D_R loss: 0.2555] [D_S loss: 0.1323]
[Epoch 1/3] [Batch 500/2582] [G loss: 4.4528] [D_R loss: 0.2359] [D_S loss: 0.1401]
[Epoch 1/3] [Batch 600/2582] [G loss: 3.9457] [D_R loss: 0.1988] [D_S loss: 0.1777]
[Epoch 1/3] [Batch 700/2582] [G loss: 4.0915] [D_R loss: 0.2212] [D_S loss: 0.1403]
[Epoch 1/3] [Batch 800/2582] [G loss: 4.1181] [D_R loss: 0.1736] [D_S loss: 0.0601]
[Epoch 1/3] [Batch 900/2582] [G loss: 3.6196] [D_R loss: 0.2351] [D_S loss: 0.1396]
[Epoch 1/3] [Batch 1000/2582] [G loss: 3.7641] [D_R loss: 0.2116] [D_S loss: 0.1935]
[Epoch 1/3] [Batch 1100/2582] [G loss: 3.9031] [D_R loss: 0.2519] [D_S loss: