# Stable Diffusion v1.5 Fine-Tuning Pipeline
Complete text-to-image generation fine-tuning using SD v1.5 on COCO + Flickr30k datasets

## 1. Setup and Imports

In [None]:
import os
import sys
import json
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from dataclasses import dataclass
from typing import Optional, Tuple
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Transformers and Diffusers
from transformers import CLIPTokenizer, CLIPTextModel
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from diffusers.optimization import get_cosine_schedule_with_warmup

# PyTorch utilities
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler

# Check GPU availability
USE_GPU = torch.cuda.is_available()
DEVICE = torch.device('cuda:0' if USE_GPU else 'cpu')
NUM_GPUS = torch.cuda.device_count() if USE_GPU else 0

print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {USE_GPU}")
print(f"Device: {DEVICE}")
print(f"Number of GPUs: {NUM_GPUS}")
if USE_GPU:
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Configuration
**Model**: Stable Diffusion v1.5 (860M parameters)
- Training time: 20-40 hours per epoch (on single GPU)
- Memory requirement: 8GB+ VRAM
- Image resolution: 512x512 pixels

In [None]:
@dataclass
class Config:
    # Model configuration
    model_id: str = "runwayml/stable-diffusion-v1-5"
    image_size: int = 512
    center_crop: bool = True
    random_flip: bool = True
    
    # Training configuration
    train_batch_size: int = 1  # Reduced for SD v1.5 (860M params)
    eval_batch_size: int = 1
    gradient_accumulation_steps: int = 4  # Effective batch size = 4
    learning_rate: float = 1e-4
    weight_decay: float = 0.01
    num_epochs: int = 10
    warmup_steps: int = 500
    max_grad_norm: float = 1.0
    use_mixed_precision: bool = True
    
    # Fine-tuning strategy
    fine_tune_text_encoder: bool = False  # Keep frozen to save memory
    fine_tune_unet: bool = True
    fine_tune_vae: bool = False  # Keep frozen
    
    # Data configuration
    num_workers: int = 4
    seed: int = 42
    
    # Paths
    checkpoints_dir: str = "./models/checkpoints"
    results_dir: str = "./results"
    log_dir: str = "./logs"

config = Config()

# Create directories
os.makedirs(config.checkpoints_dir, exist_ok=True)
os.makedirs(config.results_dir, exist_ok=True)
os.makedirs(config.log_dir, exist_ok=True)

print("Configuration:")
print(f"  Model: {config.model_id}")
print(f"  Batch size: {config.train_batch_size} (effective: {config.train_batch_size * config.gradient_accumulation_steps})")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Epochs: {config.num_epochs}")
print(f"  Fine-tune UNet: {config.fine_tune_unet}")
print(f"  Fine-tune Text Encoder: {config.fine_tune_text_encoder}")

## 3. Load Data

In [None]:
# Add src directory to path for dataloaders
sys.path.insert(0, '../src')

from dataloaders_text import caption_dataset

# Initialize dataset
dataloader_handler = caption_dataset()

# Get train and validation dataloaders
train_dataloader = dataloader_handler.get_dataloader(
    split='train',
    batch_size=config.train_batch_size,
    num_workers=config.num_workers,
    shuffle=True
)

val_dataloader = dataloader_handler.get_dataloader(
    split='val',
    batch_size=config.eval_batch_size,
    num_workers=config.num_workers,
    shuffle=False
)

print(f"Train dataloader length: {len(train_dataloader)}")
print(f"Val dataloader length: {len(val_dataloader)}")

# Get sample batch
sample_batch = next(iter(train_dataloader))
print(f"\nSample batch:")
for key, value in sample_batch.items():
    if isinstance(value, torch.Tensor):
        print(f"  {key}: {value.shape} - {value.dtype}")
    elif isinstance(value, list):
        print(f"  {key}: list of {len(value)} items")
    else:
        print(f"  {key}: {type(value)}")

## 4. Load Pretrained Model Components

In [None]:
print(f"Loading Stable Diffusion v1.5 from: {config.model_id}")
print("This may take a few minutes on first run...\n")

try:
    # Load tokenizer
    tokenizer = CLIPTokenizer.from_pretrained(
        config.model_id,
        subfolder="tokenizer"
    )
    print("✓ Tokenizer loaded")
    
    # Load text encoder
    text_encoder = CLIPTextModel.from_pretrained(
        config.model_id,
        subfolder="text_encoder",
        torch_dtype=torch.float16 if USE_GPU else torch.float32
    ).to(DEVICE)
    print("✓ Text encoder loaded")
    
    # Load VAE
    vae = AutoencoderKL.from_pretrained(
        config.model_id,
        subfolder="vae",
        torch_dtype=torch.float16 if USE_GPU else torch.float32
    ).to(DEVICE)
    print("✓ VAE loaded")
    
    # Load UNet
    unet = UNet2DConditionModel.from_pretrained(
        config.model_id,
        subfolder="unet",
        torch_dtype=torch.float16 if USE_GPU else torch.float32
    ).to(DEVICE)
    print("✓ UNet loaded")
    
    # Load scheduler
    scheduler = DDPMScheduler.from_pretrained(
        config.model_id,
        subfolder="scheduler"
    )
    print("✓ Scheduler loaded")
    
    print(f"\n✓ All components loaded successfully!")
    print(f"\nModel Summary:")
    print(f"  Text Encoder: {sum(p.numel() for p in text_encoder.parameters()) / 1e6:.1f}M parameters")
    print(f"  VAE: {sum(p.numel() for p in vae.parameters()) / 1e6:.1f}M parameters")
    print(f"  UNet: {sum(p.numel() for p in unet.parameters()) / 1e6:.1f}M parameters")
    
except Exception as e:
    print(f"✗ Error loading model: {e}")
    print(f"Attempting to clear cache and retry...")
    import shutil
    hf_cache = Path.home() / '.cache' / 'huggingface' / 'hub'
    if hf_cache.exists():
        shutil.rmtree(hf_cache)
        print("Cache cleared. Please re-run this cell.")

## 5. Fine-Tuning Wrapper Class

In [None]:
class FineTuningWrapper(torch.nn.Module):
    """Wrapper class for fine-tuning Stable Diffusion v1.5"""
    
    def __init__(
        self,
        unet,
        text_encoder,
        vae,
        tokenizer,
        scheduler,
        fine_tune_text_encoder: bool = False,
        device: str = 'cuda:0'
    ):
        super().__init__()
        self.unet = unet
        self.text_encoder = text_encoder
        self.vae = vae
        self.tokenizer = tokenizer
        self.scheduler = scheduler
        self.device = device
        self.fine_tune_text_encoder = fine_tune_text_encoder
        
        # Freeze components we're not fine-tuning
        if not fine_tune_text_encoder:
            self.text_encoder.requires_grad_(False)
        self.vae.requires_grad_(False)
        self.scheduler.requires_grad_(False)
        
    def encode_text(self, captions):
        """Encode text captions to embeddings"""
        tokens = self.tokenizer(
            captions,
            padding="max_length",
            max_length=77,
            truncation=True,
            return_tensors="pt"
        )
        token_ids = tokens['input_ids'].to(self.device)
        attention_mask = tokens['attention_mask'].to(self.device)
        
        with torch.no_grad():
            text_embeddings = self.text_encoder(
                token_ids,
                attention_mask=attention_mask
            )[0]
        return text_embeddings
    
    def encode_images(self, images):
        """Encode images to latent space using VAE"""
        images = images.to(self.device)
        with torch.no_grad():
            latents = self.vae.encode(images).latent_dist.sample()
            latents = latents * 0.18215  # VAE scaling factor
        return latents
    
    def forward(
        self,
        images,
        captions,
        timesteps=None
    ):
        """Forward pass for training"""
        # Encode images
        latents = self.encode_images(images)
        batch_size = latents.shape[0]
        
        # Encode text
        text_embeddings = self.encode_text(captions)
        
        # Sample random timesteps
        if timesteps is None:
            timesteps = torch.randint(
                0,
                self.scheduler.config.num_train_timesteps,
                (batch_size,),
                device=latents.device
            ).long()
        
        # Sample noise
        noise = torch.randn_like(latents)
        
        # Add noise to latents (forward diffusion process)
        noisy_latents = self.scheduler.add_noise(
            latents,
            noise,
            timesteps
        )
        
        # Predict noise with UNet
        model_pred = self.unet(
            noisy_latents,
            timesteps,
            text_embeddings
        ).sample
        
        # Compute loss (MSE between predicted and actual noise)
        loss = torch.nn.functional.mse_loss(model_pred, noise, reduction="mean")
        
        return loss

# Create wrapper instance
wrapper = FineTuningWrapper(
    unet=unet,
    text_encoder=text_encoder,
    vae=vae,
    tokenizer=tokenizer,
    scheduler=scheduler,
    fine_tune_text_encoder=config.fine_tune_text_encoder,
    device=str(DEVICE)
)

print("✓ Fine-tuning wrapper created")

## 6. Setup Optimizer and Scheduler

In [None]:
# Get trainable parameters
trainable_params = list(unet.parameters())
if config.fine_tune_text_encoder:
    trainable_params.extend(text_encoder.parameters())

print(f"Trainable parameters: {sum(p.numel() for p in trainable_params) / 1e6:.1f}M")

# Create optimizer
optimizer = AdamW(
    trainable_params,
    lr=config.learning_rate,
    weight_decay=config.weight_decay,
    betas=(0.9, 0.999)
)

# Create learning rate scheduler
total_steps = len(train_dataloader) * config.num_epochs // config.gradient_accumulation_steps
scheduler_lr = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=config.warmup_steps,
    num_training_steps=total_steps
)

print(f"Optimizer: AdamW")
print(f"Learning rate: {config.learning_rate}")
print(f"Total training steps: {total_steps}")
print(f"Warmup steps: {config.warmup_steps}")

# Mixed precision scaler
scaler = GradScaler() if config.use_mixed_precision else None
print(f"Mixed precision: {config.use_mixed_precision}")

## 7. Training Loop

In [None]:
# Training setup
train_losses = []
val_losses = []
global_step = 0

print(f"Starting training...")
print(f"Epochs: {config.num_epochs}")
print(f"Steps per epoch: {len(train_dataloader)}")
print(f"Total steps: {len(train_dataloader) * config.num_epochs}")
print(f"Device: {DEVICE}\n")
print("="*60)

for epoch in range(config.num_epochs):
    print(f"\nEpoch {epoch + 1}/{config.num_epochs}")
    print("-" * 60)
    
    # Training phase
    unet.train()
    if config.fine_tune_text_encoder:
        text_encoder.train()
    
    epoch_loss = 0.0
    progress_bar = tqdm(train_dataloader, desc="Training", ncols=80)
    
    for batch_idx, batch in enumerate(progress_bar):
        try:
            images = batch['images'].to(DEVICE)
            captions = batch['captions']
            
            with autocast(enabled=config.use_mixed_precision and USE_GPU):
                loss = wrapper(images, captions)
            
            # Backward pass with gradient accumulation
            if config.use_mixed_precision and USE_GPU:
                scaler.scale(loss / config.gradient_accumulation_steps).backward()
            else:
                (loss / config.gradient_accumulation_steps).backward()
            
            # Gradient accumulation
            if (batch_idx + 1) % config.gradient_accumulation_steps == 0:
                if config.use_mixed_precision and USE_GPU:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(trainable_params, config.max_grad_norm)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(trainable_params, config.max_grad_norm)
                    optimizer.step()
                
                optimizer.zero_grad()
                scheduler_lr.step()
                global_step += 1
            
            epoch_loss += loss.item()
            progress_bar.set_postfix({"loss": loss.item():.4f})
            
        except Exception as e:
            print(f"\nError in batch {batch_idx}: {e}")
            continue
    
    avg_epoch_loss = epoch_loss / len(train_dataloader)
    train_losses.append(avg_epoch_loss)
    print(f"\nEpoch {epoch + 1} - Average Loss: {avg_epoch_loss:.4f}")
    
    # Save checkpoint
    checkpoint_path = os.path.join(config.checkpoints_dir, f"checkpoint-epoch-{epoch + 1}")
    os.makedirs(checkpoint_path, exist_ok=True)
    unet.save_pretrained(os.path.join(checkpoint_path, "unet"))
    if config.fine_tune_text_encoder:
        text_encoder.save_pretrained(os.path.join(checkpoint_path, "text_encoder"))
    print(f"✓ Checkpoint saved: {checkpoint_path}")

print("\n" + "="*60)
print(f"Training completed!")
print(f"Final loss: {train_losses[-1]:.4f}")

## 8. Plot Training Loss

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(range(1, len(train_losses) + 1), train_losses, marker='o', linewidth=2, markersize=8)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('Training Loss Over Epochs', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(config.results_dir, 'training_loss.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"Loss plot saved to: {os.path.join(config.results_dir, 'training_loss.png')}")

## 9. Load Fine-Tuned Model for Inference

In [None]:
# Create pipeline with fine-tuned model
checkpoint_dir = os.path.join(config.checkpoints_dir, f"checkpoint-epoch-{config.num_epochs}")

if os.path.exists(checkpoint_dir):
    print(f"Loading fine-tuned model from: {checkpoint_dir}")
    
    # Load fine-tuned UNet
    unet_finetuned = UNet2DConditionModel.from_pretrained(
        os.path.join(checkpoint_dir, "unet")
    )
    
    # Load fine-tuned text encoder if available
    text_encoder_finetuned = CLIPTextModel.from_pretrained(
        os.path.join(checkpoint_dir, "text_encoder")
    ) if os.path.exists(os.path.join(checkpoint_dir, "text_encoder")) else CLIPTextModel.from_pretrained(
        config.model_id, subfolder="text_encoder"
    )
    
    # Create pipeline with fine-tuned components
    pipe_finetuned = StableDiffusionPipeline.from_pretrained(
        config.model_id,
        unet=unet_finetuned,
        text_encoder=text_encoder_finetuned,
        torch_dtype=torch.float32
    ).to(DEVICE)
    
    print("✓ Fine-tuned pipeline created successfully")
else:
    print(f"Checkpoint directory not found: {checkpoint_dir}")
    print("Using original pretrained model for inference")
    
    pipe_finetuned = StableDiffusionPipeline.from_pretrained(
        config.model_id,
        torch_dtype=torch.float32
    ).to(DEVICE)
    
    print("✓ Original pipeline loaded")

## 10. Test Inference

In [None]:
# Generate test images with different prompts
test_prompts = [
    "a dog playing in the park",
    "a woman reading a book in a cafe",
    "a sunset over mountains",
    "children building a sandcastle at the beach",
    "a cat sitting on a windowsill"
]

print(f"Generating test images with fine-tuned model on device: {DEVICE}\n")
print("="*60)

generated_images = []

for prompt in test_prompts:
    print(f"\nPrompt: {prompt}")
    
    with torch.no_grad():
        images = pipe_finetuned(
            prompt,
            num_inference_steps=50,
            guidance_scale=7.5,
            height=512,
            width=512,
            generator=torch.Generator(device=DEVICE).manual_seed(42)
        ).images
    
    generated_images.append(images[0])
    print(f"✓ Generated 512x512 image successfully")

print("\n" + "="*60)
print(f"Generated {len(generated_images)} test images")
print("\nSample outputs saved (showing first 3):")

# Display first 3 images
fig, axes = plt.subplots(1, min(3, len(generated_images)), figsize=(15, 5))
if len(generated_images) == 1:
    axes = [axes]
    
for idx, (ax, img) in enumerate(zip(axes, generated_images[:3])):
    ax.imshow(img)
    ax.set_title(test_prompts[idx][:30] + "...", fontsize=10)
    ax.axis('off')
    
plt.tight_layout()
plt.savefig(os.path.join(config.results_dir, "test_inference.png"), dpi=150, bbox_inches='tight')
plt.show()

print("\nTest images saved to: " + os.path.join(config.results_dir, "test_inference.png"))

## 11. Save Model Configuration and Results
Save the training configuration and results for reproducibility.

In [None]:
# Save configuration
config_dict = {
    'model_id': config.model_id,
    'image_size': config.image_size,
    'batch_size': config.train_batch_size,
    'effective_batch_size': config.train_batch_size * config.gradient_accumulation_steps,
    'learning_rate': config.learning_rate,
    'epochs': config.num_epochs,
    'warmup_steps': config.warmup_steps,
    'fine_tune_text_encoder': config.fine_tune_text_encoder,
    'fine_tune_unet': config.fine_tune_unet,
    'use_mixed_precision': config.use_mixed_precision,
    'device': str(DEVICE),
    'final_loss': float(train_losses[-1]) if train_losses else None,
}

with open(os.path.join(config.results_dir, 'training_config.json'), 'w') as f:
    json.dump(config_dict, f, indent=2)

print("✓ Configuration saved to: " + os.path.join(config.results_dir, 'training_config.json'))
print("\nTraining Summary:")
for key, value in config_dict.items():
    print(f"  {key}: {value}")

## 12. Clean Up and Summary

In [None]:
print("="*60)
print("TRAINING SUMMARY")
print("="*60)
print(f"\nModel: {config.model_id}")
print(f"Training epochs: {config.num_epochs}")
print(f"Total samples processed: {len(train_dataloader) * config.train_batch_size * config.num_epochs}")
print(f"Initial loss: {train_losses[0]:.4f}")
print(f"Final loss: {train_losses[-1]:.4f}")
print(f"Loss reduction: {(1 - train_losses[-1]/train_losses[0])*100:.1f}%")
print(f"\nCheckpoints saved in: {config.checkpoints_dir}")
print(f"Results saved in: {config.results_dir}")
print(f"Logs saved in: {config.log_dir}")
print("\n" + "="*60)

## 13. References and Documentation
- **Stable Diffusion v1.5**: [Runaway ML](https://huggingface.co/runwayml/stable-diffusion-v1-5)
- **Paper**: [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)
- **Datasets**: [COCO 2014/2017](https://cocodataset.org/), [Flickr30K](https://shannon.cs.illinois.edu/DenotationGraph/)
- **Frameworks**: PyTorch 2.0+, Hugging Face Transformers & Diffusers

### Key Parameters Explained
- **Batch Size**: Reduced to 1 for Stable Diffusion v1.5 due to 860M parameters
- **Gradient Accumulation**: 4 steps to achieve effective batch size of 4 while saving memory
- **Learning Rate**: 1e-4 following original SD paper recommendations
- **Epochs**: 10 for reasonable training time (~20-40 hours on single GPU)

### Troubleshooting
- **Out of Memory**: Reduce `train_batch_size` to 1 or disable `use_mixed_precision`
- **Slow Training**: Use multiple GPUs with DistributedDataParallel
- **Model Download Failed**: Check internet connection and HuggingFace cache folder
- **Low Inference Quality**: Increase `num_inference_steps` to 50-100 (slower but better quality)