# Notebook 29: Few-Step Image Generation

## Inference Engineering Course

---

## Overview

Diffusion models generate stunning images but require many iterative denoising steps (typically 20-50), making them slow. **Few-step generation** techniques reduce this to just 1-8 steps while maintaining quality.

### The Core Trade-off

```
Steps:  50     25     8      4      1
Speed:  Slow   ──────────────────►  Fast
Quality: Best  ──────────────────►  Degraded
```

### Key Techniques

| Technique | Steps | Quality | How It Works |
|-----------|-------|---------|-------------|
| Standard DDPM | 50-1000 | Best | Full denoising chain |
| DDIM | 20-50 | Great | Deterministic sampling |
| LCM (Latent Consistency) | 2-8 | Good | Consistency distillation |
| LCM-LoRA | 4-8 | Good | LoRA adapter for consistency |
| Consistency Models | 1-2 | Decent | Direct consistency mapping |
| Turbo/Lightning | 1-4 | Good | Adversarial distillation |

### What You'll Learn

1. How diffusion models work and why they need many steps
2. How consistency distillation enables few-step generation
3. Comparing quality at different step counts
4. Speed benchmarking and Pareto analysis
5. Practical use of LCM-LoRA

### Prerequisites
- Basic understanding of diffusion models
- Google Colab with GPU runtime (T4 for actual generation)

In [None]:
# ============================================================
# Install dependencies
# ============================================================
!pip install diffusers transformers accelerate torch -q
!pip install matplotlib numpy Pillow -q

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import time
import warnings
warnings.filterwarnings('ignore')

print("Dependencies loaded!")

In [None]:
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

---

## Section 1: How Diffusion Models Work (Review)

Diffusion models learn to reverse a noise-adding process:

### Forward Process (Adding Noise)
$$x_t = \sqrt{\bar{\alpha}_t} \cdot x_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon$$

### Reverse Process (Removing Noise)
$$x_{t-1} = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \cdot \epsilon_\theta(x_t, t)\right) + \sigma_t \cdot z$$

The model $\epsilon_\theta$ predicts the noise at each step. More steps = more gradual denoising = better quality.

In [None]:
# ============================================================
# Visualize the denoising process at different step counts
# ============================================================

np.random.seed(42)

def simulate_denoising(n_steps, image_size=64):
    """
    Simulate denoising process for visualization.
    Creates a simple gradient image corrupted by noise.
    """
    # Target: a clean gradient image
    x, y = np.meshgrid(np.linspace(-1, 1, image_size), np.linspace(-1, 1, image_size))
    target = np.sin(2 * np.pi * x) * np.cos(2 * np.pi * y)
    target = (target + 1) / 2  # Normalize to [0, 1]
    
    # Start from pure noise
    noisy = np.random.randn(image_size, image_size)
    
    # Simulate denoising steps
    intermediates = [noisy.copy()]
    for step in range(n_steps):
        t = 1.0 - (step + 1) / n_steps  # t goes from 1 to 0
        # Interpolate between noise and target (simplified)
        noisy = t * noisy + (1 - t) * target
        # Add small noise for stochasticity (except last step)
        if step < n_steps - 1:
            noisy += 0.05 * t * np.random.randn(image_size, image_size)
        intermediates.append(noisy.copy())
    
    return intermediates, target

# Compare different step counts
step_counts = [1, 4, 8, 20, 50]

fig, axes = plt.subplots(len(step_counts), 6, figsize=(18, len(step_counts) * 3))

for row, n_steps in enumerate(step_counts):
    intermediates, target = simulate_denoising(n_steps)
    
    # Show evenly spaced intermediates
    indices = np.linspace(0, len(intermediates) - 1, 5).astype(int)
    
    for col, idx in enumerate(indices):
        axes[row, col].imshow(intermediates[idx], cmap='viridis', vmin=-1, vmax=1)
        axes[row, col].set_title(f'Step {idx}/{n_steps}', fontsize=9)
        axes[row, col].axis('off')
    
    # Show target
    axes[row, 5].imshow(target, cmap='viridis', vmin=0, vmax=1)
    axes[row, 5].set_title('Target', fontsize=9)
    axes[row, 5].axis('off')
    
    # Quality metric (MSE)
    mse = np.mean((intermediates[-1] - target) ** 2)
    axes[row, 0].set_ylabel(f'{n_steps} steps\nMSE={mse:.4f}', fontsize=11, fontweight='bold')

plt.suptitle('Denoising Process at Different Step Counts', fontsize=15, fontweight='bold')
plt.tight_layout()
plt.savefig('denoising_steps.png', dpi=150, bbox_inches='tight')
plt.show()

print("More steps = more gradual denoising = better quality")
print("But each step requires a full neural network forward pass!")

---

## Section 2: Consistency Models and LCM

### The Key Insight

Instead of learning to denoise one step at a time, **consistency models** learn to map ANY noisy version directly to the clean image:

```
Standard Diffusion:    x_T → x_{T-1} → ... → x_1 → x_0
                       (many steps, each step needs model)

Consistency Model:     x_T ──────────────────────► x_0
                       (one step, one model call)
```

### Consistency Property

For any two points on the same trajectory, the consistency function maps both to the same clean output:

$$f(x_t, t) = f(x_{t'}, t') \quad \text{for all } t, t' \text{ on same trajectory}$$

### LCM (Latent Consistency Models)

LCM applies consistency distillation in the **latent space** of Stable Diffusion:

1. Start with a pre-trained Stable Diffusion model (teacher)
2. Train a student to directly predict the final latent in fewer steps
3. The student learns to be consistent across noise levels

In [None]:
# ============================================================
# Visualize: Standard vs Consistency Model paths
# ============================================================

fig, axes = plt.subplots(1, 2, figsize=(16, 7))

# Simulate denoising trajectories
np.random.seed(42)

# Standard diffusion: many small steps
ax = axes[0]
timesteps = np.linspace(1, 0, 50)
n_trajectories = 5

for i in range(n_trajectories):
    noise = np.random.randn() * 3
    x = noise
    trajectory_x = [x]
    trajectory_t = [1.0]
    
    for t_idx in range(1, len(timesteps)):
        t = timesteps[t_idx]
        x = x * t + (1 - t) * 0 + np.random.randn() * 0.1 * t  # Denoise toward 0
        trajectory_x.append(x)
        trajectory_t.append(t)
    
    color = plt.cm.viridis(i / n_trajectories)
    ax.plot(trajectory_t, trajectory_x, '-', color=color, alpha=0.7, linewidth=1.5)
    ax.scatter([1], [trajectory_x[0]], color=color, s=50, zorder=5)
    ax.scatter([trajectory_t[-1]], [trajectory_x[-1]], color=color, s=50, marker='*', zorder=5)

ax.set_xlabel('Noise Level (t)', fontsize=12)
ax.set_ylabel('Value (x)', fontsize=12)
ax.set_title('Standard Diffusion\n(50 steps, following trajectory)', fontsize=13, fontweight='bold')
ax.invert_xaxis()
ax.grid(True, alpha=0.3)
ax.axhline(y=0, color='red', linestyle='--', alpha=0.5, label='Clean target')
ax.legend()

# Consistency model: direct mapping
ax = axes[1]

for i in range(n_trajectories):
    noise = np.random.randn() * 3
    # Sample a few noise levels
    noise_levels = [1.0, 0.7, 0.4, 0.1]
    for t in noise_levels:
        x_noisy = noise * t
        x_clean = 0  # All map to same clean output
        
        color = plt.cm.viridis(i / n_trajectories)
        ax.annotate('', xy=(0, x_clean), xytext=(t, x_noisy),
                    arrowprops=dict(arrowstyle='->', color=color, alpha=0.5, lw=1.5))
        ax.scatter([t], [x_noisy], color=color, s=30, zorder=5)

ax.scatter([0] * n_trajectories, [0] * n_trajectories, color='red', s=100, 
           marker='*', zorder=10, label='Clean target')
ax.set_xlabel('Noise Level (t)', fontsize=12)
ax.set_ylabel('Value (x)', fontsize=12)
ax.set_title('Consistency Model\n(Direct mapping, any noise level → clean)', fontsize=13, fontweight='bold')
ax.invert_xaxis()
ax.grid(True, alpha=0.3)
ax.legend()

plt.tight_layout()
plt.savefig('consistency_vs_standard.png', dpi=150, bbox_inches='tight')
plt.show()

---

## Section 3: Generating Images with Different Step Counts

Let's use Stable Diffusion with LCM-LoRA to compare generation quality at different step counts.

In [None]:
# ============================================================
# Load Stable Diffusion with LCM-LoRA
# ============================================================

from diffusers import StableDiffusionPipeline, LCMScheduler, DPMSolverMultistepScheduler
import torch

MODEL_ID = "stabilityai/stable-diffusion-2-1-base"  # Fits on T4
LCM_LORA_ID = "latent-consistency/lcm-lora-sdv1-5"  # LCM LoRA adapter

# Note: If running on CPU, we'll simulate the results
USE_GPU = device == 'cuda'

if USE_GPU:
    print("Loading Stable Diffusion pipeline...")
    try:
        pipe = StableDiffusionPipeline.from_pretrained(
            "stabilityai/stable-diffusion-2-1-base",
            torch_dtype=torch.float16,
            safety_checker=None,
        ).to(device)
        
        print("Pipeline loaded successfully!")
        PIPELINE_LOADED = True
    except Exception as e:
        print(f"Could not load pipeline: {e}")
        print("Will use simulated results.")
        PIPELINE_LOADED = False
else:
    print("No GPU detected. Using simulated results for visualization.")
    print("To generate real images, enable GPU runtime.")
    PIPELINE_LOADED = False

In [None]:
# ============================================================
# Generate images at different step counts (or simulate)
# ============================================================

prompt = "a beautiful sunset over mountains, oil painting style, detailed, warm colors"
step_configs = [1, 2, 4, 8, 15, 25, 50]

generation_times = []
images = []

if PIPELINE_LOADED:
    # Use DPM-Solver for standard steps, LCM scheduler for few steps
    generator = torch.Generator(device=device).manual_seed(42)
    
    for n_steps in step_configs:
        print(f"Generating with {n_steps} steps...")
        
        # Use appropriate scheduler
        if n_steps <= 8:
            pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
        else:
            pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
        
        generator = torch.Generator(device=device).manual_seed(42)
        
        start = time.time()
        image = pipe(
            prompt,
            num_inference_steps=n_steps,
            guidance_scale=1.0 if n_steps <= 8 else 7.5,
            generator=generator,
            height=512,
            width=512,
        ).images[0]
        
        elapsed = time.time() - start
        generation_times.append(elapsed)
        images.append(np.array(image))
        
        print(f"  Done in {elapsed:.2f}s")
else:
    # Simulate results for visualization
    print("Simulating generation results...")
    
    for n_steps in step_configs:
        # Simulate generation time (roughly linear with steps)
        base_time = 0.8  # Time for model overhead
        time_per_step = 0.15  # Time per denoising step on T4
        sim_time = base_time + n_steps * time_per_step
        generation_times.append(sim_time)
        
        # Simulate image quality (more steps = less noise)
        np.random.seed(42)
        base_image = np.zeros((64, 64, 3))
        # Create sunset gradient
        for i in range(64):
            r = min(1, 0.9 - i/100)
            g = max(0, 0.5 - i/150)
            b = max(0, 0.3 - i/200)
            base_image[i, :, :] = [r, g, b]
        # Add mountains
        for x in range(64):
            height = int(35 + 10 * np.sin(x/10) + 5 * np.sin(x/3))
            base_image[height:, x, :] = [0.15, 0.1, 0.2]
        
        # Add noise based on step count (fewer steps = more noise)
        noise_level = 0.3 / np.sqrt(n_steps)
        noisy_image = np.clip(base_image + noise_level * np.random.randn(64, 64, 3), 0, 1)
        images.append((noisy_image * 255).astype(np.uint8))
    
    print("Simulation complete!")

print(f"\nGeneration times:")
for steps, t in zip(step_configs, generation_times):
    print(f"  {steps:>3d} steps: {t:.2f}s")

In [None]:
# ============================================================
# Visualize: Quality comparison across step counts
# ============================================================

n_images = len(images)
fig, axes = plt.subplots(2, (n_images + 1) // 2, figsize=(20, 9))
axes = axes.flatten()

for i, (img, steps, gen_time) in enumerate(zip(images, step_configs, generation_times)):
    axes[i].imshow(img)
    axes[i].set_title(f'{steps} Steps\n({gen_time:.2f}s)', fontsize=12, fontweight='bold')
    axes[i].axis('off')
    
    # Color border based on speed
    if steps <= 4:
        color = '#4CAF50'  # Fast = green
    elif steps <= 15:
        color = '#FF9800'  # Medium = orange
    else:
        color = '#F44336'  # Slow = red
    
    for spine in axes[i].spines.values():
        spine.set_edgecolor(color)
        spine.set_linewidth(3)
        spine.set_visible(True)

# Hide extra axes
for i in range(n_images, len(axes)):
    axes[i].axis('off')

fig.suptitle(f'Image Quality vs Generation Steps\n"{prompt}"', 
            fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('step_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

---

## Section 4: Speed Benchmarking

Let's create a comprehensive speed benchmark comparing different step counts and configurations.

In [None]:
# ============================================================
# Comprehensive Speed Benchmark
# ============================================================

# Use measured or simulated times
benchmark_steps = [1, 2, 4, 8, 15, 25, 50]

# Simulate benchmark for multiple resolutions
resolutions = {
    '256x256': 0.04,    # Time per step (seconds)
    '512x512': 0.15,    # Time per step
    '768x768': 0.35,    # Time per step
    '1024x1024': 0.65,  # Time per step
}

overhead = 0.5  # Fixed overhead per generation (model loading, encoding, etc.)

fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Plot 1: Generation time vs steps (different resolutions)
ax = axes[0]
for res_name, time_per_step in resolutions.items():
    times = [overhead + steps * time_per_step for steps in benchmark_steps]
    ax.plot(benchmark_steps, times, '-o', linewidth=2, markersize=6, label=res_name)

ax.set_xlabel('Number of Steps', fontsize=12)
ax.set_ylabel('Generation Time (seconds)', fontsize=12)
ax.set_title('Generation Time vs Steps\n(by resolution, T4 GPU)', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_yscale('log')

# Plot 2: Speedup from reducing steps
ax = axes[1]
baseline_steps = 50
for res_name, time_per_step in resolutions.items():
    baseline_time = overhead + baseline_steps * time_per_step
    speedups = [baseline_time / (overhead + steps * time_per_step) for steps in benchmark_steps]
    ax.plot(benchmark_steps, speedups, '-s', linewidth=2, markersize=6, label=res_name)

ax.set_xlabel('Number of Steps', fontsize=12)
ax.set_ylabel('Speedup vs 50 Steps', fontsize=12)
ax.set_title('Speedup from Reducing Steps', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
ax.axhline(y=1, color='gray', linestyle='--', alpha=0.5)

# Plot 3: Images per minute throughput
ax = axes[2]
res_512_time_per_step = resolutions['512x512']
throughputs = [60 / (overhead + steps * res_512_time_per_step) for steps in benchmark_steps]
colors = ['#4CAF50' if s <= 4 else '#FF9800' if s <= 15 else '#F44336' for s in benchmark_steps]

bars = ax.bar([str(s) for s in benchmark_steps], throughputs, color=colors, alpha=0.8)
for bar, t in zip(bars, throughputs):
    ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.3,
            f'{t:.1f}', ha='center', fontweight='bold', fontsize=10)

ax.set_xlabel('Number of Steps', fontsize=12)
ax.set_ylabel('Images per Minute', fontsize=12)
ax.set_title('Throughput at 512x512\n(T4 GPU)', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('speed_benchmark.png', dpi=150, bbox_inches='tight')
plt.show()

print("Key Insight: 4-step generation is ~10x faster than 50-step,")
print("making real-time image generation feasible on consumer GPUs.")

---

## Section 5: Quality-Speed Pareto Analysis

The **Pareto frontier** shows the optimal trade-off between quality and speed -- configurations that are not dominated by any other configuration.

In [None]:
# ============================================================
# Quality-Speed Pareto Curve
# ============================================================

# Simulated quality scores (based on published benchmarks)
# FID-like score: lower is better (we'll use inverted for visualization)

methods = {
    'DDPM (50 steps)': {'steps': 50, 'quality': 0.95, 'time': 8.0, 'method': 'standard'},
    'DDIM (25 steps)': {'steps': 25, 'quality': 0.93, 'time': 4.2, 'method': 'standard'},
    'DPM++ (20 steps)': {'steps': 20, 'quality': 0.94, 'time': 3.5, 'method': 'standard'},
    'DPM++ (15 steps)': {'steps': 15, 'quality': 0.91, 'time': 2.8, 'method': 'standard'},
    'LCM (8 steps)': {'steps': 8, 'quality': 0.88, 'time': 1.7, 'method': 'consistency'},
    'LCM (4 steps)': {'steps': 4, 'quality': 0.83, 'time': 1.1, 'method': 'consistency'},
    'LCM (2 steps)': {'steps': 2, 'quality': 0.75, 'time': 0.8, 'method': 'consistency'},
    'LCM-LoRA (8 steps)': {'steps': 8, 'quality': 0.87, 'time': 1.7, 'method': 'lora'},
    'LCM-LoRA (4 steps)': {'steps': 4, 'quality': 0.82, 'time': 1.1, 'method': 'lora'},
    'Turbo (4 steps)': {'steps': 4, 'quality': 0.86, 'time': 1.0, 'method': 'turbo'},
    'Turbo (1 step)': {'steps': 1, 'quality': 0.72, 'time': 0.6, 'method': 'turbo'},
    'Lightning (4 steps)': {'steps': 4, 'quality': 0.88, 'time': 1.0, 'method': 'lightning'},
    'Lightning (2 steps)': {'steps': 2, 'quality': 0.80, 'time': 0.7, 'method': 'lightning'},
    'Consistency (1 step)': {'steps': 1, 'quality': 0.68, 'time': 0.6, 'method': 'consistency'},
}

# Find Pareto frontier
def is_pareto_optimal(points):
    """Find points on the Pareto frontier (maximize quality, minimize time)."""
    is_optimal = np.ones(len(points), dtype=bool)
    for i, (q1, t1) in enumerate(points):
        for j, (q2, t2) in enumerate(points):
            if i != j:
                if q2 >= q1 and t2 <= t1 and (q2 > q1 or t2 < t1):
                    is_optimal[i] = False
                    break
    return is_optimal

points = [(v['quality'], v['time']) for v in methods.values()]
pareto_mask = is_pareto_optimal(points)

# Visualization
fig, ax = plt.subplots(figsize=(14, 9))

method_colors = {
    'standard': '#2196F3',
    'consistency': '#4CAF50',
    'lora': '#FF9800',
    'turbo': '#9C27B0',
    'lightning': '#F44336',
}

for i, (name, data) in enumerate(methods.items()):
    color = method_colors[data['method']]
    marker = 'o' if not pareto_mask[i] else '*'
    size = 100 if not pareto_mask[i] else 250
    
    ax.scatter(data['time'], data['quality'], 
              c=color, s=size, marker=marker, alpha=0.8,
              edgecolors='black' if pareto_mask[i] else 'none',
              linewidths=2 if pareto_mask[i] else 0,
              zorder=10 if pareto_mask[i] else 5)
    
    # Label
    offset_x = 0.1
    offset_y = 0.01 if i % 2 == 0 else -0.02
    ax.annotate(name, (data['time'], data['quality']),
               xytext=(data['time'] + offset_x, data['quality'] + offset_y),
               fontsize=8, alpha=0.8)

# Draw Pareto frontier
pareto_points = [(methods[name]['time'], methods[name]['quality']) 
                 for i, name in enumerate(methods.keys()) if pareto_mask[i]]
pareto_points.sort()
if pareto_points:
    px, py = zip(*pareto_points)
    ax.plot(px, py, 'k--', alpha=0.4, linewidth=2, label='Pareto frontier')

# Legend for method types
legend_elements = [
    plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=c, markersize=10, label=m.title())
    for m, c in method_colors.items()
]
legend_elements.append(plt.Line2D([0], [0], marker='*', color='w', markerfacecolor='gray',
                                   markersize=15, markeredgecolor='black', markeredgewidth=2,
                                   label='Pareto optimal'))
ax.legend(handles=legend_elements, loc='lower left', fontsize=10, title='Method Type')

ax.set_xlabel('Generation Time (seconds)', fontsize=13)
ax.set_ylabel('Image Quality Score', fontsize=13)
ax.set_title('Quality-Speed Pareto Analysis\n(512x512, T4 GPU)', fontsize=15, fontweight='bold')
ax.grid(True, alpha=0.3)

# Add region annotations
ax.annotate('Real-time\nZone', xy=(0.8, 0.65), fontsize=12, 
            style='italic', alpha=0.5, ha='center',
            bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.2))
ax.annotate('High Quality\nZone', xy=(6, 0.94), fontsize=12,
            style='italic', alpha=0.5, ha='center',
            bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.2))

plt.tight_layout()
plt.savefig('pareto_curve.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nPareto Optimal Configurations:")
for i, name in enumerate(methods.keys()):
    if pareto_mask[i]:
        d = methods[name]
        print(f"  {name}: Quality={d['quality']:.2f}, Time={d['time']:.1f}s, Steps={d['steps']}")

---

## Section 6: Distillation for Few-Step Models

How do these few-step models actually work? They are created through **distillation** from the original multi-step model.

### Types of Distillation for Diffusion Models

| Method | Key Idea | Quality | Speed |
|--------|----------|---------|-------|
| Progressive Distillation | Halve steps iteratively: 1024→512→256→...→4 | Good | Good |
| Consistency Distillation | Learn consistency mapping from teacher ODE | Good | Best |
| Adversarial Distillation | GAN discriminator ensures quality | Better | Good |
| Score Distillation | Match score function of teacher | Good | Good |

In [None]:
# ============================================================
# Simulate progressive distillation process
# ============================================================

def simulate_progressive_distillation():
    """
    Simulate the progressive distillation process:
    Teacher (1024 steps) → Student (512) → ... → Final (4 steps)
    """
    stages = []
    
    # Each stage halves the number of steps
    step_schedule = [1024, 512, 256, 128, 64, 32, 16, 8, 4]
    quality_degradation = 0.005  # Quality loss per halving
    training_cost_per_stage = 100  # GPU hours (simplified)
    
    base_quality = 1.0
    cumulative_cost = 0
    
    for i, steps in enumerate(step_schedule):
        quality = base_quality - i * quality_degradation * (1 + i * 0.1)  # Accelerating degradation
        cumulative_cost += training_cost_per_stage * (0.5 ** max(0, i - 2))  # Cheaper later
        
        # Speed relative to 1024 steps
        speed = 1024 / steps
        
        stages.append({
            'steps': steps,
            'quality': quality,
            'speed': speed,
            'training_cost': cumulative_cost,
        })
    
    return stages

stages = simulate_progressive_distillation()

# Visualization
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

steps = [s['steps'] for s in stages]
qualities = [s['quality'] for s in stages]
speeds = [s['speed'] for s in stages]
costs = [s['training_cost'] for s in stages]

# Plot 1: Quality vs Steps
axes[0].plot(steps, qualities, 'b-o', linewidth=2, markersize=8)
axes[0].fill_between(steps, qualities, alpha=0.1, color='blue')
axes[0].set_xlabel('Number of Steps', fontsize=12)
axes[0].set_ylabel('Quality Score', fontsize=12)
axes[0].set_title('Quality Degradation\n(Progressive Distillation)', fontsize=13, fontweight='bold')
axes[0].set_xscale('log', base=2)
axes[0].invert_xaxis()
axes[0].grid(True, alpha=0.3)
axes[0].set_ylim(0.85, 1.02)

# Annotate the sweet spot
axes[0].axvspan(4, 8, alpha=0.1, color='green')
axes[0].annotate('Sweet spot\n(4-8 steps)', xy=(6, 0.92), fontsize=11,
                 fontweight='bold', color='green', ha='center')

# Plot 2: Speedup vs Quality
axes[1].plot(qualities, speeds, 'r-s', linewidth=2, markersize=8)
for i, s in enumerate(stages):
    axes[1].annotate(f"{s['steps']}s", (s['quality'], s['speed']),
                     textcoords="offset points", xytext=(10, 5), fontsize=9)
axes[1].set_xlabel('Quality Score', fontsize=12)
axes[1].set_ylabel('Speedup Factor', fontsize=12)
axes[1].set_title('Speed-Quality Trade-off', fontsize=13, fontweight='bold')
axes[1].grid(True, alpha=0.3)
axes[1].set_yscale('log', base=2)

# Plot 3: Cumulative Training Cost
bars = axes[2].bar(range(len(stages)), costs, 
                   color=plt.cm.RdYlGn_r(np.linspace(0.2, 0.8, len(stages))), alpha=0.8)
axes[2].set_xlabel('Distillation Stage', fontsize=12)
axes[2].set_ylabel('Cumulative Training Cost (GPU-hours)', fontsize=12)
axes[2].set_title('Training Cost', fontsize=13, fontweight='bold')
axes[2].set_xticks(range(len(stages)))
axes[2].set_xticklabels([f"{s['steps']}\nsteps" for s in stages], fontsize=8)
axes[2].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('distillation_process.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ============================================================
# Quality degradation analysis at different step counts
# ============================================================

# Simulate quality metrics at different steps
np.random.seed(42)

step_range = [1, 2, 4, 6, 8, 10, 15, 20, 25, 30, 40, 50]

# Simulated metrics (based on published benchmarks)
fid_scores = [45 * np.exp(-0.05 * s) + 8 + np.random.normal(0, 1) for s in step_range]
clip_scores = [0.20 + 0.008 * np.log(s + 1) + np.random.normal(0, 0.005) for s in step_range]
aesthetic_scores = [4.5 + 0.8 * (1 - np.exp(-0.1 * s)) + np.random.normal(0, 0.1) for s in step_range]

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# FID (lower is better)
axes[0].plot(step_range, fid_scores, 'b-o', linewidth=2, markersize=7)
axes[0].fill_between(step_range, fid_scores, alpha=0.1, color='blue')
axes[0].set_xlabel('Steps', fontsize=12)
axes[0].set_ylabel('FID Score (lower = better)', fontsize=12)
axes[0].set_title('FID Score vs Steps', fontsize=13, fontweight='bold')
axes[0].grid(True, alpha=0.3)
axes[0].axhspan(5, 15, alpha=0.1, color='green')
axes[0].text(30, 10, 'Acceptable', fontsize=10, color='green', fontweight='bold')

# CLIP Score (higher is better)
axes[1].plot(step_range, clip_scores, 'g-s', linewidth=2, markersize=7)
axes[1].fill_between(step_range, clip_scores, alpha=0.1, color='green')
axes[1].set_xlabel('Steps', fontsize=12)
axes[1].set_ylabel('CLIP Score (higher = better)', fontsize=12)
axes[1].set_title('CLIP Score vs Steps', fontsize=13, fontweight='bold')
axes[1].grid(True, alpha=0.3)

# Aesthetic Score
axes[2].plot(step_range, aesthetic_scores, 'r-^', linewidth=2, markersize=7)
axes[2].fill_between(step_range, aesthetic_scores, alpha=0.1, color='red')
axes[2].set_xlabel('Steps', fontsize=12)
axes[2].set_ylabel('Aesthetic Score (higher = better)', fontsize=12)
axes[2].set_title('Aesthetic Score vs Steps', fontsize=13, fontweight='bold')
axes[2].grid(True, alpha=0.3)

# Add diminishing returns annotation to all
for ax in axes:
    ax.axvline(x=8, color='orange', linestyle=':', alpha=0.5)
    ax.text(9, ax.get_ylim()[0] + 0.1 * (ax.get_ylim()[1] - ax.get_ylim()[0]),
           'Diminishing\nreturns\n(~8 steps)', fontsize=9, color='orange')

plt.tight_layout()
plt.savefig('quality_metrics.png', dpi=150, bbox_inches='tight')
plt.show()

print("Key Insight: Quality improvements flatten dramatically after ~8-15 steps.")
print("The 'diminishing returns' point is where few-step methods offer the best value.")

---

## Summary & Key Takeaways

| Concept | Key Insight |
|---------|-------------|
| **Standard Diffusion** | 20-50 steps, highest quality, slow |
| **LCM / Consistency** | 2-8 steps via consistency distillation, good quality |
| **LCM-LoRA** | Add consistency to any SD model via LoRA adapter |
| **Progressive Distillation** | Halves steps iteratively, each stage trains a student |
| **Quality-Speed Trade-off** | Diminishing returns after ~8 steps |
| **Pareto Frontier** | Lightning/Turbo at 4 steps offers best quality-speed trade-off |

### Practical Recommendations

| Use Case | Recommended Method | Steps |
|----------|-------------------|-------|
| Real-time generation | SDXL-Turbo / Lightning | 1-4 |
| Interactive editing | LCM-LoRA | 4-8 |
| High-quality production | DPM++ scheduler | 20-30 |
| Maximum quality | DDPM / Full diffusion | 50+ |

---

## Exercises

### Exercise 1: LCM-LoRA Comparison
Load a Stable Diffusion model with and without LCM-LoRA. Generate the same prompt at 4 steps and compare quality.

### Exercise 2: Resolution vs Steps Trade-off
For a fixed time budget of 2 seconds, what's better: higher resolution with fewer steps, or lower resolution with more steps?

### Exercise 3: Batch Generation Analysis
Measure how batch size affects per-image generation time at different step counts.

### Exercise 4: Prompt Complexity Impact
Do complex prompts need more steps than simple ones? Generate images with varying prompt complexity at different step counts and evaluate.

In [None]:
# ============================================================
# Exercise 1 Starter: LCM-LoRA Comparison
# ============================================================

# Uncomment to run with GPU:

# from diffusers import StableDiffusionPipeline, LCMScheduler
# 
# # Load base model
# pipe = StableDiffusionPipeline.from_pretrained(
#     "runwayml/stable-diffusion-v1-5",
#     torch_dtype=torch.float16,
# ).to("cuda")
# 
# # Load LCM-LoRA adapter
# pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
# pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
# 
# # Generate with 4 steps
# image = pipe(
#     "a beautiful castle on a hilltop, fantasy art",
#     num_inference_steps=4,
#     guidance_scale=1.0,  # LCM works best with guidance_scale=1.0
# ).images[0]
# 
# image.save("lcm_lora_4steps.png")
# print("Image saved!")

print("Uncomment the code above to generate images with LCM-LoRA!")
print("Requires GPU runtime.")