# CycleGAN Lite - Monet Style Transfer

**Lightweight CycleGAN** optimized for ~30-45 min training on T4 GPU.

Key optimizations:
- 128x128 resolution (upscaled to 256x256 for submission)
- 6 ResNet blocks (vs 9)
- 32 base filters (vs 64)
- 10 epochs
- Batch size 4

In [None]:
!pip install -q torchmetrics[image] torch-fidelity

In [None]:
import os
import random
import itertools
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.auto import tqdm
from datetime import datetime
import zipfile

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.utils import make_grid, save_image

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Configuration - Optimized for speed
class Config:
    # Paths
    MONET_DIR = '/kaggle/input/gan-getting-started/monet_jpg'
    PHOTO_DIR = '/kaggle/input/gan-getting-started/photo_jpg'
    OUTPUT_DIR = '/kaggle/working'
    
    # Image settings
    IMAGE_SIZE = 128      # Train at 128, upscale to 256 for submission
    OUTPUT_SIZE = 256     # Submission requirement
    
    # Training - Lighter settings
    EPOCHS = 10
    BATCH_SIZE = 4        # Larger batch for speed
    LR = 2e-4
    BETA1 = 0.5
    BETA2 = 0.999
    
    # Loss weights
    LAMBDA_CYCLE = 10.0
    LAMBDA_IDENTITY = 5.0
    
    # Architecture - Lighter
    NGF = 32              # Reduced from 64
    NDF = 32              # Reduced from 64
    N_RESIDUAL = 6        # Reduced from 9
    
    # Generation
    NUM_GENERATE = 7000
    
    # Misc
    SEED = 42
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    NUM_WORKERS = 2

cfg = Config()

# Set seeds
random.seed(cfg.SEED)
np.random.seed(cfg.SEED)
torch.manual_seed(cfg.SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(cfg.SEED)
    torch.backends.cudnn.benchmark = True  # Speed optimization

# Create directories
os.makedirs(os.path.join(cfg.OUTPUT_DIR, 'samples'), exist_ok=True)
os.makedirs(os.path.join(cfg.OUTPUT_DIR, 'images'), exist_ok=True)

print(f"Training at {cfg.IMAGE_SIZE}x{cfg.IMAGE_SIZE}, output at {cfg.OUTPUT_SIZE}x{cfg.OUTPUT_SIZE}")
print(f"Epochs: {cfg.EPOCHS}, Batch: {cfg.BATCH_SIZE}")

## Dataset

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = [f for f in os.listdir(root_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        print(f"Loaded {len(self.images)} images from {root_dir}")
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img = Image.open(os.path.join(self.root_dir, self.images[idx])).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img

# Transforms
train_transform = transforms.Compose([
    transforms.Resize(int(cfg.IMAGE_SIZE * 1.1), transforms.InterpolationMode.BICUBIC),
    transforms.RandomCrop(cfg.IMAGE_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# Dataloaders
monet_dataset = ImageDataset(cfg.MONET_DIR, train_transform)
photo_dataset = ImageDataset(cfg.PHOTO_DIR, train_transform)

monet_loader = DataLoader(monet_dataset, batch_size=cfg.BATCH_SIZE, shuffle=True, 
                          num_workers=cfg.NUM_WORKERS, pin_memory=True, drop_last=True)
photo_loader = DataLoader(photo_dataset, batch_size=cfg.BATCH_SIZE, shuffle=True, 
                          num_workers=cfg.NUM_WORKERS, pin_memory=True, drop_last=True)

print(f"Monet batches: {len(monet_loader)}, Photo batches: {len(photo_loader)}")

## Model Architecture

In [None]:
class ResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, 3),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, 3),
            nn.InstanceNorm2d(channels)
        )
    
    def forward(self, x):
        return x + self.block(x)


class Generator(nn.Module):
    def __init__(self, ngf=32, n_res=6):
        super().__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(3, ngf, 7),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(inplace=True),
            # Downsample 1
            nn.Conv2d(ngf, ngf*2, 3, stride=2, padding=1),
            nn.InstanceNorm2d(ngf*2),
            nn.ReLU(inplace=True),
            # Downsample 2
            nn.Conv2d(ngf*2, ngf*4, 3, stride=2, padding=1),
            nn.InstanceNorm2d(ngf*4),
            nn.ReLU(inplace=True)
        )
        
        # Residual blocks
        self.res_blocks = nn.Sequential(*[ResBlock(ngf*4) for _ in range(n_res)])
        
        # Decoder
        self.decoder = nn.Sequential(
            # Upsample 1
            nn.ConvTranspose2d(ngf*4, ngf*2, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(ngf*2),
            nn.ReLU(inplace=True),
            # Upsample 2
            nn.ConvTranspose2d(ngf*2, ngf, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(inplace=True),
            # Output
            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf, 3, 7),
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.res_blocks(x)
        return self.decoder(x)


class Discriminator(nn.Module):
    def __init__(self, ndf=32):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, ndf, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf, ndf*2, 4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf*2, ndf*4, 4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf*4),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf*4, ndf*8, 4, stride=1, padding=1),
            nn.InstanceNorm2d(ndf*8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf*8, 1, 4, stride=1, padding=1)
        )
    
    def forward(self, x):
        return self.model(x)

In [None]:
# Initialize models
def init_weights(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.normal_(m.weight, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)
    elif isinstance(m, nn.InstanceNorm2d) and m.weight is not None:
        nn.init.normal_(m.weight, 1.0, 0.02)
        nn.init.constant_(m.bias, 0.0)

G_M = Generator(cfg.NGF, cfg.N_RESIDUAL).to(cfg.DEVICE)  # Photo -> Monet
G_P = Generator(cfg.NGF, cfg.N_RESIDUAL).to(cfg.DEVICE)  # Monet -> Photo
D_M = Discriminator(cfg.NDF).to(cfg.DEVICE)
D_P = Discriminator(cfg.NDF).to(cfg.DEVICE)

G_M.apply(init_weights)
G_P.apply(init_weights)
D_M.apply(init_weights)
D_P.apply(init_weights)

n_params = sum(p.numel() for p in G_M.parameters())
print(f"Generator params: {n_params:,}")
print(f"Total params: {n_params * 2 + sum(p.numel() for p in D_M.parameters()) * 2:,}")

## Training Setup

In [None]:
# Loss and optimizers
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

opt_G = optim.Adam(itertools.chain(G_M.parameters(), G_P.parameters()), lr=cfg.LR, betas=(cfg.BETA1, cfg.BETA2))
opt_D = optim.Adam(itertools.chain(D_M.parameters(), D_P.parameters()), lr=cfg.LR, betas=(cfg.BETA1, cfg.BETA2))

# LR scheduler
scheduler_G = optim.lr_scheduler.LambdaLR(opt_G, lambda e: 1.0 - max(0, e - cfg.EPOCHS//2) / (cfg.EPOCHS//2 + 1))
scheduler_D = optim.lr_scheduler.LambdaLR(opt_D, lambda e: 1.0 - max(0, e - cfg.EPOCHS//2) / (cfg.EPOCHS//2 + 1))

In [None]:
# Replay buffer
class ReplayBuffer:
    def __init__(self, max_size=50):
        self.max_size = max_size
        self.data = []
    
    def push_and_pop(self, data):
        result = []
        for img in data:
            img = img.unsqueeze(0)
            if len(self.data) < self.max_size:
                self.data.append(img)
                result.append(img)
            elif random.random() > 0.5:
                idx = random.randint(0, self.max_size - 1)
                result.append(self.data[idx].clone())
                self.data[idx] = img
            else:
                result.append(img)
        return torch.cat(result, 0)

buf_M = ReplayBuffer()
buf_P = ReplayBuffer()

## Training

In [None]:
def train():
    history = {'G': [], 'D': []}
    fixed_photos = next(iter(photo_loader))[:4].to(cfg.DEVICE)
    
    print(f"Training for {cfg.EPOCHS} epochs...")
    start = datetime.now()
    
    for epoch in range(cfg.EPOCHS):
        G_M.train(); G_P.train(); D_M.train(); D_P.train()
        
        g_losses, d_losses = [], []
        monet_iter = iter(monet_loader)
        photo_iter = iter(photo_loader)
        n_batches = min(len(monet_loader), len(photo_loader))
        
        pbar = tqdm(range(n_batches), desc=f"Epoch {epoch+1}/{cfg.EPOCHS}")
        for _ in pbar:
            try:
                real_M = next(monet_iter).to(cfg.DEVICE)
                real_P = next(photo_iter).to(cfg.DEVICE)
            except StopIteration:
                break
            
            bs = real_M.size(0)
            valid = torch.ones(bs, 1, 14, 14, device=cfg.DEVICE)
            fake_label = torch.zeros(bs, 1, 14, 14, device=cfg.DEVICE)
            
            # ===== Train Generators =====
            opt_G.zero_grad()
            
            # Identity
            loss_id = (criterion_identity(G_M(real_M), real_M) + 
                       criterion_identity(G_P(real_P), real_P)) * cfg.LAMBDA_IDENTITY
            
            # GAN
            fake_M = G_M(real_P)
            fake_P = G_P(real_M)
            loss_gan = (criterion_GAN(D_M(fake_M), valid) + 
                        criterion_GAN(D_P(fake_P), valid))
            
            # Cycle
            loss_cycle = (criterion_cycle(G_P(fake_M), real_P) + 
                          criterion_cycle(G_M(fake_P), real_M)) * cfg.LAMBDA_CYCLE
            
            loss_G = loss_gan + loss_cycle + loss_id
            loss_G.backward()
            opt_G.step()
            
            # ===== Train Discriminators =====
            opt_D.zero_grad()
            
            # D_M
            fake_M_buf = buf_M.push_and_pop(fake_M.detach())
            loss_D_M = (criterion_GAN(D_M(real_M), valid) + 
                        criterion_GAN(D_M(fake_M_buf), fake_label)) * 0.5
            
            # D_P
            fake_P_buf = buf_P.push_and_pop(fake_P.detach())
            loss_D_P = (criterion_GAN(D_P(real_P), valid) + 
                        criterion_GAN(D_P(fake_P_buf), fake_label)) * 0.5
            
            loss_D = loss_D_M + loss_D_P
            loss_D.backward()
            opt_D.step()
            
            g_losses.append(loss_G.item())
            d_losses.append(loss_D.item())
            pbar.set_postfix({'G': f"{loss_G.item():.3f}", 'D': f"{loss_D.item():.3f}"})
        
        # Record & schedule
        history['G'].append(np.mean(g_losses))
        history['D'].append(np.mean(d_losses))
        scheduler_G.step()
        scheduler_D.step()
        
        # Save samples
        G_M.eval()
        with torch.no_grad():
            samples = G_M(fixed_photos) * 0.5 + 0.5
            save_image(samples, os.path.join(cfg.OUTPUT_DIR, 'samples', f'epoch_{epoch+1:02d}.png'), nrow=2)
        
        elapsed = (datetime.now() - start).total_seconds() / 60
        print(f"Epoch {epoch+1} | G: {history['G'][-1]:.4f} | D: {history['D'][-1]:.4f} | Time: {elapsed:.1f}min")
    
    print(f"\nTraining completed in {elapsed:.1f} minutes")
    return history

history = train()

## Generate Submission

In [None]:
def generate_submission():
    G_M.eval()
    images_dir = os.path.join(cfg.OUTPUT_DIR, 'images')
    
    # Transform for generation (at training size, then upscale)
    gen_transform = transforms.Compose([
        transforms.Resize((cfg.IMAGE_SIZE, cfg.IMAGE_SIZE), transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    
    photo_files = [f for f in os.listdir(cfg.PHOTO_DIR) if f.endswith(('.jpg', '.png'))]
    
    print(f"Generating {cfg.NUM_GENERATE} images...")
    count = 0
    
    with torch.no_grad():
        pbar = tqdm(total=cfg.NUM_GENERATE)
        while count < cfg.NUM_GENERATE:
            for f in photo_files:
                if count >= cfg.NUM_GENERATE:
                    break
                
                img = Image.open(os.path.join(cfg.PHOTO_DIR, f)).convert('RGB')
                img_t = gen_transform(img).unsqueeze(0).to(cfg.DEVICE)
                
                fake = G_M(img_t)
                fake = fake * 0.5 + 0.5
                
                # Upscale to 256x256 for submission
                fake = F.interpolate(fake, size=(cfg.OUTPUT_SIZE, cfg.OUTPUT_SIZE), 
                                     mode='bicubic', align_corners=False)
                fake = fake.clamp(0, 1)
                
                save_image(fake, os.path.join(images_dir, f'{count:05d}.jpg'))
                count += 1
                pbar.update(1)
        pbar.close()
    
    # Create zip
    zip_path = os.path.join(cfg.OUTPUT_DIR, 'images.zip')
    print("Creating submission zip...")
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
        for f in os.listdir(images_dir):
            if f.endswith('.jpg'):
                zf.write(os.path.join(images_dir, f), f)
    
    print(f"Created: {zip_path} ({os.path.getsize(zip_path)/1024/1024:.1f} MB)")
    return count

num_generated = generate_submission()

## FID Evaluation

In [None]:
from torchmetrics.image.fid import FrechetInceptionDistance

def compute_fid(real_dir, fake_dir, num_samples=300, batch_size=32):
    """Compute FID between real Monet images and generated images."""
    print("Computing FID score...")
    
    # FID transform - Inception expects specific input
    fid_transform = transforms.Compose([
        transforms.Resize((299, 299), transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
    ])
    
    # Initialize FID metric
    fid = FrechetInceptionDistance(feature=2048, normalize=True).to(cfg.DEVICE)
    
    # Get real Monet images
    real_files = [f for f in os.listdir(real_dir) if f.endswith(('.jpg', '.png'))][:num_samples]
    print(f"Processing {len(real_files)} real Monet images...")
    
    real_images = []
    for f in tqdm(real_files, desc="Loading real"):
        img = Image.open(os.path.join(real_dir, f)).convert('RGB')
        img_t = fid_transform(img)
        real_images.append(img_t)
        
        if len(real_images) == batch_size:
            batch = torch.stack(real_images).to(cfg.DEVICE)
            fid.update(batch, real=True)
            real_images = []
    
    if real_images:
        batch = torch.stack(real_images).to(cfg.DEVICE)
        fid.update(batch, real=True)
    
    # Get fake/generated images
    fake_files = [f for f in os.listdir(fake_dir) if f.endswith(('.jpg', '.png'))][:num_samples]
    print(f"Processing {len(fake_files)} generated images...")
    
    fake_images = []
    for f in tqdm(fake_files, desc="Loading fake"):
        img = Image.open(os.path.join(fake_dir, f)).convert('RGB')
        img_t = fid_transform(img)
        fake_images.append(img_t)
        
        if len(fake_images) == batch_size:
            batch = torch.stack(fake_images).to(cfg.DEVICE)
            fid.update(batch, real=False)
            fake_images = []
    
    if fake_images:
        batch = torch.stack(fake_images).to(cfg.DEVICE)
        fid.update(batch, real=False)
    
    # Compute FID
    fid_score = fid.compute().item()
    
    print(f"\n{'='*40}")
    print(f"  FID Score: {fid_score:.2f}")
    print(f"{'='*40}")
    
    return fid_score

# Compute FID
images_dir = os.path.join(cfg.OUTPUT_DIR, 'images')
fid_score = compute_fid(cfg.MONET_DIR, images_dir, num_samples=300)

In [None]:
# Save FID results to CSV
import csv

results_path = os.path.join(cfg.OUTPUT_DIR, 'fid_results.csv')
with open(results_path, 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['metric', 'value'])
    writer.writerow(['fid_score', f'{fid_score:.4f}'])
    writer.writerow(['num_epochs', cfg.EPOCHS])
    writer.writerow(['image_size', cfg.IMAGE_SIZE])
    writer.writerow(['output_size', cfg.OUTPUT_SIZE])
    writer.writerow(['ngf', cfg.NGF])
    writer.writerow(['n_residual', cfg.N_RESIDUAL])
    writer.writerow(['num_generated', num_generated])

print(f"Results saved to {results_path}")

In [None]:
## Visualization

In [None]:
# Plot losses
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
ax[0].plot(history['G'], label='Generator')
ax[0].set_title('Generator Loss')
ax[0].legend()
ax[0].grid(True, alpha=0.3)

ax[1].plot(history['D'], label='Discriminator', color='orange')
ax[1].set_title('Discriminator Loss')
ax[1].legend()
ax[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(cfg.OUTPUT_DIR, 'losses.png'))
plt.show()

In [None]:
# Show results
G_M.eval()
photos = next(iter(DataLoader(photo_dataset, batch_size=6, shuffle=True)))

with torch.no_grad():
    fakes = G_M(photos.to(cfg.DEVICE))

photos = photos * 0.5 + 0.5
fakes = fakes * 0.5 + 0.5

fig, axes = plt.subplots(2, 6, figsize=(15, 5))
for i in range(6):
    axes[0, i].imshow(photos[i].permute(1,2,0).cpu())
    axes[0, i].axis('off')
    axes[1, i].imshow(fakes[i].permute(1,2,0).cpu())
    axes[1, i].axis('off')
axes[0, 0].set_title('Photo')
axes[1, 0].set_title('Monet')
plt.suptitle(f'Photo â†’ Monet (FID: {fid_score:.2f})', fontsize=14)
plt.tight_layout()
plt.savefig(os.path.join(cfg.OUTPUT_DIR, 'results.png'))
plt.show()

## Summary

In [None]:
print("\n" + "="*50)
print("         CYCLEGAN LITE SUMMARY")
print("="*50)
print(f"\nArchitecture:")
print(f"  - Training size: {cfg.IMAGE_SIZE}x{cfg.IMAGE_SIZE}")
print(f"  - Output size: {cfg.OUTPUT_SIZE}x{cfg.OUTPUT_SIZE}")
print(f"  - ResNet blocks: {cfg.N_RESIDUAL}")
print(f"  - Base filters: {cfg.NGF}")
print(f"\nTraining:")
print(f"  - Epochs: {cfg.EPOCHS}")
print(f"  - Batch size: {cfg.BATCH_SIZE}")
print(f"\nResults:")
print(f"  - Images generated: {num_generated}")
print(f"\n" + "+"*50)
print(f"  FID SCORE: {fid_score:.2f}")
print("+"*50)
print(f"\nOutput files:")
print(f"  - Submission: {os.path.join(cfg.OUTPUT_DIR, 'images.zip')}")
print(f"  - Results CSV: {results_path}")
print("="*50)