# ðŸŽ¨ I'm Something of a Painter Myself - FastCUT Implementation

This notebook implements a **lightweight FastCUT model** for the Kaggle competition to generate Monet-style paintings.

**Key Features:**
- Simplified FastCUT architecture (based on CUT paper, Park et al. ECCV 2020)
- Optimized for <10 minute runtime on T4 GPU
- 64x64 resolution for speed
- FID evaluation using torchmetrics

---

## 1. Setup & Environment

In [None]:
# Install required packages (run this cell first, then restart kernel if needed)
!pip install -q torchmetrics[image] torch-fidelity scipy

import os
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.auto import tqdm
import csv
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.utils import make_grid, save_image
from torchmetrics.image.fid import FrechetInceptionDistance

print(f"PyTorch version: {torch.__version__}")
print("All imports successful!")

In [None]:
# ===== Configuration =====
class Config:
    # Paths (adjust for Kaggle)
    MONET_DIR = '/kaggle/input/gan-getting-started/monet_jpg'
    PHOTO_DIR = '/kaggle/input/gan-getting-started/photo_jpg'
    OUTPUT_DIR = '/kaggle/working'
    FAKE_IMAGES_DIR = '/kaggle/working/fake_images'
    REAL_IMAGES_DIR = '/kaggle/working/real_images'
    
    # Training parameters (optimized for speed)
    IMAGE_SIZE = 64          # Small size for speed
    BATCH_SIZE = 16
    NUM_ITERATIONS = 400     # ~8-9 minutes on T4
    LR_G = 2e-4
    LR_D = 2e-4
    BETA1 = 0.5
    BETA2 = 0.999
    
    # Model parameters (lightweight)
    NGF = 48                 # Generator filters
    NDF = 48                 # Discriminator filters
    N_BLOCKS = 4             # ResNet blocks
    NCE_LAYERS = [0, 2]      # Layers for NCE loss
    NCE_TEMP = 0.07          # Temperature for NCE
    LAMBDA_NCE = 1.0         # NCE loss weight
    LAMBDA_GAN = 1.0         # GAN loss weight
    
    # Generation
    NUM_FAKE_IMAGES = 250
    NUM_REAL_IMAGES = 250
    
    # Misc
    SEED = 42
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    LOG_INTERVAL = 50
    SAVE_INTERVAL = 100

cfg = Config()
print(f"Device: {cfg.DEVICE}")
if cfg.DEVICE == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Set seeds for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(cfg.SEED)
print(f"Seed set to {cfg.SEED}")

In [None]:
# Create output directories
os.makedirs(cfg.FAKE_IMAGES_DIR, exist_ok=True)
os.makedirs(cfg.REAL_IMAGES_DIR, exist_ok=True)
os.makedirs(os.path.join(cfg.OUTPUT_DIR, 'checkpoints'), exist_ok=True)
os.makedirs(os.path.join(cfg.OUTPUT_DIR, 'samples'), exist_ok=True)
print("Directories created.")

## 2. Dataset Preparation

In [None]:
class ImageDataset(Dataset):
    """Dataset for loading images from a directory."""
    
    def __init__(self, root_dir, transform=None, max_samples=None):
        self.root_dir = root_dir
        self.transform = transform
        
        # Get all image files
        valid_ext = {'.jpg', '.jpeg', '.png', '.bmp'}
        self.image_paths = [
            os.path.join(root_dir, f) for f in os.listdir(root_dir)
            if os.path.splitext(f)[1].lower() in valid_ext
        ]
        
        if max_samples:
            self.image_paths = self.image_paths[:max_samples]
            
        print(f"Loaded {len(self.image_paths)} images from {root_dir}")
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image

In [None]:
# Define transforms
train_transform = transforms.Compose([
    transforms.Resize((cfg.IMAGE_SIZE, cfg.IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])

# Load datasets
monet_dataset = ImageDataset(cfg.MONET_DIR, transform=train_transform)
photo_dataset = ImageDataset(cfg.PHOTO_DIR, transform=train_transform)

# Create dataloaders
monet_loader = DataLoader(
    monet_dataset, 
    batch_size=cfg.BATCH_SIZE, 
    shuffle=True, 
    num_workers=2,
    pin_memory=True,
    drop_last=True
)

photo_loader = DataLoader(
    photo_dataset, 
    batch_size=cfg.BATCH_SIZE, 
    shuffle=True, 
    num_workers=2,
    pin_memory=True,
    drop_last=True
)

print(f"Monet batches: {len(monet_loader)}, Photo batches: {len(photo_loader)}")

In [None]:
# Visualize sample batch
def show_batch(images, title="Sample Batch", nrow=8):
    """Display a batch of images."""
    images = images * 0.5 + 0.5  # Denormalize
    grid = make_grid(images, nrow=nrow, padding=2, normalize=False)
    
    plt.figure(figsize=(12, 6))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.title(title)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

# Show samples
sample_monet = next(iter(monet_loader))
sample_photo = next(iter(photo_loader))
show_batch(sample_monet[:8], "Real Monet Paintings")
show_batch(sample_photo[:8], "Real Photos")

## 3. FastCUT Model Architecture

Implementing a simplified version of:
- **Generator**: ResNet-based encoder-decoder
- **Discriminator**: PatchGAN discriminator
- **PatchNCE Loss**: Contrastive loss for unpaired translation

In [None]:
# ===== Building Blocks =====

class ResidualBlock(nn.Module):
    """Residual block with instance normalization."""
    
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, 3),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, 3),
            nn.InstanceNorm2d(channels)
        )
    
    def forward(self, x):
        return x + self.block(x)


class Downsample(nn.Module):
    """Downsampling block."""
    
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1),
            nn.InstanceNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)


class Upsample(nn.Module):
    """Upsampling block."""
    
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(in_ch, out_ch, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

In [None]:
# ===== Generator =====

class Generator(nn.Module):
    """ResNet-based Generator for FastCUT. Returns output and features for NCE."""
    
    def __init__(self, in_ch=3, out_ch=3, ngf=48, n_blocks=4):
        super().__init__()
        
        # Initial conv
        self.initial = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_ch, ngf, 7),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(inplace=True)
        )
        
        # Downsampling
        self.down1 = Downsample(ngf, ngf * 2)
        self.down2 = Downsample(ngf * 2, ngf * 4)
        
        # Residual blocks
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(ngf * 4) for _ in range(n_blocks)]
        )
        
        # Upsampling
        self.up1 = Upsample(ngf * 4, ngf * 2)
        self.up2 = Upsample(ngf * 2, ngf)
        
        # Final conv
        self.final = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf, out_ch, 7),
            nn.Tanh()
        )
        
        self.nce_layers = [0, 2]
    
    def forward(self, x, return_features=False):
        features = []
        
        x = self.initial(x)
        features.append(x)
        
        x = self.down1(x)
        features.append(x)
        
        x = self.down2(x)
        features.append(x)
        
        x = self.res_blocks(x)
        features.append(x)
        
        x = self.up1(x)
        x = self.up2(x)
        x = self.final(x)
        
        if return_features:
            return x, [features[i] for i in self.nce_layers]
        return x
    
    def encode(self, x):
        """Extract encoder features only."""
        features = []
        
        x = self.initial(x)
        features.append(x)
        
        x = self.down1(x)
        features.append(x)
        
        x = self.down2(x)
        features.append(x)
        
        x = self.res_blocks(x)
        features.append(x)
        
        return [features[i] for i in self.nce_layers]

In [None]:
# ===== Discriminator =====

class PatchDiscriminator(nn.Module):
    """PatchGAN Discriminator."""
    
    def __init__(self, in_ch=3, ndf=48):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Conv2d(in_ch, ndf, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf, ndf * 2, 4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf * 2, ndf * 4, 4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf * 4, ndf * 8, 4, stride=1, padding=1),
            nn.InstanceNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf * 8, 1, 4, stride=1, padding=1)
        )
    
    def forward(self, x):
        return self.model(x)

In [None]:
# ===== PatchNCE Loss =====

class PatchSampleMLP(nn.Module):
    """MLP head for projecting features for NCE loss."""
    
    def __init__(self, in_dim, out_dim=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.ReLU(inplace=True),
            nn.Linear(out_dim, out_dim)
        )
    
    def forward(self, x):
        return self.mlp(x)


class PatchNCELoss(nn.Module):
    """PatchNCE loss for contrastive learning."""
    
    def __init__(self, nce_temp=0.07, num_patches=64):
        super().__init__()
        self.nce_temp = nce_temp
        self.num_patches = num_patches
        self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
    
    def forward(self, feat_q, feat_k):
        B, C, H, W = feat_q.shape
        
        feat_q = feat_q.view(B, C, -1)
        feat_k = feat_k.view(B, C, -1)
        
        N = feat_q.shape[2]
        num_patches = min(self.num_patches, N)
        patch_ids = torch.randperm(N, device=feat_q.device)[:num_patches]
        
        feat_q = feat_q[:, :, patch_ids]
        feat_k = feat_k[:, :, patch_ids]
        
        feat_q = F.normalize(feat_q, dim=1)
        feat_k = F.normalize(feat_k, dim=1)
        
        total_loss = 0.0
        for b in range(B):
            q = feat_q[b].T
            k = feat_k[b].T
            
            l_pos = torch.sum(q * k, dim=1, keepdim=True)
            l_neg = torch.mm(q, k.T)
            
            logits = torch.cat([l_pos, l_neg], dim=1) / self.nce_temp
            labels = torch.zeros(num_patches, dtype=torch.long, device=feat_q.device)
            
            total_loss += self.cross_entropy(logits, labels)
        
        return total_loss / B

In [None]:
# ===== Initialize Models =====

def init_weights(m):
    """Initialize network weights."""
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, 'bias') and m.bias is not None:
            nn.init.constant_(m.bias.data, 0.0)
    elif classname.find('BatchNorm') != -1 or classname.find('InstanceNorm') != -1:
        if hasattr(m, 'weight') and m.weight is not None:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
        if hasattr(m, 'bias') and m.bias is not None:
            nn.init.constant_(m.bias.data, 0.0)

# Create models
G = Generator(ngf=cfg.NGF, n_blocks=cfg.N_BLOCKS).to(cfg.DEVICE)
D = PatchDiscriminator(ndf=cfg.NDF).to(cfg.DEVICE)

G.apply(init_weights)
D.apply(init_weights)

# Create MLP heads for NCE loss
with torch.no_grad():
    dummy = torch.zeros(1, 3, cfg.IMAGE_SIZE, cfg.IMAGE_SIZE).to(cfg.DEVICE)
    _, feats = G(dummy, return_features=True)
    feat_dims = [f.shape[1] for f in feats]
    print(f"Feature dimensions for NCE: {feat_dims}")

mlp_heads = nn.ModuleList([
    PatchSampleMLP(dim, 256).to(cfg.DEVICE) for dim in feat_dims
])

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Generator params: {count_params(G):,}")
print(f"Discriminator params: {count_params(D):,}")
print(f"MLP heads params: {count_params(mlp_heads):,}")

## 4. Training Loop

In [None]:
# ===== Loss Functions =====

def gan_loss_lsgan(pred, target_is_real):
    """LSGAN loss."""
    target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
    return F.mse_loss(pred, target)

nce_loss_fn = PatchNCELoss(nce_temp=cfg.NCE_TEMP, num_patches=64)

In [None]:
# ===== Optimizers =====

optimizer_G = optim.Adam(
    list(G.parameters()) + list(mlp_heads.parameters()),
    lr=cfg.LR_G, 
    betas=(cfg.BETA1, cfg.BETA2)
)

optimizer_D = optim.Adam(
    D.parameters(), 
    lr=cfg.LR_D, 
    betas=(cfg.BETA1, cfg.BETA2)
)

In [None]:
# ===== Training =====

def train_fastcut():
    """Main training loop for FastCUT."""
    
    G.train()
    D.train()
    mlp_heads.train()
    
    history = {'G_loss': [], 'D_loss': [], 'NCE_loss': [], 'GAN_loss': []}
    
    monet_iter = iter(monet_loader)
    photo_iter = iter(photo_loader)
    
    fixed_photos = next(iter(photo_loader))[:8].to(cfg.DEVICE)
    
    print(f"Starting training for {cfg.NUM_ITERATIONS} iterations...")
    print("=" * 60)
    
    start_time = datetime.now()
    pbar = tqdm(range(1, cfg.NUM_ITERATIONS + 1), desc="Training")
    
    for iteration in pbar:
        # Get batch (with cycling)
        try:
            real_monet = next(monet_iter)
        except StopIteration:
            monet_iter = iter(monet_loader)
            real_monet = next(monet_iter)
        
        try:
            real_photo = next(photo_iter)
        except StopIteration:
            photo_iter = iter(photo_loader)
            real_photo = next(photo_iter)
        
        real_monet = real_monet.to(cfg.DEVICE)
        real_photo = real_photo.to(cfg.DEVICE)
        
        # ===== Train Generator =====
        optimizer_G.zero_grad()
        
        fake_monet, fake_feats = G(real_photo, return_features=True)
        real_feats = G.encode(real_photo)
        
        # GAN loss
        pred_fake = D(fake_monet)
        loss_G_gan = gan_loss_lsgan(pred_fake, True)
        
        # NCE loss
        loss_nce = 0.0
        for i, (feat_q, feat_k, mlp) in enumerate(zip(fake_feats, real_feats, mlp_heads)):
            B, C, H, W = feat_q.shape
            feat_q_flat = feat_q.permute(0, 2, 3, 1).reshape(-1, C)
            feat_k_flat = feat_k.permute(0, 2, 3, 1).reshape(-1, C)
            
            feat_q_proj = mlp(feat_q_flat).reshape(B, H, W, -1).permute(0, 3, 1, 2)
            feat_k_proj = mlp(feat_k_flat).reshape(B, H, W, -1).permute(0, 3, 1, 2)
            
            loss_nce += nce_loss_fn(feat_q_proj, feat_k_proj)
        
        loss_nce = loss_nce / len(mlp_heads)
        
        loss_G = cfg.LAMBDA_GAN * loss_G_gan + cfg.LAMBDA_NCE * loss_nce
        loss_G.backward()
        optimizer_G.step()
        
        # ===== Train Discriminator =====
        optimizer_D.zero_grad()
        
        pred_real = D(real_monet)
        loss_D_real = gan_loss_lsgan(pred_real, True)
        
        pred_fake = D(fake_monet.detach())
        loss_D_fake = gan_loss_lsgan(pred_fake, False)
        
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        optimizer_D.step()
        
        # Record
        history['G_loss'].append(loss_G.item())
        history['D_loss'].append(loss_D.item())
        history['NCE_loss'].append(loss_nce.item())
        history['GAN_loss'].append(loss_G_gan.item())
        
        pbar.set_postfix({'G': f"{loss_G.item():.3f}", 'D': f"{loss_D.item():.3f}", 'NCE': f"{loss_nce.item():.3f}"})
        
        if iteration % cfg.LOG_INTERVAL == 0:
            elapsed = (datetime.now() - start_time).total_seconds() / 60
            print(f"\n[{iteration}/{cfg.NUM_ITERATIONS}] G: {loss_G.item():.4f}, D: {loss_D.item():.4f}, NCE: {loss_nce.item():.4f} | Time: {elapsed:.1f}min")
        
        if iteration % cfg.SAVE_INTERVAL == 0:
            G.eval()
            with torch.no_grad():
                fake_samples = G(fixed_photos)
                fake_samples = fake_samples * 0.5 + 0.5
                save_image(fake_samples, os.path.join(cfg.OUTPUT_DIR, 'samples', f'iter_{iteration:04d}.png'), nrow=4)
            G.train()
    
    total_time = (datetime.now() - start_time).total_seconds() / 60
    print(f"\n{'=' * 60}")
    print(f"Training completed in {total_time:.2f} minutes")
    
    torch.save({
        'G': G.state_dict(),
        'D': D.state_dict(),
        'mlp_heads': mlp_heads.state_dict(),
        'history': history
    }, os.path.join(cfg.OUTPUT_DIR, 'checkpoints', 'fastcut_final.pth'))
    print("Checkpoint saved.")
    
    return history

In [None]:
# Run training
history = train_fastcut()

## 5. Sample Generation

In [None]:
def generate_fake_images(generator, dataloader, num_images, output_dir):
    """Generate fake Monet images from photos."""
    generator.eval()
    count = 0
    photo_iter = iter(dataloader)
    
    print(f"Generating {num_images} fake Monet images...")
    
    with torch.no_grad():
        pbar = tqdm(total=num_images, desc="Generating")
        
        while count < num_images:
            try:
                photos = next(photo_iter)
            except StopIteration:
                photo_iter = iter(dataloader)
                photos = next(photo_iter)
            
            photos = photos.to(cfg.DEVICE)
            fake_monets = generator(photos)
            fake_monets = fake_monets * 0.5 + 0.5
            
            for img in fake_monets:
                if count >= num_images:
                    break
                save_path = os.path.join(output_dir, f'fake_monet_{count:04d}.png')
                save_image(img, save_path)
                count += 1
                pbar.update(1)
        
        pbar.close()
    
    print(f"Saved {count} fake Monet images to {output_dir}")
    return count

num_generated = generate_fake_images(G, photo_loader, cfg.NUM_FAKE_IMAGES, cfg.FAKE_IMAGES_DIR)

In [None]:
def copy_real_images(source_dir, output_dir, num_images, transform):
    """Copy and resize real Monet images for FID calculation."""
    valid_ext = {'.jpg', '.jpeg', '.png', '.bmp'}
    image_paths = [
        os.path.join(source_dir, f) for f in os.listdir(source_dir)
        if os.path.splitext(f)[1].lower() in valid_ext
    ][:num_images]
    
    print(f"Copying {len(image_paths)} real Monet images...")
    
    for i, path in enumerate(tqdm(image_paths, desc="Copying")):
        img = Image.open(path).convert('RGB')
        img = img.resize((cfg.IMAGE_SIZE, cfg.IMAGE_SIZE), Image.BILINEAR)
        save_path = os.path.join(output_dir, f'real_monet_{i:04d}.png')
        img.save(save_path)
    
    print(f"Saved {len(image_paths)} real Monet images to {output_dir}")
    return len(image_paths)

num_real = copy_real_images(cfg.MONET_DIR, cfg.REAL_IMAGES_DIR, cfg.NUM_REAL_IMAGES, train_transform)

## 6. FID Evaluation

In [None]:
def compute_fid(real_dir, fake_dir, batch_size=32):
    """Compute FID score between real and fake images using torchmetrics."""
    print("Computing FID score...")
    
    fid_transform = transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
    ])
    
    fid = FrechetInceptionDistance(feature=2048, normalize=True).to(cfg.DEVICE)
    
    real_dataset = ImageDataset(real_dir, transform=fid_transform)
    real_loader = DataLoader(real_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    print(f"Processing {len(real_dataset)} real images...")
    for batch in tqdm(real_loader, desc="Real images"):
        batch = batch.to(cfg.DEVICE)
        fid.update(batch, real=True)
    
    fake_dataset = ImageDataset(fake_dir, transform=fid_transform)
    fake_loader = DataLoader(fake_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    print(f"Processing {len(fake_dataset)} fake images...")
    for batch in tqdm(fake_loader, desc="Fake images"):
        batch = batch.to(cfg.DEVICE)
        fid.update(batch, real=False)
    
    fid_score = fid.compute().item()
    
    print(f"\n{'=' * 40}")
    print(f"FID Score: {fid_score:.4f}")
    print(f"{'=' * 40}")
    
    return fid_score

fid_score = compute_fid(cfg.REAL_IMAGES_DIR, cfg.FAKE_IMAGES_DIR)

In [None]:
# Save results to CSV
results_path = os.path.join(cfg.OUTPUT_DIR, 'fid_results.csv')

with open(results_path, 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['metric', 'value'])
    writer.writerow(['fid_score', fid_score])
    writer.writerow(['num_real_images', num_real])
    writer.writerow(['num_fake_images', num_generated])
    writer.writerow(['image_size', cfg.IMAGE_SIZE])
    writer.writerow(['num_iterations', cfg.NUM_ITERATIONS])
    writer.writerow(['model', 'FastCUT'])

print(f"Results saved to {results_path}")

## 7. Visualization

In [None]:
# ===== Plot Training Curves =====

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

axes[0, 0].plot(history['G_loss'], label='G Loss', color='blue', alpha=0.7)
axes[0, 0].set_title('Generator Loss')
axes[0, 0].set_xlabel('Iteration')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

axes[0, 1].plot(history['D_loss'], label='D Loss', color='red', alpha=0.7)
axes[0, 1].set_title('Discriminator Loss')
axes[0, 1].set_xlabel('Iteration')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

axes[1, 0].plot(history['NCE_loss'], label='NCE Loss', color='green', alpha=0.7)
axes[1, 0].set_title('PatchNCE Loss (Contrastive)')
axes[1, 0].set_xlabel('Iteration')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

axes[1, 1].plot(history['GAN_loss'], label='GAN Loss', color='purple', alpha=0.7)
axes[1, 1].set_title('GAN Loss')
axes[1, 1].set_xlabel('Iteration')
axes[1, 1].set_ylabel('Loss')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(cfg.OUTPUT_DIR, 'training_curves.png'), dpi=150)
plt.show()
print("Training curves saved.")

In [None]:
# ===== Display Generated Images Grid =====

def display_generated_grid(fake_dir, num_images=16, nrow=4):
    """Display a grid of generated images."""
    image_files = sorted([f for f in os.listdir(fake_dir) if f.endswith('.png')])[:num_images]
    
    images = []
    for fname in image_files:
        img = Image.open(os.path.join(fake_dir, fname)).convert('RGB')
        img_tensor = transforms.ToTensor()(img)
        images.append(img_tensor)
    
    grid = make_grid(torch.stack(images), nrow=nrow, padding=2)
    
    plt.figure(figsize=(14, 14))
    plt.imshow(grid.permute(1, 2, 0).numpy())
    plt.title(f'Generated Monet Images (FID: {fid_score:.2f})', fontsize=16)
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join(cfg.OUTPUT_DIR, 'generated_grid.png'), dpi=150)
    plt.show()

display_generated_grid(cfg.FAKE_IMAGES_DIR, num_images=16)

In [None]:
# ===== Side-by-side Comparison: Photo -> Monet =====

def show_translation_examples(generator, dataloader, num_examples=8):
    """Show photo to Monet translation examples."""
    generator.eval()
    photos = next(iter(dataloader))[:num_examples].to(cfg.DEVICE)
    
    with torch.no_grad():
        fake_monets = generator(photos)
    
    photos = photos * 0.5 + 0.5
    fake_monets = fake_monets * 0.5 + 0.5
    
    fig, axes = plt.subplots(2, num_examples, figsize=(num_examples * 2, 4))
    
    for i in range(num_examples):
        axes[0, i].imshow(photos[i].permute(1, 2, 0).cpu().numpy())
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_title('Original Photo', fontsize=10)
        
        axes[1, i].imshow(fake_monets[i].permute(1, 2, 0).cpu().numpy())
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_title('Generated Monet', fontsize=10)
    
    plt.suptitle('Photo -> Monet Style Transfer', fontsize=14)
    plt.tight_layout()
    plt.savefig(os.path.join(cfg.OUTPUT_DIR, 'translation_examples.png'), dpi=150)
    plt.show()

show_translation_examples(G, photo_loader)

## 8. Summary & Final Results

In [None]:
# ===== Print Final Summary =====

print("\n" + "=" * 60)
print("          FASTCUT TRAINING SUMMARY")
print("=" * 60)
print()
print("Model Configuration:")
print(f"  - Architecture: FastCUT (Lightweight)")
print(f"  - Image Size: {cfg.IMAGE_SIZE}x{cfg.IMAGE_SIZE}")
print(f"  - Generator Filters: {cfg.NGF}")
print(f"  - Discriminator Filters: {cfg.NDF}")
print(f"  - ResNet Blocks: {cfg.N_BLOCKS}")
print()
print("Training:")
print(f"  - Total Iterations: {cfg.NUM_ITERATIONS}")
print(f"  - Batch Size: {cfg.BATCH_SIZE}")
print(f"  - Learning Rate (G): {cfg.LR_G}")
print(f"  - Learning Rate (D): {cfg.LR_D}")
print()
print("Evaluation:")
print(f"  - Real Images: {num_real}")
print(f"  - Fake Images: {num_generated}")
print()
print("+" + "-" * 30 + "+")
print(f"|{'FID SCORE:':^15}{fid_score:^15.4f}|")
print("+" + "-" * 30 + "+")
print()
print("Output Files:")
print(f"  - Fake images: {cfg.FAKE_IMAGES_DIR}")
print(f"  - Checkpoint: {os.path.join(cfg.OUTPUT_DIR, 'checkpoints', 'fastcut_final.pth')}")
print(f"  - Results CSV: {results_path}")
print("=" * 60)