## Monet Style Transfer - Optimized for Speed & Quality

Uses **pre-trained VGG** for perceptual loss + lightweight generator.  
~15-20 min training with good FID scores.

In [None]:
import os, time, shutil, tempfile, random
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

## 2. Configuration

In [None]:
# 10-MIN Config
DATA_DIR = '/kaggle/input/gan-getting-started'  
OUTPUT_DIR = '/kaggle/working/outputs'
os.makedirs(OUTPUT_DIR, exist_ok=True)

IMG_SIZE = 256
EPOCHS = 15           # Reduced
BATCH_SIZE = 16       # Larger = faster
LR_G = 2e-4
LR_D = 4e-4

LAMBDA_ADV = 1.0
LAMBDA_PERC = 10.0

print(f"10-min config: {EPOCHS} epochs, batch={BATCH_SIZE}")

## Dataset: Photos and Monet Paintings

In [None]:
# Dataset
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data_dir):
        self.photos = sorted(Path(data_dir, 'photo_jpg').glob('*.jpg'))
        self.monets = sorted(Path(data_dir, 'monet_jpg').glob('*.jpg'))
        self.tf = transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])
        print(f"Photos: {len(self.photos)}, Monet: {len(self.monets)}")
    def __len__(self): return len(self.photos)
    def __getitem__(self, i):
        p = Image.open(self.photos[i]).convert('RGB')
        m = Image.open(self.monets[random.randint(0, len(self.monets)-1)]).convert('RGB')
        return self.tf(p), self.tf(m)

ds = Dataset(DATA_DIR)
dl = DataLoader(ds, BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

## Models + VGG Perceptual Loss

## 4. Model Architecture

In [None]:
# Smaller ResBlock
class ResBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1), nn.Conv2d(ch, ch, 3), nn.InstanceNorm2d(ch), nn.ReLU(True),
            nn.ReflectionPad2d(1), nn.Conv2d(ch, ch, 3), nn.InstanceNorm2d(ch)
        )
    def forward(self, x): return x + self.block(x)

# Smaller Generator - 4 ResBlocks instead of 6
class Generator(nn.Module):
    def __init__(self, nf=48):  # Reduced from 64
        super().__init__()
        self.enc = nn.Sequential(
            nn.ReflectionPad2d(3), nn.Conv2d(3, nf, 7), nn.InstanceNorm2d(nf), nn.ReLU(True),
            nn.Conv2d(nf, nf*2, 3, 2, 1), nn.InstanceNorm2d(nf*2), nn.ReLU(True),
            nn.Conv2d(nf*2, nf*4, 3, 2, 1), nn.InstanceNorm2d(nf*4), nn.ReLU(True),
        )
        self.res = nn.Sequential(*[ResBlock(nf*4) for _ in range(4)])  # 4 instead of 6
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(nf*4, nf*2, 3, 2, 1, 1), nn.InstanceNorm2d(nf*2), nn.ReLU(True),
            nn.ConvTranspose2d(nf*2, nf, 3, 2, 1, 1), nn.InstanceNorm2d(nf), nn.ReLU(True),
            nn.ReflectionPad2d(3), nn.Conv2d(nf, 3, 7), nn.Tanh()
        )
    def forward(self, x): return self.dec(self.res(self.enc(x)))

# Smaller Discriminator
class Discriminator(nn.Module):
    def __init__(self, nf=48):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, nf, 4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(nf, nf*2, 4, 2, 1), nn.InstanceNorm2d(nf*2), nn.LeakyReLU(0.2, True),
            nn.Conv2d(nf*2, nf*4, 4, 2, 1), nn.InstanceNorm2d(nf*4), nn.LeakyReLU(0.2, True),
            nn.Conv2d(nf*4, 1, 4, 1, 1)
        )
    def forward(self, x): return self.model(x)

# VGG Loss - use new API
class VGGLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features[:16].eval()
        for p in vgg.parameters(): p.requires_grad = False
        self.vgg = vgg
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
    
    def forward(self, x, y):
        x = (x + 1) / 2; y = (y + 1) / 2
        x = (x - self.mean) / self.std
        y = (y - self.mean) / self.std
        return F.l1_loss(self.vgg(x), self.vgg(y))

G = Generator().to(device)
D = Discriminator().to(device)
vgg_loss = VGGLoss().to(device)

print(f"G: {sum(p.numel() for p in G.parameters()):,} | D: {sum(p.numel() for p in D.parameters()):,}")

## 5. Training Setup

In [None]:
# Optimizers with TTUR
optG = optim.Adam(G.parameters(), lr=LR_G, betas=(0.5, 0.999))
optD = optim.Adam(D.parameters(), lr=LR_D, betas=(0.5, 0.999))
mse = nn.MSELoss()
G_losses, D_losses = [], []

## 6. Training Loop

In [None]:
# Fast Training with AMP
start = time.time()
fixed = next(iter(dl))[0][:4].to(device)
scaler = torch.amp.GradScaler('cuda')

for epoch in range(EPOCHS):
    gL, dL = 0, 0
    for photo, monet in tqdm(dl, desc=f"E{epoch+1}", leave=False):
        photo, monet = photo.to(device), monet.to(device)
        
        # D
        optD.zero_grad()
        with torch.amp.autocast('cuda'):
            fake = G(photo).detach()
            loss_D = 0.5 * (mse(D(monet), torch.ones_like(D(monet))) + 
                           mse(D(fake), torch.zeros_like(D(fake))))
        scaler.scale(loss_D).backward()
        scaler.step(optD)
        
        # G
        optG.zero_grad()
        with torch.amp.autocast('cuda'):
            fake = G(photo)
            loss_adv = mse(D(fake), torch.ones_like(D(fake)))
            loss_perc = vgg_loss(fake, photo)
            loss_G = LAMBDA_ADV * loss_adv + LAMBDA_PERC * loss_perc
        scaler.scale(loss_G).backward()
        scaler.step(optG)
        scaler.update()
        
        gL += loss_G.item(); dL += loss_D.item()
    
    G_losses.append(gL/len(dl)); D_losses.append(dL/len(dl))
    
    if (epoch+1) % 5 == 0:
        print(f"E{epoch+1} D:{D_losses[-1]:.3f} G:{G_losses[-1]:.3f} [{(time.time()-start)/60:.1f}m]")

total = (time.time()-start)/60
print(f"\n✓ Done in {total:.1f} min")
torch.save(G.state_dict(), f'{OUTPUT_DIR}/G.pt')

## 7. Training Visualization

## 8. FID Evaluation

**FID (Fréchet Inception Distance)** measures the quality of generated images.  
Lower FID = better quality and more realistic images.

In [None]:
# Install torch-fidelity for FID computation
import subprocess
import sys

try:
    import torch_fidelity
    print("✓ torch-fidelity already installed")
except ImportError:
    print("Installing torch-fidelity...")
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', 'torch-fidelity'])
    print("✓ torch-fidelity installed")
    
# Now import torchmetrics FID
try:
    from torchmetrics.image.fid import FrechetInceptionDistance
    print("✓ FID evaluation module loaded successfully")
except Exception as e:
    print(f"⚠️  Could not load FID module: {e}")
    print("FID evaluation will be skipped")

In [None]:
# Fast FID (200 samples)
print("Computing FID...")
temp = tempfile.mkdtemp()
real_d, fake_d = f'{temp}/real', f'{temp}/fake'
os.makedirs(real_d); os.makedirs(fake_d)

try:
    for i, p in enumerate(list(Path(DATA_DIR,'monet_jpg').glob('*.jpg'))[:200]):
        shutil.copy(p, f'{real_d}/{i}.jpg')
    
    G.eval(); cnt = 0
    with torch.no_grad():
        for photo, _ in dl:
            if cnt >= 200: break
            fake = G(photo.to(device))
            for img in fake:
                if cnt >= 200: break
                transforms.ToPILImage()(((img.cpu()+1)/2).clamp(0,1)).save(f'{fake_d}/{cnt}.jpg')
                cnt += 1
    
    import torch_fidelity
    fid_score = torch_fidelity.calculate_metrics(input1=fake_d, input2=real_d, cuda=True, fid=True, verbose=False)['frechet_inception_distance']
    print(f"FID: {fid_score:.2f}")
except Exception as e:
    print(f"FID error: {e}"); fid_score = None
finally:
    shutil.rmtree(temp, ignore_errors=True)

## 9. Visual Quality Assessment

In [None]:
# Results
G.eval()
fig, ax = plt.subplots(3, 6, figsize=(18, 9))
batch = next(iter(DataLoader(ds, 6, shuffle=True)))
photos = batch[0].to(device)

with torch.no_grad():
    fakes = G(photos).cpu()

for i in range(6):
    ax[0,i].imshow(((photos[i].cpu()+1)/2).permute(1,2,0).clamp(0,1))
    ax[0,i].set_title('Photo'); ax[0,i].axis('off')
    ax[1,i].imshow(((fakes[i]+1)/2).permute(1,2,0).clamp(0,1))
    ax[1,i].set_title('Generated'); ax[1,i].axis('off')
    real = transforms.Resize((256,256))(Image.open(ds.monets[i]))
    ax[2,i].imshow(real); ax[2,i].set_title('Real Monet'); ax[2,i].axis('off')

plt.suptitle(f'Photo → Monet (FID: {fid_score:.1f})' if fid_score else 'Photo → Monet', fontsize=14)
plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/results.png', dpi=150)
plt.show()

## 10. Final Summary

In [None]:
# More samples
fig, ax = plt.subplots(4, 4, figsize=(12, 12))
batch = next(iter(DataLoader(ds, 16, shuffle=True)))[0].to(device)
with torch.no_grad(): out = G(batch).cpu()
for i, a in enumerate(ax.flat):
    a.imshow(((out[i]+1)/2).permute(1,2,0).clamp(0,1)); a.axis('off')
plt.suptitle('Generated Monet-Style', fontsize=14)
plt.tight_layout(); plt.savefig(f'{OUTPUT_DIR}/samples.png', dpi=150); plt.show()

In [None]:
print(f"Training: {total:.1f} min | FID: {fid_score:.1f}" if fid_score else f"Training: {total:.1f} min")

In [None]:
# Loss curves
plt.figure(figsize=(8,4))
plt.plot(G_losses, label='G'); plt.plot(D_losses, label='D')
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.grid(alpha=0.3)
plt.savefig(f'{OUTPUT_DIR}/loss.png'); plt.show()