# üõçÔ∏è E-Commerce Fashion Image Generation

This notebook walks you through the complete pipeline for training and generating fashion images.

**Two Models:**
1. **Projected GAN** - Fast unconditional generation (~2-4 hours)
2. **Stable Diffusion + LoRA** - Text-conditioned generation (~4-8 hours)

**Hardware:** Optimized for RTX 4060 (8GB VRAM)

---

## üìã Run Order

1. **Environment Setup** - Install dependencies, verify GPU
2. **Data Preparation** - Download/prepare DeepFashion dataset
3. **Projected GAN Training** - Train unconditional GAN
4. **Projected GAN Generation** - Generate random fashion images
5. **Stable Diffusion LoRA Training** - Fine-tune with LoRA
6. **Stable Diffusion LoRA Generation** - Generate from text prompts

---
## 1. Environment Setup

Install dependencies and verify GPU availability.

In [None]:
# Step 1: Install main dependencies (run once)
# Uncomment and run if you haven't installed yet

# !pip install -r requirements.txt

# Step 2: Install xformers (AFTER the above completes)
# This is optional but recommended for 8GB VRAM

# !pip install xformers --index-url https://download.pytorch.org/whl/cu121

In [None]:
# Verify GPU and CUDA
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("‚ö†Ô∏è CUDA not available! Training will be very slow on CPU.")

In [None]:
# Common imports and setup
import os
import sys
import random
import numpy as np
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
import yaml

# Set project root
PROJECT_ROOT = Path.cwd()
print(f"Project root: {PROJECT_ROOT}")

# Add to path
sys.path.insert(0, str(PROJECT_ROOT))

# Set seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Device
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")

---
## 2. Data Preparation

Prepare the DeepFashion dataset for training.

**Options:**
1. **Kaggle** - Download the full dataset (requires API key)
2. **Local** - Use your own fashion images
3. **Sample** - Create folder structure (add images manually)

In [None]:
# Create data directories
DATA_DIR = PROJECT_ROOT / "data"
RAW_DIR = DATA_DIR / "raw"
PROCESSED_DIR = DATA_DIR / "processed"

RAW_DIR.mkdir(parents=True, exist_ok=True)
PROCESSED_DIR.mkdir(parents=True, exist_ok=True)

print(f"Raw data directory: {RAW_DIR}")
print(f"Processed data directory: {PROCESSED_DIR}")

In [None]:
# Option 1: Download from Kaggle (uncomment to use)
# Requires: pip install kaggle && kaggle.json in ~/.kaggle/
# !python data/download_deepfashion.py --source kaggle

# Option 2: Use local images (uncomment and modify path)
# !python data/download_deepfashion.py --source local --local-path "C:/path/to/your/images"

# Option 3: Sample (creates folder structure, you add images manually)
# !python data/download_deepfashion.py --source sample

In [None]:
# Check available images
image_extensions = {'.jpg', '.jpeg', '.png', '.webp'}

def count_images(directory):
    count = 0
    for path in Path(directory).rglob('*'):
        if path.suffix.lower() in image_extensions:
            count += 1
    return count

if RAW_DIR.exists():
    num_images = count_images(RAW_DIR)
    print(f"Found {num_images} images in {RAW_DIR}")
    
    if num_images == 0:
        print("\n‚ö†Ô∏è No images found! Please add images to data/raw/ before proceeding.")
else:
    print(f"‚ö†Ô∏è Directory not found: {RAW_DIR}")

In [None]:
# Prepare dataset for both models (resize, crop, generate captions)
# Creates: 256x256 for GAN, 512x512 for LoRA
!python data/prepare_dataset.py --max-gan-images 5000 --max-lora-images 1000

In [None]:
# Verify prepared datasets
GAN_DATA = PROCESSED_DIR / "projected_gan"
LORA_DATA = PROCESSED_DIR / "lora" / "images"

if GAN_DATA.exists():
    print(f"‚úì Projected GAN dataset: {count_images(GAN_DATA)} images (256x256)")
else:
    print("‚ö†Ô∏è Projected GAN dataset not found")

if LORA_DATA.exists():
    print(f"‚úì LoRA dataset: {count_images(LORA_DATA)} images (512x512)")
else:
    print("‚ö†Ô∏è LoRA dataset not found")

---
## 3. Projected GAN Training

Train Projected GAN for **unconditional** fashion image generation.

‚è±Ô∏è **Expected time:** ~2-4 hours on RTX 4060

In [None]:
# Load GAN configuration
with open('config/projected_gan_config.yaml', 'r') as f:
    gan_config = yaml.safe_load(f)

print("Projected GAN Configuration:")
print(f"  Image size: {gan_config['model']['img_size']}x{gan_config['model']['img_size']}")
print(f"  Batch size: {gan_config['training']['batch_size']}")
print(f"  Total images: {gan_config['training']['total_kimg']}k")
print(f"  Mixed precision: {gan_config['training']['mixed_precision']}")

In [None]:
# Import Projected GAN components
from projected_gan.model import Generator, ProjectedDiscriminator
from projected_gan.train import Trainer, FashionDataset
from torch.utils.data import DataLoader

# Create dataset
gan_dataset = FashionDataset(
    root=str(GAN_DATA),
    img_size=gan_config['model']['img_size'],
    augment=True,
)

gan_dataloader = DataLoader(
    gan_dataset,
    batch_size=gan_config['training']['batch_size'],
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    drop_last=True,
)

print(f"‚úì Dataset size: {len(gan_dataset)} images")
print(f"‚úì Batches per epoch: {len(gan_dataloader)}")

In [None]:
# Create trainer
gan_trainer = Trainer(gan_config, device=DEVICE)

# Print model info
g_params = sum(p.numel() for p in gan_trainer.G.parameters() if p.requires_grad)
d_params = sum(p.numel() for p in gan_trainer.D.parameters() if p.requires_grad)
print(f"Generator parameters: {g_params:,}")
print(f"Discriminator trainable parameters: {d_params:,}")

In [None]:
# Train Projected GAN
# ‚ö†Ô∏è This will take 2-4 hours!
# Tip: Reduce total_kimg in config for faster testing

gan_trainer.train(gan_dataloader)

---
## 4. Projected GAN Generation

Generate random fashion images using the trained GAN.

In [None]:
# Load trained generator
from projected_gan.generate import load_generator, generate_images

GAN_CHECKPOINT = PROJECT_ROOT / 'outputs' / 'projected_gan' / 'checkpoint_final.pt'

if GAN_CHECKPOINT.exists():
    G, _ = load_generator(str(GAN_CHECKPOINT), device=DEVICE)
    print(f"‚úì Loaded generator from {GAN_CHECKPOINT}")
else:
    print(f"‚ö†Ô∏è Checkpoint not found. Using trainer's generator.")
    G = gan_trainer.G

In [None]:
# Generate random fashion images
NUM_SAMPLES = 16
TRUNCATION = 0.7  # Lower = higher quality, less diversity

generated = generate_images(G, num_samples=NUM_SAMPLES, truncation=TRUNCATION, device=DEVICE)

# Display results
fig, axes = plt.subplots(4, 4, figsize=(12, 12))
for i, (ax, img) in enumerate(zip(axes.flatten(), generated)):
    ax.imshow(img.permute(1, 2, 0).numpy())
    ax.axis('off')
plt.suptitle(f'Generated Fashion Images (truncation={TRUNCATION})', fontsize=14)
plt.tight_layout()
plt.show()

---
## 5. Stable Diffusion LoRA Training

Fine-tune Stable Diffusion v1.5 with LoRA for **text-conditioned** generation.

‚è±Ô∏è **Expected time:** ~4-8 hours on RTX 4060

In [None]:
# Clear GPU memory before loading Stable Diffusion
if 'G' in dir(): del G
if 'gan_trainer' in dir(): del gan_trainer
torch.cuda.empty_cache()

print(f"GPU memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

In [None]:
# Load LoRA configuration
with open('config/lora_config.yaml', 'r') as f:
    lora_config = yaml.safe_load(f)

print("LoRA Configuration:")
print(f"  Base model: {lora_config['model']['name']}")
print(f"  LoRA rank: {lora_config['lora']['rank']}")
print(f"  Batch size: {lora_config['training']['batch_size']}")
print(f"  Gradient accumulation: {lora_config['training']['gradient_accumulation_steps']}")
print(f"  Effective batch: {lora_config['training']['batch_size'] * lora_config['training']['gradient_accumulation_steps']}")
print(f"  Epochs: {lora_config['training']['num_train_epochs']}")

In [None]:
# Import LoRA components and create trainer
# ‚ö†Ô∏è This downloads ~4GB model on first run
from stable_diffusion_lora.train_lora import LoRATrainer, FashionLoRADataset

lora_trainer = LoRATrainer(lora_config, device=DEVICE)

In [None]:
# Create LoRA dataset
lora_data_dir = PROCESSED_DIR / 'lora'

lora_dataset = FashionLoRADataset(
    data_dir=str(lora_data_dir),
    tokenizer=lora_trainer.tokenizer,
    resolution=lora_config['data']['resolution'],
)

lora_dataloader = DataLoader(
    lora_dataset,
    batch_size=lora_config['training']['batch_size'],
    shuffle=True,
    num_workers=0,
    pin_memory=True,
)

print(f"‚úì Dataset size: {len(lora_dataset)} images")

In [None]:
# Train LoRA
# ‚ö†Ô∏è This will take 4-8 hours!
# Tip: Reduce num_train_epochs in config for faster testing

lora_trainer.train(lora_dataloader)

---
## 6. Stable Diffusion LoRA Generation

Generate fashion images from text prompts using the fine-tuned LoRA model.

In [None]:
# Clear memory and load generation pipeline
if 'lora_trainer' in dir(): del lora_trainer
torch.cuda.empty_cache()

from stable_diffusion_lora.generate import load_pipeline, generate_images as generate_sd_images

LORA_CHECKPOINT = PROJECT_ROOT / 'outputs' / 'lora' / 'checkpoint-final'

if LORA_CHECKPOINT.exists():
    pipeline = load_pipeline(str(LORA_CHECKPOINT), device=DEVICE)
    print(f"‚úì Loaded LoRA from {LORA_CHECKPOINT}")
else:
    print(f"‚ö†Ô∏è Checkpoint not found: {LORA_CHECKPOINT}")

In [None]:
# Generate from text prompts
PROMPTS = [
    "a high quality fashion photo of an elegant black dress",
    "a fashion photo of a casual white t-shirt, studio lighting",
    "a professional product photo of blue jeans",
    "a fashion photo of a red evening gown, luxury",
]

NEGATIVE = "low quality, blurry, distorted, ugly"

if 'pipeline' in dir():
    images = generate_sd_images(
        pipeline, prompts=PROMPTS, num_inference_steps=30,
        guidance_scale=7.5, negative_prompt=NEGATIVE, seed=42,
    )
    
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    for ax, img, prompt in zip(axes, images, PROMPTS):
        ax.imshow(img)
        ax.axis('off')
        ax.set_title(prompt[:35] + '...', fontsize=10)
    plt.tight_layout()
    plt.show()

In [None]:
# Try your own prompt!
MY_PROMPT = "a fashion photograph of a summer floral dress, bright colors"

if 'pipeline' in dir():
    img = pipeline(MY_PROMPT, negative_prompt=NEGATIVE, num_inference_steps=30).images[0]
    plt.figure(figsize=(8, 8))
    plt.imshow(img)
    plt.axis('off')
    plt.title(MY_PROMPT)
    plt.show()

---
## üéâ Complete!

You've trained and used both models for fashion image generation.

| Model | Output | Best For |
|-------|--------|----------|
| **Projected GAN** | Random 256x256 images | Quick prototyping, diverse outputs |
| **SD + LoRA** | Text-conditioned 512x512 images | Specific product descriptions |

### Checkpoints Saved:
- `outputs/projected_gan/checkpoint_final.pt`
- `outputs/lora/checkpoint-final/`