# üé® Fashion Image Generation with GANs & Diffusion Models

This notebook trains two state-of-the-art generative models on the DeepFashion dataset:

1. **Projected GAN** - Fast unconditional image generation
2. **Stable Diffusion + LoRA** - Text-conditioned image generation

## GPU Auto-Detection
The notebook automatically detects your Kaggle GPU and optimizes settings:
- **T4 (16GB)**: Batch size 16-32
- **P100 (16GB)**: Batch size 16-32  
- **A100 (40GB)**: Batch size 64+

---

## üì¶ 1. Setup & Installation

In [None]:
# Install required packages
!pip install -q timm clean-fid ninja lpips scipy click
!pip install -q diffusers transformers accelerate peft safetensors
!pip install -q xformers
!pip install -q einops rich

print("‚úÖ Packages installed!")

In [None]:
# Imports
import os
import sys
import json
import random
import shutil
from pathlib import Path
from datetime import datetime
from typing import List, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.amp import GradScaler, autocast
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from PIL import Image
from tqdm.auto import tqdm
import timm
import matplotlib.pyplot as plt

# Set seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# ============================================
# GPU Detection & Configuration
# ============================================

def detect_gpu_and_configure():
    """Detect GPU and return optimized configuration."""
    if not torch.cuda.is_available():
        raise RuntimeError("GPU not available! Enable GPU in Kaggle settings.")
    
    gpu_name = torch.cuda.get_device_name(0)
    vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    
    print(f"üñ•Ô∏è GPU: {gpu_name}")
    print(f"üíæ VRAM: {vram_gb:.1f} GB")
    
    config = {'device': 'cuda', 'gpu_name': gpu_name, 'vram_gb': vram_gb}
    
    if 'A100' in gpu_name:
        print("üöÄ A100 detected - Maximum performance mode!")
        config.update({'gan_batch_size': 64, 'lora_batch_size': 8, 'lora_grad_accum': 2, 'num_workers': 4, 'gan_img_size': 512})
    elif 'V100' in gpu_name:
        print("üöÄ V100 detected - High performance mode!")
        config.update({'gan_batch_size': 32, 'lora_batch_size': 4, 'lora_grad_accum': 2, 'num_workers': 4, 'gan_img_size': 256})
    elif 'T4' in gpu_name or 'P100' in gpu_name:
        print("‚ö° T4/P100 detected - Balanced mode")
        config.update({'gan_batch_size': 16, 'lora_batch_size': 2, 'lora_grad_accum': 4, 'num_workers': 2, 'gan_img_size': 256})
    else:
        print(f"üîß Unknown GPU - Using conservative settings")
        config.update({'gan_batch_size': 8, 'lora_batch_size': 1, 'lora_grad_accum': 8, 'num_workers': 2, 'gan_img_size': 256})
    
    torch.backends.cudnn.benchmark = True
    return config

CONFIG = detect_gpu_and_configure()
print(f"\nüìä Configuration:")
for k, v in CONFIG.items():
    print(f"   {k}: {v}")

## üìÇ 2. Dataset Preparation

This section loads the DeepFashion dataset from Kaggle.

**Required**: Add the DeepFashion dataset to your notebook:
1. Go to "Add Data" ‚Üí Search "deepfashion"
2. Select a DeepFashion dataset variant

In [None]:
# Find DeepFashion Dataset
def find_fashion_images(search_dirs=['/kaggle/input']):
    """Find fashion images in Kaggle input directories."""
    extensions = {'.jpg', '.jpeg', '.png', '.webp'}
    all_images = []
    for search_dir in search_dirs:
        search_path = Path(search_dir)
        if not search_path.exists():
            continue
        for path in search_path.rglob('*'):
            if path.suffix.lower() in extensions:
                all_images.append(path)
    print(f"Found {len(all_images)} images in {search_dirs}")
    return all_images

ALL_IMAGES = find_fashion_images()

if len(ALL_IMAGES) == 0:
    print("\n‚ö†Ô∏è No images found! Please add a fashion dataset:")
    print("1. Click 'Add Data' in the right panel")
    print("2. Search for 'deepfashion' or 'fashion'")
    print("3. Add the dataset and re-run this cell")
else:
    print(f"\n‚úÖ Found {len(ALL_IMAGES)} images!")
    print(f"Sample paths: {ALL_IMAGES[:3]}")

In [None]:
# Create output directories and prepare datasets
OUTPUT_DIR = Path('/kaggle/working/outputs')
GAN_DATA_DIR = Path('/kaggle/working/gan_data')
LORA_DATA_DIR = Path('/kaggle/working/lora_data')

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
GAN_DATA_DIR.mkdir(parents=True, exist_ok=True)
(LORA_DATA_DIR / 'images').mkdir(parents=True, exist_ok=True)

# ============================================
# ADJUST THESE VALUES TO USE MORE IMAGES
# ============================================
MAX_GAN_IMAGES = 50000   # Set to None to use ALL images (can be slow to prepare)
MAX_LORA_IMAGES = 5000   # LoRA needs fewer images

def prepare_dataset(images, output_dir, img_size=256, max_images=10000, with_captions=False):
    """Prepare images for training."""
    output_dir = Path(output_dir)
    if with_captions:
        images_dir = output_dir / 'images'
        images_dir.mkdir(parents=True, exist_ok=True)
    else:
        images_dir = output_dir
        images_dir.mkdir(parents=True, exist_ok=True)
    
    images = images[:max_images]
    print(f"Preparing {len(images)} images ({img_size}x{img_size})...")
    
    metadata = []
    captions = [
        "a high quality fashion photograph of clothing, professional product photo",
        "a fashion product image, studio lighting, white background",
        "professional fashion photography, elegant clothing, detailed fabric texture",
    ]
    
    for i, img_path in enumerate(tqdm(images, desc="Processing")):
        try:
            img = Image.open(img_path).convert('RGB')
            ratio = img_size / min(img.size)
            new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio))
            img = img.resize(new_size, Image.LANCZOS)
            left = (img.size[0] - img_size) // 2
            top = (img.size[1] - img_size) // 2
            img = img.crop((left, top, left + img_size, top + img_size))
            
            filename = f'img_{i:06d}.jpg'
            img.save(images_dir / filename, quality=95)
            
            if with_captions:
                metadata.append({'file_name': filename, 'text': random.choice(captions)})
        except Exception as e:
            continue
    
    if with_captions:
        with open(output_dir / 'metadata.jsonl', 'w') as f:
            for item in metadata:
                f.write(json.dumps(item) + '\n')
    
    print(f"‚úÖ Dataset ready: {len(list(images_dir.glob('*.jpg')))} images")

# Prepare datasets
if len(list(GAN_DATA_DIR.glob('*.jpg'))) < 100:
    prepare_dataset(ALL_IMAGES, GAN_DATA_DIR, img_size=CONFIG['gan_img_size'], max_images=MAX_GAN_IMAGES)
else:
    print(f"‚úÖ GAN dataset exists: {len(list(GAN_DATA_DIR.glob('*.jpg')))} images")

if not (LORA_DATA_DIR / 'metadata.jsonl').exists():
    prepare_dataset(ALL_IMAGES, LORA_DATA_DIR, img_size=512, max_images=MAX_LORA_IMAGES, with_captions=True)
else:
    print(f"‚úÖ LoRA dataset exists")

---

# üéØ Part 1: Projected GAN Training

Fast unconditional fashion image generation using feature projection from a frozen EfficientNet backbone.

In [None]:
# ============================================
# Projected GAN - Model Architecture
# ============================================

class MappingNetwork(nn.Module):
    def __init__(self, z_dim=256, w_dim=256, num_layers=4):
        super().__init__()
        layers = []
        for i in range(num_layers):
            in_dim = z_dim if i == 0 else w_dim
            layers.extend([nn.Linear(in_dim, w_dim), nn.LeakyReLU(0.2, inplace=True)])
        self.mapping = nn.Sequential(*layers)
    def forward(self, z): return self.mapping(z)

class AdaIN(nn.Module):
    def __init__(self, w_dim, num_features):
        super().__init__()
        self.norm = nn.InstanceNorm2d(num_features, affine=False)
        self.style = nn.Linear(w_dim, num_features * 2)
        self.style.bias.data[:num_features] = 1.0
        self.style.bias.data[num_features:] = 0.0
    def forward(self, x, w):
        style = self.style(w)
        gamma, beta = style.chunk(2, dim=1)
        return gamma.unsqueeze(-1).unsqueeze(-1) * self.norm(x) + beta.unsqueeze(-1).unsqueeze(-1)

class SynthesisBlock(nn.Module):
    """Synthesis block with SPECTRAL NORMALIZATION for stability."""
    def __init__(self, in_ch, out_ch, w_dim, upsample=True):
        super().__init__()
        self.upsample = upsample
        # Spectral normalization prevents generator explosion!
        self.conv1 = nn.utils.spectral_norm(nn.Conv2d(in_ch, out_ch, 3, padding=1))
        self.conv2 = nn.utils.spectral_norm(nn.Conv2d(out_ch, out_ch, 3, padding=1))
        self.adain1, self.adain2 = AdaIN(w_dim, out_ch), AdaIN(w_dim, out_ch)
        self.act = nn.LeakyReLU(0.2, inplace=True)
        self.noise_scale1 = nn.Parameter(torch.zeros(1))  # Learnable noise scale
        self.noise_scale2 = nn.Parameter(torch.zeros(1))
    def forward(self, x, w):
        if self.upsample: 
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        h = self.conv1(x)
        h = h + self.noise_scale1 * torch.randn_like(h)  # Safer noise injection
        x = self.act(self.adain1(h, w))
        h = self.conv2(x)
        h = h + self.noise_scale2 * torch.randn_like(h)
        return self.act(self.adain2(h, w))

class Generator(nn.Module):
    """Generator with spectral normalization for bulletproof training."""
    def __init__(self, z_dim=256, w_dim=256, img_size=256, base_ch=32):
        super().__init__()
        self.z_dim = z_dim
        self.mapping = MappingNetwork(z_dim, w_dim)
        self.const = nn.Parameter(torch.randn(1, base_ch * 16, 4, 4) * 0.01)  # Smaller init
        self.blocks = nn.ModuleList()
        in_ch = base_ch * 16
        for _ in range(int(np.log2(img_size)) - 2):
            out_ch = max(base_ch, in_ch // 2)
            self.blocks.append(SynthesisBlock(in_ch, out_ch, w_dim))
            in_ch = out_ch
        # Spectral norm on final layer too
        self.to_rgb = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(in_ch, 3, 1)), 
            nn.Tanh()
        )
    def forward(self, z):
        w = self.mapping(z)
        x = self.const.repeat(z.shape[0], 1, 1, 1)
        for block in self.blocks: 
            x = block(x, w)
        return self.to_rgb(x)

class ProjectedDiscriminator(nn.Module):
    def __init__(self, backbone='tf_efficientnet_lite0', proj_ch=128):
        super().__init__()
        self.backbone = timm.create_model(backbone, pretrained=True, features_only=True, out_indices=[1, 2, 3])
        for p in self.backbone.parameters(): p.requires_grad = False
        self.backbone.eval()
        with torch.no_grad():
            dims = [f.shape[1] for f in self.backbone(torch.zeros(1, 3, 256, 256))]
        self.projectors = nn.ModuleList([nn.Sequential(nn.Conv2d(d, proj_ch, 1), nn.LeakyReLU(0.2)) for d in dims])
        self.heads = nn.ModuleList([nn.Sequential(nn.Conv2d(proj_ch, proj_ch, 3, padding=1), nn.LeakyReLU(0.2), nn.Conv2d(proj_ch, 1, 1)) for _ in dims])
    def forward(self, x):
        x = ((x + 1) / 2 - torch.tensor([0.485, 0.456, 0.406], device=x.device).view(1, 3, 1, 1)) / torch.tensor([0.229, 0.224, 0.225], device=x.device).view(1, 3, 1, 1)
        features = self.backbone(x)
        return [head(proj(feat)) for feat, proj, head in zip(features, self.projectors, self.heads)]

def hinge_loss_dis(real, fake):
    return sum(torch.mean(F.relu(1 - r)) + torch.mean(F.relu(1 + f)) for r, f in zip(real, fake)) / len(real)

def hinge_loss_gen(fake):
    return -sum(torch.mean(f) for f in fake) / len(fake)

print("‚úÖ GAN models defined!")

In [None]:
# ============================================
# GAN Dataset & Training Function
# ============================================

class FashionDataset(Dataset):
    """Dataset with STRONG augmentation for small datasets."""
    def __init__(self, root, img_size=256):
        self.images = list(Path(root).glob('*.jpg')) + list(Path(root).glob('*.png'))
        # Strong augmentation - critical for small datasets!
        self.transform = transforms.Compose([
            transforms.Resize(int(img_size * 1.15)),   # Slightly larger
            transforms.RandomCrop(img_size),           # Random crop (not center)
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
            transforms.RandomAffine(degrees=8, translate=(0.08, 0.08), scale=(0.95, 1.05)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3),
        ])
        print(f"üìä Dataset: {len(self.images)} images with strong augmentation")
    
    def __len__(self): return len(self.images)
    
    def __getitem__(self, idx):
        try: 
            return self.transform(Image.open(self.images[idx]).convert('RGB'))
        except: 
            return self[random.randint(0, len(self)-1)]

def train_gan(data_dir, output_dir, img_size=256, batch_size=8, total_kimg=500, device='cuda', 
               r1_gamma=0.1, resume_from=None):
    """
    üõ°Ô∏è BULLETPROOF Projected GAN Training
    
    Features:
    - Spectral normalization (in Generator)
    - Very conservative learning rates
    - Tight gradient clipping
    - R1 regularization
    - Automatic checkpointing (resume if crash!)
    - Collapse detection + early stopping
    - LR warmup
    - FP32 for Generator (most stable)
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # =========================================
    # Initialize Models
    # =========================================
    G = Generator(z_dim=256, img_size=img_size).to(device)
    D = ProjectedDiscriminator().to(device)
    
    # EMA Generator for smoother outputs
    G_ema = Generator(z_dim=256, img_size=img_size).to(device)
    G_ema.load_state_dict(G.state_dict())
    G_ema.eval()
    ema_beta = 0.9999  # Slower EMA for stability
    
    # =========================================
    # BULLETPROOF Settings
    # =========================================
    g_lr = 0.0002   # Very conservative
    d_lr = 0.0002
    warmup_steps = 500
    grad_clip = 0.5  # Tight clipping
    r1_interval = 8  # More frequent R1
    checkpoint_interval = 5000
    max_nan_count = 20  # Stop if too many NaNs
    
    opt_G = torch.optim.Adam(G.parameters(), lr=g_lr, betas=(0.0, 0.99))
    opt_D = torch.optim.Adam(D.parameters(), lr=d_lr, betas=(0.0, 0.99))
    
    # Only use scaler for D (G runs in FP32 for stability)
    scaler_D = GradScaler('cuda')
    
    dataset = FashionDataset(data_dir, img_size)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, 
                           num_workers=CONFIG['num_workers'], pin_memory=True, drop_last=True)
    fixed_z = torch.randn(16, 256, device=device)
    
    num_images = len(dataset)
    total_samples = total_kimg * 1000
    num_epochs = total_samples / num_images
    
    # =========================================
    # Resume from checkpoint if available
    # =========================================
    start_step = 0
    if resume_from and Path(resume_from).exists():
        print(f"üìÇ Loading checkpoint: {resume_from}")
        ckpt = torch.load(resume_from, map_location=device)
        G.load_state_dict(ckpt['G'])
        G_ema.load_state_dict(ckpt['G_ema'])
        D.load_state_dict(ckpt['D'])
        opt_G.load_state_dict(ckpt['opt_G'])
        opt_D.load_state_dict(ckpt['opt_D'])
        start_step = ckpt['step']
        print(f"‚úÖ Resumed from step {start_step}")
    
    print(f"\n{'='*55}")
    print(f" üõ°Ô∏è BULLETPROOF GAN Training")
    print(f"{'='*55}")
    print(f"üìä Dataset:       {num_images:,} images")
    print(f"üî¢ Batch size:    {batch_size}")
    print(f"üîÑ Total kimg:    {total_kimg} ({total_samples:,} samples)")
    print(f"üìà Epochs:        {num_epochs:.1f}")
    print(f"‚ö° Learning rate: G={g_lr}, D={d_lr}")
    print(f"‚úÇÔ∏è Grad clip:     {grad_clip}")
    print(f"üîí R1 gamma:      {r1_gamma} (every {r1_interval} steps)")
    print(f"üíæ Checkpoints:   every {checkpoint_interval} steps")
    print(f"üéØ Generator:     FP32 (most stable)")
    print()
    
    step = start_step
    nan_count = 0
    pbar = tqdm(total=total_samples, initial=start_step * batch_size, desc="Training")
    
    collapsed = False
    while step * batch_size < total_kimg * 1000 and not collapsed:
        for real in dataloader:
            real = real.to(device)
            
            # =====================================
            # Learning Rate Warmup
            # =====================================
            if step < warmup_steps:
                lr_scale = (step + 1) / warmup_steps
                for pg in opt_G.param_groups:
                    pg['lr'] = g_lr * lr_scale
                for pg in opt_D.param_groups:
                    pg['lr'] = d_lr * lr_scale
            
            # =====================================
            # Train Discriminator (FP16)
            # =====================================
            opt_D.zero_grad()
            with autocast('cuda'):
                z = torch.randn(batch_size, 256, device=device)
                with torch.no_grad():
                    fake = G(z).detach()
                d_loss = hinge_loss_dis(D(real), D(fake))
            
            if torch.isnan(d_loss) or torch.isinf(d_loss):
                nan_count += 1
                if nan_count >= max_nan_count:
                    tqdm.write(f"\n‚ùå COLLAPSE DETECTED! {nan_count} NaNs. Stopping.")
                    collapsed = True
                    break
                continue
            
            scaler_D.scale(d_loss).backward()
            
            # R1 regularization (more frequent for stability)
            if step % r1_interval == 0 and r1_gamma > 0:
                real_r1 = real.detach().requires_grad_(True)
                with autocast('cuda', enabled=False):
                    d_real = D(real_r1.float())
                    r1_grads = torch.autograd.grad(
                        outputs=sum([o.sum() for o in d_real]),
                        inputs=real_r1,
                        create_graph=True
                    )[0]
                    r1_penalty = r1_grads.pow(2).sum([1,2,3]).mean() * r1_gamma * 0.5
                scaler_D.scale(r1_penalty).backward()
            
            scaler_D.unscale_(opt_D)
            torch.nn.utils.clip_grad_norm_(D.parameters(), grad_clip)
            scaler_D.step(opt_D)
            scaler_D.update()
            
            # =====================================
            # Train Generator (FP32 for stability!)
            # =====================================
            opt_G.zero_grad()
            
            # NO autocast - run G in full FP32
            z = torch.randn(batch_size, 256, device=device)
            fake = G(z)
            
            # D can still use FP16 for forward
            with autocast('cuda'):
                g_loss = hinge_loss_gen(D(fake))
            
            if torch.isnan(g_loss) or torch.isinf(g_loss):
                nan_count += 1
                if nan_count >= max_nan_count:
                    tqdm.write(f"\n‚ùå COLLAPSE DETECTED! {nan_count} NaNs. Stopping.")
                    collapsed = True
                    break
                opt_G.zero_grad()
                continue
            else:
                nan_count = 0  # Reset on successful step
            
            g_loss.backward()
            torch.nn.utils.clip_grad_norm_(G.parameters(), grad_clip)
            opt_G.step()
            
            # Update EMA Generator
            with torch.no_grad():
                for p_ema, p in zip(G_ema.parameters(), G.parameters()):
                    p_ema.data.mul_(ema_beta).add_(p.data, alpha=1 - ema_beta)
            
            step += 1
            pbar.update(batch_size)
            
            # =====================================
            # Logging
            # =====================================
            if step % 100 == 0:
                curr_epoch = (step * batch_size) / num_images
                pbar.set_postfix({
                    'Ep': f'{curr_epoch:.1f}',
                    'D': f'{d_loss.item():.3f}', 
                    'G': f'{g_loss.item():.3f}'
                })
            
            # =====================================
            # Save samples every 2500 steps
            # =====================================
            if step % 2500 == 0 and step > 0:
                with torch.no_grad():
                    samples = G_ema(fixed_z)
                    samples = torch.clamp((samples + 1) / 2, 0, 1)
                    save_image(samples, output_dir / f'samples_{step:08d}.png', nrow=4)
                curr_epoch = (step * batch_size) / num_images
                tqdm.write(f"   üñºÔ∏è Saved samples at step {step} (epoch {curr_epoch:.1f})")
            
            # =====================================
            # Checkpoint every N steps (CRITICAL!)
            # =====================================
            if step % checkpoint_interval == 0 and step > 0:
                ckpt_path = output_dir / f'checkpoint_{step}.pt'
                torch.save({
                    'step': step,
                    'G': G.state_dict(),
                    'G_ema': G_ema.state_dict(),
                    'D': D.state_dict(),
                    'opt_G': opt_G.state_dict(),
                    'opt_D': opt_D.state_dict(),
                }, ckpt_path)
                tqdm.write(f"   üíæ Checkpoint saved: {ckpt_path.name}")
            
            if step * batch_size >= total_kimg * 1000:
                break
    
    pbar.close()
    
    # =========================================
    # Save Final Model
    # =========================================
    final_epochs = (step * batch_size) / num_images
    
    torch.save({
        'step': step,
        'G': G.state_dict(),
        'G_ema': G_ema.state_dict(),
        'D': D.state_dict(),
        'opt_G': opt_G.state_dict(),
        'opt_D': opt_D.state_dict(),
    }, output_dir / 'generator_final.pt')
    
    # Generate final samples
    with torch.no_grad():
        samples = G_ema(fixed_z)
        samples = torch.clamp((samples + 1) / 2, 0, 1)
        save_image(samples, output_dir / 'samples_final.png', nrow=4)
    
    if collapsed:
        print(f"\n‚ö†Ô∏è Training stopped early due to collapse at step {step}")
        print(f"   üí° Try reducing learning rate or increasing R1 gamma")
    else:
        print(f"\n‚úÖ GAN training complete!")
    
    print(f"   üìä Total steps: {step:,}, Epochs: {final_epochs:.1f}")
    print(f"   üíæ Saved to {output_dir}")
    print(f"   üîÑ To resume: set resume_from='{output_dir}/checkpoint_*.pt'")
    
    return G_ema

print("‚úÖ GAN training function defined!")

In [None]:
# ============================================
# üõ°Ô∏è BULLETPROOF GAN TRAINING
# ============================================
# 
# TRAINING GUIDE for 10k images (with bulletproof settings):
# ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
# ‚îÇ total_kimg   ‚îÇ Epochs  ‚îÇ Est. Time    ‚îÇ Quality        ‚îÇ
# ‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
# ‚îÇ 500          ‚îÇ 50      ‚îÇ ~2.5hr       ‚îÇ Decent         ‚îÇ
# ‚îÇ 1000         ‚îÇ 100     ‚îÇ ~5hr         ‚îÇ Good           ‚îÇ
# ‚îÇ 1500         ‚îÇ 150     ‚îÇ ~7.5hr       ‚îÇ Very Good      ‚îÇ
# ‚îÇ 2000         ‚îÇ 200     ‚îÇ ~10hr        ‚îÇ Best           ‚îÇ
# ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
# 
# Note: Bulletproof settings are ~20% slower but WON'T CRASH!
# Checkpoints saved every 5000 steps - you can resume if interrupted.

GAN_OUTPUT_DIR = OUTPUT_DIR / 'projected_gan'

# Use smaller batch for stability
gan_batch = min(CONFIG['gan_batch_size'], 8)

# To resume from a checkpoint, set this:
RESUME_FROM = None  # or: GAN_OUTPUT_DIR / 'checkpoint_5000.pt'

generator = train_gan(
    data_dir=GAN_DATA_DIR,
    output_dir=GAN_OUTPUT_DIR,
    img_size=CONFIG['gan_img_size'],
    batch_size=gan_batch,
    total_kimg=1000,     # 100 epochs - good balance
    device=CONFIG['device'],
    r1_gamma=0.1,        # Gentle R1 regularization
    resume_from=RESUME_FROM,
)

# Display results
if (GAN_OUTPUT_DIR / 'samples_final.png').exists():
    plt.figure(figsize=(12, 12))
    plt.imshow(Image.open(GAN_OUTPUT_DIR / 'samples_final.png'))
    plt.title('Generated Fashion Images (Projected GAN)', fontsize=14)
    plt.axis('off')
    plt.show()

---

# üéØ Part 2: Stable Diffusion + LoRA Training

Fine-tune Stable Diffusion v1.5 with LoRA for text-conditioned fashion image generation.

In [None]:
# ============================================
# Stable Diffusion LoRA Imports & Setup
# ============================================

from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, StableDiffusionPipeline
from diffusers.optimization import get_scheduler
from peft import LoraConfig, get_peft_model

class LoRADataset(Dataset):
    def __init__(self, data_dir, tokenizer, resolution=512):
        self.data_dir = Path(data_dir)
        self.tokenizer = tokenizer
        self.resolution = resolution
        self.samples = []
        if (self.data_dir / 'metadata.jsonl').exists():
            with open(self.data_dir / 'metadata.jsonl') as f:
                for line in f:
                    item = json.loads(line)
                    if (self.data_dir / 'images' / item['file_name']).exists():
                        self.samples.append({'path': self.data_dir / 'images' / item['file_name'], 'text': item['text']})
        print(f"Loaded {len(self.samples)} samples")
    
    def __len__(self): return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        img = Image.open(sample['path']).convert('RGB').resize((self.resolution, self.resolution), Image.LANCZOS)
        if random.random() > 0.5: img = img.transpose(Image.FLIP_LEFT_RIGHT)
        img = torch.from_numpy(np.array(img).astype(np.float32) / 127.5 - 1.0).permute(2, 0, 1)
        tokens = self.tokenizer(sample['text'], max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt")
        return {'pixel_values': img, 'input_ids': tokens.input_ids.squeeze(0)}

print("‚úÖ LoRA imports ready!")

In [None]:
# ============================================
# LoRA Training Function
# ============================================

def train_lora(data_dir, output_dir, batch_size=2, grad_accum=4, num_epochs=30, lr=1e-4, lora_rank=64, device='cuda'):
    """Train Stable Diffusion with LoRA."""
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    model_id = "runwayml/stable-diffusion-v1-5"
    
    print(f"\n{'='*50}\n Loading Stable Diffusion v1.5\n{'='*50}")
    tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
    vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device, dtype=torch.float16)
    unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(device, dtype=torch.float16)
    noise_scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
    
    vae.requires_grad_(False); text_encoder.requires_grad_(False); unet.requires_grad_(False)
    
    print(f"Applying LoRA (rank={lora_rank})...")
    lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_rank, init_lora_weights="gaussian", target_modules=["to_q", "to_k", "to_v", "to_out.0"])
    unet = get_peft_model(unet, lora_config)
    unet.print_trainable_parameters()
    unet.enable_gradient_checkpointing()
    try: unet.enable_xformers_memory_efficient_attention(); print("‚úì xformers enabled")
    except: pass
    
    optimizer = torch.optim.AdamW(unet.parameters(), lr=lr)
    dataset = LoRADataset(data_dir, tokenizer)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
    max_steps = num_epochs * len(dataloader) // grad_accum
    lr_scheduler = get_scheduler("cosine", optimizer=optimizer, num_warmup_steps=100, num_training_steps=max_steps)
    
    print(f"\n{'='*50}\n Starting LoRA Training\n{'='*50}")
    print(f"Epochs: {num_epochs}, Batch: {batch_size}, Grad accum: {grad_accum}, Effective: {batch_size*grad_accum}\n")
    
    global_step = 0
    pbar = tqdm(range(max_steps), desc="Training")
    unet.train()
    
    for epoch in range(num_epochs):
        for step, batch in enumerate(dataloader):
            pixels = batch['pixel_values'].to(device, dtype=torch.float16)
            ids = batch['input_ids'].to(device)
            
            with torch.no_grad():
                latents = vae.encode(pixels).latent_dist.sample() * vae.config.scaling_factor
                encoder_hidden = text_encoder(ids)[0]
            
            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()
            noisy = noise_scheduler.add_noise(latents, noise, timesteps)
            
            with torch.autocast('cuda', dtype=torch.float16):
                pred = unet(noisy, timesteps, encoder_hidden).sample
            
            loss = F.mse_loss(pred.float(), noise.float()) / grad_accum
            loss.backward()
            
            if (step + 1) % grad_accum == 0:
                torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
                optimizer.step(); lr_scheduler.step(); optimizer.zero_grad()
                global_step += 1; pbar.update(1)
                if global_step % 10 == 0: pbar.set_postfix({'loss': f'{loss.item()*grad_accum:.4f}', 'lr': f'{lr_scheduler.get_last_lr()[0]:.2e}'})
                if global_step >= max_steps: break
        if global_step >= max_steps: break
    
    pbar.close()
    unet.save_pretrained(output_dir / 'checkpoint-final')
    print(f"\n‚úÖ LoRA training complete! Saved to {output_dir / 'checkpoint-final'}")
    return unet, vae, text_encoder, tokenizer

print("‚úÖ LoRA training function defined!")

In [None]:
# ============================================
# TRAIN LoRA
# ============================================

LORA_OUTPUT_DIR = OUTPUT_DIR / 'lora'

unet, vae, text_encoder, tokenizer = train_lora(
    data_dir=LORA_DATA_DIR,
    output_dir=LORA_OUTPUT_DIR,
    batch_size=CONFIG['lora_batch_size'],
    grad_accum=CONFIG['lora_grad_accum'],
    num_epochs=30,  # Adjust based on time
    lr=1e-4,
    lora_rank=64,
    device=CONFIG['device'],
)

In [None]:
# ============================================
# Generate Images with LoRA
# ============================================

def generate_images(unet, vae, text_encoder, tokenizer, prompts, output_dir, device='cuda'):
    """Generate images using trained LoRA."""
    output_dir = Path(output_dir)
    
    pipeline = StableDiffusionPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5", unet=unet, text_encoder=text_encoder, vae=vae, torch_dtype=torch.float16
    ).to(device)
    pipeline.safety_checker = None
    
    print("\nGenerating images...")
    images = []
    for i, prompt in enumerate(prompts):
        print(f"  [{i+1}/{len(prompts)}] {prompt[:50]}...")
        img = pipeline(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
        img.save(output_dir / f'generated_{i:03d}.png')
        images.append(img)
    
    # Display grid
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    for ax, img, prompt in zip(axes.flatten(), images, prompts):
        ax.imshow(img); ax.set_title(prompt[:40]+'...', fontsize=9); ax.axis('off')
    plt.suptitle('Generated Fashion Images (SD + LoRA)', fontsize=14)
    plt.tight_layout()
    plt.savefig(output_dir / 'generated_grid.png', dpi=150)
    plt.show()
    
    del pipeline; torch.cuda.empty_cache()
    return images

# Test prompts
prompts = [
    "a high quality fashion photograph of an elegant red dress, studio lighting",
    "professional product photo of a black leather jacket, minimalist background",
    "fashion photography of blue denim jeans, white background, sharp focus",
    "luxury fashion photograph of a silk blouse, soft lighting, professional",
    "high-end fashion product shot of designer sneakers, clean background",
    "professional fashion photograph of a wool coat, autumn fashion, detailed",
]

generated = generate_images(unet, vae, text_encoder, tokenizer, prompts, LORA_OUTPUT_DIR, CONFIG['device'])

---

## üì¶ Save & Download Results

In [None]:
# ============================================
# Summary & Download
# ============================================
import zipfile

print("\n" + "="*60)
print(" üéâ Training Complete!")
print("="*60)

print("\nüìÅ Output Files:")
print(f"\n  Projected GAN: {GAN_OUTPUT_DIR}")
for f in sorted(GAN_OUTPUT_DIR.glob('*'))[:5]:
    print(f"    - {f.name}")

print(f"\n  Stable Diffusion LoRA: {LORA_OUTPUT_DIR}")
for f in sorted(LORA_OUTPUT_DIR.glob('*'))[:5]:
    print(f"    - {f.name}")

# Create zip files for download
def zip_dir(path, zip_name):
    zip_path = Path('/kaggle/working') / zip_name
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as z:
        for f in Path(path).rglob('*'):
            if f.is_file(): z.write(f, f.relative_to(path))
    print(f"Created {zip_path} ({zip_path.stat().st_size / 1e6:.1f} MB)")

zip_dir(GAN_OUTPUT_DIR, 'gan_results.zip')
zip_dir(LORA_OUTPUT_DIR, 'lora_results.zip')

print("\n" + "="*60)
print(" Download from 'Output' tab on the right ‚Üí")
print("="*60)