In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install -q kaggle

In [None]:
from google.colab import files
files.upload()  # Upload your kaggle.json

!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

Saving kaggle.json to kaggle (1).json


train/photos: 20655 images
train/sketches: 20655 images
val/photos: 1000 images
val/sketches: 1000 images
test/photos: 679 images
test/sketches: 679 images


In [None]:
!kaggle datasets download -d almightyj/person-face-sketches -p /content/dataset --force
!unzip -q /content/dataset/person-face-sketches.zip -d /content/dataset

Dataset URL: https://www.kaggle.com/datasets/almightyj/person-face-sketches
License(s): CC0-1.0
Downloading person-face-sketches.zip to /content/dataset
 98% 1.26G/1.29G [00:06<00:00, 256MB/s]
100% 1.29G/1.29G [00:08<00:00, 169MB/s]
replace /content/dataset/test/photos/10132.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: A
A
ll


In [None]:
import os

splits = ['train', 'val', 'test']
for split in splits:
    photos_path = f'/content/dataset/{split}/photos'
    sketches_path = f'/content/dataset/{split}/sketches'

    if os.path.exists(photos_path):
        print(f"{split}/photos: {len(os.listdir(photos_path))} images")
    if os.path.exists(sketches_path):
        print(f"{split}/sketches: {len(os.listdir(sketches_path))} images")

train/photos: 20655 images
train/sketches: 20655 images
val/photos: 1000 images
val/sketches: 1000 images
test/photos: 679 images
test/sketches: 679 images


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import os
import random
import itertools
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [None]:
num_epochs = 25
batch_size = 1
lr = 0.0002
beta1 = 0.5
beta2 = 0.999
lambda_cycle = 10.0
lambda_id = 5.0
img_size = 256
channels = 3
num_residual = 9  # For generator
sample_interval = 500  # Save samples every N batches
checkpoint_dir = '/content/drive/MyDrive/cyclegan_checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs('samples', exist_ok=True)
max_samples = 5000  # Optimize: use subset to speed up

In [None]:
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root, transform=None, max_samples=None):
        self.transform = transform
        self.files = sorted([os.path.join(root, f) for f in os.listdir(root) if f.endswith(('.jpg', '.png'))])
        if max_samples:
            self.files = self.files[:max_samples]
        random.shuffle(self.files)

    def __getitem__(self, index):
        img = Image.open(self.files[index]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img

    def __len__(self):
        return len(self.files)

In [None]:
dir_A = '/content/dataset/train/photos'  # Real faces
dir_B = '/content/dataset/train/sketches'  # Sketches

dataset_A = ImageDataset(dir_A, transform=transform, max_samples=max_samples)
dataset_B = ImageDataset(dir_B, transform=transform, max_samples=max_samples)

print(f'Loaded {len(dataset_A)} images for domain A (faces)')
print(f'Loaded {len(dataset_B)} images for domain B (sketches)')

dataloader_A = DataLoader(dataset_A, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
dataloader_B = DataLoader(dataset_B, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

Loaded 5000 images for domain A (faces)
Loaded 5000 images for domain B (sketches)


In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels, in_channels, 3),
            nn.InstanceNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels, in_channels, 3),
            nn.InstanceNorm2d(in_channels)
        )

    def forward(self, x):
        return x + self.block(x)

In [None]:
class Generator(nn.Module):
    def __init__(self, num_residual=9):
        super().__init__()
        self.model = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(3, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True),
            *[ResidualBlock(256) for _ in range(num_residual)],
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, 3, 7),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        def block(in_c, out_c, norm=True):
            layers = [nn.Conv2d(in_c, out_c, 4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True)]
            if norm:
                layers.insert(1, nn.InstanceNorm2d(out_c))
            return layers

        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            *block(64, 128),
            *block(128, 256),
            *block(256, 512),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, x):
        return self.model(x)

In [None]:
G = Generator(num_residual).to(device)  # Face -> Sketch
F = Generator(num_residual).to(device)  # Sketch -> Face
D_A = Discriminator().to(device)  # Disc for faces
D_B = Discriminator().to(device)  # Disc for sketches

In [None]:
optimizer_G = optim.Adam(itertools.chain(G.parameters(), F.parameters()), lr=lr, betas=(beta1, beta2))
optimizer_D_A = optim.Adam(D_A.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D_B = optim.Adam(D_B.parameters(), lr=lr, betas=(beta1, beta2))

criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_id = nn.L1Loss()

In [None]:
start_epoch = 0
checkpoint_path = os.path.join(checkpoint_dir, 'latest.pth')
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    G.load_state_dict(checkpoint['G'])
    F.load_state_dict(checkpoint['F'])
    D_A.load_state_dict(checkpoint['D_A'])
    D_B.load_state_dict(checkpoint['D_B'])
    optimizer_G.load_state_dict(checkpoint['opt_G'])
    optimizer_D_A.load_state_dict(checkpoint['opt_D_A'])
    optimizer_D_B.load_state_dict(checkpoint['opt_D_B'])
    start_epoch = checkpoint['epoch'] + 1
    print(f'Resuming from epoch {start_epoch}')

In [None]:
dataloader_A_cycle = itertools.cycle(dataloader_A)
dataloader_B_cycle = itertools.cycle(dataloader_B)

num_batches = max(len(dataloader_A), len(dataloader_B))

for epoch in range(start_epoch, num_epochs):
    progress_bar = tqdm(range(num_batches))
    for i in progress_bar:
        real_A = next(dataloader_A_cycle).to(device)
        real_B = next(dataloader_B_cycle).to(device)

        # Generators
        optimizer_G.zero_grad()
        fake_B = G(real_A)
        loss_GAN_G = criterion_GAN(D_B(fake_B), torch.ones_like(D_B(fake_B)))
        fake_A = F(real_B)
        loss_GAN_F = criterion_GAN(D_A(fake_A), torch.ones_like(D_A(fake_A)))
        rec_A = F(fake_B)
        loss_cycle_A = criterion_cycle(rec_A, real_A) * lambda_cycle
        rec_B = G(fake_A)
        loss_cycle_B = criterion_cycle(rec_B, real_B) * lambda_cycle
        id_A = F(real_A)
        loss_id_A = criterion_id(id_A, real_A) * lambda_id
        id_B = G(real_B)
        loss_id_B = criterion_id(id_B, real_B) * lambda_id
        loss_G = loss_GAN_G + loss_GAN_F + loss_cycle_A + loss_cycle_B + loss_id_A + loss_id_B
        loss_G.backward()
        optimizer_G.step()

        # Disc A
        optimizer_D_A.zero_grad()
        loss_real_A = criterion_GAN(D_A(real_A), torch.ones_like(D_A(real_A)))
        loss_fake_A = criterion_GAN(D_A(fake_A.detach()), torch.zeros_like(D_A(fake_A)))
        loss_D_A = (loss_real_A + loss_fake_A) / 2
        loss_D_A.backward()
        optimizer_D_A.step()

        # Disc B
        optimizer_D_B.zero_grad()
        loss_real_B = criterion_GAN(D_B(real_B), torch.ones_like(D_B(real_B)))
        loss_fake_B = criterion_GAN(D_B(fake_B.detach()), torch.zeros_like(D_B(fake_B)))
        loss_D_B = (loss_real_B + loss_fake_B) / 2
        loss_D_B.backward()
        optimizer_D_B.step()

        progress_bar.set_description(f'Epoch {epoch}: G={loss_G.item():.4f}, Da={loss_D_A.item():.4f}, Db={loss_D_B.item():.4f}')

        if i % sample_interval == 0:
            save_image((fake_B + 1) / 2, f'samples/fake_sketch_{epoch}_{i}.png')
            save_image((fake_A + 1) / 2, f'samples/fake_face_{epoch}_{i}.png')

    # Save checkpoint
    torch.save({
        'epoch': epoch,
        'G': G.state_dict(),
        'F': F.state_dict(),
        'D_A': D_A.state_dict(),
        'D_B': D_B.state_dict(),
        'opt_G': optimizer_G.state_dict(),
        'opt_D_A': optimizer_D_A.state_dict(),
        'opt_D_B': optimizer_D_B.state_dict()
    }, os.path.join(checkpoint_dir, 'latest.pth'))
    print(f'Checkpoint saved at epoch {epoch}')

NameError: name 'itertools' is not defined

In [None]:
# Load latest
checkpoint = torch.load(os.path.join(checkpoint_dir, 'latest.pth'), map_location=device)
G.load_state_dict(checkpoint['G'])
F.load_state_dict(checkpoint['F'])
G.eval()
F.eval()

# Test function
def test_image(img_path, is_sketch=False):
    img = Image.open(img_path).convert('RGB')
    img_t = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        if is_sketch:
            out = F(img_t)  # Sketch to face
        else:
            out = G(img_t)  # Face to sketch
    out = (out.squeeze(0).cpu() + 1) / 2
    return transforms.ToPILImage()(out)

# Example: Replace with your image path
out_img = test_image('/content/test_face.jpg')  # Assume face
out_img.show()

Epoch 0: G=5.2119, Da=0.1675, Db=0.0308:  30%|██▉       | 6120/20655 [1:09:14<2:44:27,  1.47it/s]   


KeyboardInterrupt: 