# Notebook 17: Image Generation -- Guidance Scale & Steps

---

## Inference Engineering Course

Welcome to Notebook 17! Here we explore how **diffusion models** generate images and how two key parameters -- **guidance scale** and **number of steps** -- affect both quality and speed.

### What You Will Learn

| Topic | Description |
|-------|-------------|
| **Diffusion Models** | How images are generated from noise |
| **Guidance Scale** | Classifier-free guidance and its effect on quality |
| **Step Count** | Tradeoff between quality and inference speed |
| **CFG Math** | The mathematics behind classifier-free guidance |
| **Comparison Grids** | Side-by-side visual comparisons |

### Classifier-Free Guidance (CFG) in a Nutshell

The prediction is a blend of conditional and unconditional outputs:

$$\hat{\epsilon} = \epsilon_{\text{uncond}} + w \cdot (\epsilon_{\text{cond}} - \epsilon_{\text{uncond}})$$

Where `w` is the **guidance scale**:
- `w = 1.0`: No guidance (just conditional)
- `w = 7.0`: Standard guidance (good balance)
- `w > 10`: Strong guidance (high prompt fidelity, may oversaturate)

---

## Part 1: Setup & Installations

**Important**: This notebook requires a GPU. In Colab: Runtime -> Change runtime type -> T4 GPU.

We use Stable Diffusion v1.5 via the `diffusers` library.

In [None]:
%%capture
!pip install diffusers transformers accelerate torch matplotlib numpy Pillow

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import time
import warnings
warnings.filterwarnings('ignore')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
else:
    print("WARNING: No GPU detected. Image generation will be very slow on CPU.")
    print("Enable GPU: Runtime -> Change runtime type -> T4 GPU")

## Part 2: Loading the Diffusion Model

We will use **Stable Diffusion v1.5** which generates 512x512 images. For Colab compatibility, we use float16 precision to save memory.

In [None]:
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler

print("Loading Stable Diffusion v1.5 (this may take a few minutes)...")

# Use float16 for memory efficiency on T4 GPU
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
    safety_checker=None,  # Disable for speed in educational context
)

# Use DPM-Solver for fast inference
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)

# Enable memory optimizations
if device == 'cuda':
    pipe.enable_attention_slicing()

print("Model loaded successfully!")
print(f"Image size: {pipe.unet.config.sample_size * 8}x{pipe.unet.config.sample_size * 8}")

In [None]:
def generate_image(prompt, guidance_scale=7.5, num_steps=25, seed=42, size=512):
    """
    Generate an image with timing information.
    
    Args:
        prompt: Text description of the desired image
        guidance_scale: CFG scale (1.0-20.0)
        num_steps: Number of denoising steps
        seed: Random seed for reproducibility
        size: Image width and height
    """
    generator = torch.Generator(device=device).manual_seed(seed)
    
    start = time.time()
    with torch.no_grad():
        result = pipe(
            prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=num_steps,
            generator=generator,
            width=size,
            height=size,
        )
    elapsed = time.time() - start
    
    return result.images[0], elapsed

# Test generation
print("Generating test image...")
test_img, test_time = generate_image(
    "A beautiful sunset over mountain peaks, oil painting style",
    guidance_scale=7.5, num_steps=20, size=512
)
print(f"Generated in {test_time:.1f}s")

plt.figure(figsize=(6, 6))
plt.imshow(test_img)
plt.axis('off')
plt.title('Test Image (CFG=7.5, Steps=20)', fontsize=13)
plt.show()

## Part 3: Understanding Classifier-Free Guidance

### The Math Behind CFG

At each denoising step, the model predicts noise. With CFG, we compute:

$$\hat{\epsilon}_\theta(x_t, c) = \epsilon_\theta(x_t, \varnothing) + w \cdot [\epsilon_\theta(x_t, c) - \epsilon_\theta(x_t, \varnothing)]$$

Where:
- $\epsilon_\theta(x_t, c)$: Noise predicted WITH the text prompt (conditional)
- $\epsilon_\theta(x_t, \varnothing)$: Noise predicted WITHOUT the text prompt (unconditional)
- $w$: Guidance scale

### Intuition

| Guidance Scale | Effect |
|---------------|--------|
| w = 1.0 | Pure conditional generation (ignores CFG) |
| w = 3.0 | Mild guidance (creative, may drift from prompt) |
| w = 7.0 | Standard (good balance of quality and fidelity) |
| w = 15.0 | Strong guidance (high fidelity, may oversaturate) |
| w = 20.0+ | Very strong (often produces artifacts) |

In [None]:
# Visualize the CFG formula
fig, axes = plt.subplots(1, 3, figsize=(16, 4))

# Simulate 1D noise predictions
np.random.seed(42)
x = np.linspace(0, 10, 200)
unconditional = np.sin(x) * 0.5 + np.random.normal(0, 0.1, len(x))
conditional = np.sin(x) * 0.5 + 0.8 * np.sin(2*x) + np.random.normal(0, 0.1, len(x))
difference = conditional - unconditional

for ax, w, title in zip(axes, [1.0, 7.5, 15.0], 
                         ['Low (w=1.0)', 'Standard (w=7.5)', 'High (w=15.0)']):
    guided = unconditional + w * difference
    
    ax.plot(x, unconditional, '--', color='gray', alpha=0.5, label='Unconditional')
    ax.plot(x, conditional, '--', color='blue', alpha=0.5, label='Conditional')
    ax.plot(x, guided, '-', color='red', linewidth=2, label=f'Guided (w={w})')
    
    ax.set_title(f'CFG: {title}', fontsize=12, fontweight='bold')
    ax.legend(fontsize=8)
    ax.set_ylim(-5, 10)

plt.suptitle('Classifier-Free Guidance: How w Amplifies the Prompt Signal',
             fontsize=14, fontweight='bold', y=1.05)
plt.tight_layout()
plt.show()

## Part 4: Varying Guidance Scale

Let's generate images with different guidance scales to see the effect visually.

In [None]:
# Generate images with different guidance scales
prompt = "A photorealistic cat wearing a tiny crown, sitting on a red velvet throne"
guidance_scales = [1.0, 3.0, 7.0, 15.0]
num_steps = 25

guidance_images = []
guidance_times = []

print(f"Prompt: '{prompt}'")
print(f"Generating {len(guidance_scales)} images with different guidance scales...")

for gs in guidance_scales:
    img, elapsed = generate_image(prompt, guidance_scale=gs, num_steps=num_steps, size=512)
    guidance_images.append(img)
    guidance_times.append(elapsed)
    print(f"  CFG={gs:>5.1f}: {elapsed:.1f}s")

# Display comparison grid
fig, axes = plt.subplots(1, 4, figsize=(20, 5))

for ax, img, gs, t in zip(axes, guidance_images, guidance_scales, guidance_times):
    ax.imshow(img)
    ax.set_title(f'CFG Scale = {gs}\n({t:.1f}s)', fontsize=12, fontweight='bold')
    ax.axis('off')

plt.suptitle(f'Effect of Guidance Scale (Steps={num_steps})',
             fontsize=15, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

## Part 5: Varying Step Count

The number of denoising steps directly controls the **quality vs speed tradeoff**:

- Fewer steps = faster but potentially lower quality
- More steps = slower but generally higher quality (with diminishing returns)

Modern schedulers (like DPM-Solver) can produce good results in as few as 15-25 steps.

In [None]:
# Generate images with different step counts
prompt = "A serene Japanese garden with a stone path, cherry blossoms, watercolor style"
step_counts = [5, 10, 25, 50]
guidance_scale = 7.5

step_images = []
step_times = []

print(f"Prompt: '{prompt}'")
print(f"Generating {len(step_counts)} images with different step counts...")

for steps in step_counts:
    img, elapsed = generate_image(prompt, guidance_scale=guidance_scale, 
                                   num_steps=steps, size=512)
    step_images.append(img)
    step_times.append(elapsed)
    print(f"  Steps={steps:>3d}: {elapsed:.1f}s")

# Display comparison grid
fig, axes = plt.subplots(1, 4, figsize=(20, 5))

for ax, img, steps, t in zip(axes, step_images, step_counts, step_times):
    ax.imshow(img)
    ax.set_title(f'{steps} Steps\n({t:.1f}s)', fontsize=12, fontweight='bold')
    ax.axis('off')

plt.suptitle(f'Effect of Step Count (CFG={guidance_scale})',
             fontsize=15, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

## Part 6: Quality vs Speed Tradeoff Analysis

Let's create a detailed analysis of the time-quality tradeoff. We measure image similarity between low-step and high-step ("reference") images using pixel-level metrics.

In [None]:
# Measure quality vs speed tradeoff
prompt = "A fantasy castle on a floating island, dramatic lighting, digital art"
all_steps = [3, 5, 8, 10, 15, 20, 25, 30, 40, 50]

gen_times = []
images = []

print("Generating images across step counts (this may take a few minutes)...")
for steps in all_steps:
    img, elapsed = generate_image(prompt, guidance_scale=7.5, num_steps=steps, size=512)
    images.append(np.array(img))
    gen_times.append(elapsed)
    print(f"  Steps={steps:>3d}: {elapsed:.2f}s")

# Use the 50-step image as reference for quality comparison
reference = images[-1].astype(float)

# Calculate RMSE and SSIM-like metric against reference
rmse_scores = []
for img in images:
    diff = img.astype(float) - reference
    rmse = np.sqrt(np.mean(diff ** 2))
    rmse_scores.append(rmse)

# Normalize RMSE to [0, 1] quality score (1 = perfect)
max_rmse = max(rmse_scores)
quality_scores = [1 - (r / max_rmse) if max_rmse > 0 else 1.0 for r in rmse_scores]

# Plot
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Left: Time vs Steps
ax = axes[0]
ax.plot(all_steps, gen_times, 'o-', color='#F44336', linewidth=2.5, markersize=8)
ax.set_xlabel('Number of Steps', fontsize=12)
ax.set_ylabel('Generation Time (seconds)', fontsize=12)
ax.set_title('Generation Time vs Steps', fontsize=14, fontweight='bold')

# Middle: Quality vs Steps
ax = axes[1]
ax.plot(all_steps, quality_scores, 's-', color='#4CAF50', linewidth=2.5, markersize=8)
ax.set_xlabel('Number of Steps', fontsize=12)
ax.set_ylabel('Quality Score (vs 50-step ref)', fontsize=12)
ax.set_title('Quality vs Steps', fontsize=14, fontweight='bold')
ax.set_ylim(0, 1.05)

# Right: Quality vs Time (Pareto frontier)
ax = axes[2]
scatter = ax.scatter(gen_times, quality_scores, c=all_steps, cmap='viridis',
                     s=150, edgecolors='black', linewidths=1, zorder=5)
ax.plot(gen_times, quality_scores, '--', color='gray', alpha=0.5)

for i, steps in enumerate(all_steps):
    ax.annotate(f'{steps}', (gen_times[i], quality_scores[i]),
               textcoords='offset points', xytext=(5, 5), fontsize=9)

ax.set_xlabel('Generation Time (seconds)', fontsize=12)
ax.set_ylabel('Quality Score', fontsize=12)
ax.set_title('Quality vs Speed Tradeoff', fontsize=14, fontweight='bold')
plt.colorbar(scatter, ax=ax, label='Steps')

plt.tight_layout()
plt.show()

print("\nKey observation: quality improvements show diminishing returns after ~20-25 steps")
print("with DPM-Solver. The 'sweet spot' is often 20-25 steps for a good quality-speed balance.")

## Part 7: Full Comparison Grid

Let's create a comprehensive grid showing the interaction between guidance scale and step count.

In [None]:
# Create a 2D comparison grid: Guidance Scale x Step Count
prompt = "A detailed portrait of a wise old wizard, fantasy art, detailed"
grid_guidance = [1.0, 5.0, 10.0, 15.0]
grid_steps = [5, 15, 30]

print(f"Generating {len(grid_guidance) * len(grid_steps)} images for comparison grid...")
grid_images = {}
grid_times = {}

for gs in grid_guidance:
    for steps in grid_steps:
        img, elapsed = generate_image(prompt, guidance_scale=gs, num_steps=steps, size=512)
        grid_images[(gs, steps)] = img
        grid_times[(gs, steps)] = elapsed
        print(f"  CFG={gs:>5.1f}, Steps={steps:>3d}: {elapsed:.1f}s")

# Plot the grid
fig, axes = plt.subplots(len(grid_steps), len(grid_guidance), 
                          figsize=(4*len(grid_guidance), 4*len(grid_steps) + 1))

for i, steps in enumerate(grid_steps):
    for j, gs in enumerate(grid_guidance):
        ax = axes[i][j]
        ax.imshow(grid_images[(gs, steps)])
        t = grid_times[(gs, steps)]
        ax.set_title(f'CFG={gs}, Steps={steps}\n({t:.1f}s)', fontsize=10)
        ax.axis('off')

# Add row and column labels
for i, steps in enumerate(grid_steps):
    axes[i][0].set_ylabel(f'{steps} Steps', fontsize=13, fontweight='bold')

plt.suptitle(f'Guidance Scale vs Step Count Grid',
             fontsize=16, fontweight='bold', y=1.01)
plt.tight_layout()
plt.show()

## Part 8: Inference Optimization Tips

For production image generation, several optimizations can dramatically improve throughput:

In [None]:
# Benchmark: standard vs optimized inference
prompt = "A mountain landscape at golden hour, photography"

# Standard: 50 steps
_, time_50 = generate_image(prompt, num_steps=50, size=512)

# Optimized: 20 steps (DPM-Solver already configured)
_, time_20 = generate_image(prompt, num_steps=20, size=512)

# Fast: 10 steps
_, time_10 = generate_image(prompt, num_steps=10, size=512)

# Smaller image: 10 steps, 256x256
_, time_small = generate_image(prompt, num_steps=10, size=256)

# Visualize
methods = ['50 steps\n512x512', '20 steps\n512x512', '10 steps\n512x512', '10 steps\n256x256']
times = [time_50, time_20, time_10, time_small]
speedups = [time_50/t for t in times]

fig, ax = plt.subplots(figsize=(10, 5))
colors = ['#F44336', '#FF9800', '#4CAF50', '#2196F3']
bars = ax.bar(methods, times, color=colors, alpha=0.85, edgecolor='black')

for bar, t, sp in zip(bars, times, speedups):
    ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.1,
           f'{t:.1f}s\n({sp:.1f}x)', ha='center', fontsize=11, fontweight='bold')

ax.set_ylabel('Generation Time (seconds)', fontsize=12)
ax.set_title('Image Generation Speed Optimizations', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

print("Optimization strategies:")
print("1. Use efficient schedulers (DPM-Solver, DDIM) for fewer steps")
print("2. Reduce resolution for previews, upscale later")
print("3. Use FP16 precision (already enabled)")
print("4. Use torch.compile() for ~20-40% speedup (PyTorch 2.0+)")
print("5. Use TensorRT or ONNX for production deployment")

## Part 9: The Denoising Process Visualized

Let's peek inside the denoising process to see how an image evolves from pure noise to the final result.

In [None]:
# Capture intermediate steps during denoising
from diffusers import DDIMScheduler

# Use DDIM scheduler so we can capture intermediates more easily
prompt = "A vibrant coral reef with tropical fish, underwater photography"
num_steps = 30

# Generate with callback to capture intermediate images
intermediate_images = []
intermediate_steps = []

def callback_fn(pipe, step, timestep, callback_kwargs):
    if step % 5 == 0 or step == num_steps - 1:  # Capture every 5 steps
        # Decode the latent to an image
        latents = callback_kwargs['latents']
        with torch.no_grad():
            image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
            image = pipe.image_processor.postprocess(image, output_type='pil')[0]
            intermediate_images.append(image)
            intermediate_steps.append(step)
    return callback_kwargs

generator = torch.Generator(device=device).manual_seed(42)

final = pipe(
    prompt,
    guidance_scale=7.5,
    num_inference_steps=num_steps,
    generator=generator,
    callback_on_step_end=callback_fn,
)

# Display the denoising progression
n_intermediates = len(intermediate_images)
fig, axes = plt.subplots(1, n_intermediates, figsize=(4 * n_intermediates, 4))

for ax, img, step in zip(axes, intermediate_images, intermediate_steps):
    ax.imshow(img)
    pct = int((step + 1) / num_steps * 100)
    ax.set_title(f'Step {step}\n({pct}% done)', fontsize=11, fontweight='bold')
    ax.axis('off')

plt.suptitle('Denoising Process: From Noise to Image',
             fontsize=15, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

print("Observe: Global structure forms first, then details are refined.")

## Part 10: Understanding the Diffusion Process Mathematically

Let's build a deeper understanding of the forward and reverse diffusion processes that underlie image generation.

### Forward Process (Adding Noise)

The forward process gradually adds Gaussian noise to a clean image over $T$ timesteps:

$$q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t I)$$

Where $\beta_t$ is the noise schedule. After $T$ steps, the image becomes pure noise.

### Reverse Process (Denoising)

The model learns to reverse this process, predicting the noise to remove at each step:

$$p_\theta(x_{t-1}|x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \sigma_t^2 I)$$

In [None]:
# Visualize the forward diffusion process (adding noise)
np.random.seed(42)

# Create a simple test image (checkerboard pattern)
img_size = 64
block_size = 8
checkerboard = np.zeros((img_size, img_size, 3))
for i in range(img_size):
    for j in range(img_size):
        if ((i // block_size) + (j // block_size)) % 2 == 0:
            checkerboard[i, j] = [0.9, 0.3, 0.1]  # Orange
        else:
            checkerboard[i, j] = [0.1, 0.3, 0.9]  # Blue

# Define noise schedule (linear)
T = 1000  # Total timesteps
beta_start, beta_end = 0.0001, 0.02
betas = np.linspace(beta_start, beta_end, T)
alphas = 1.0 - betas
alpha_cumprod = np.cumprod(alphas)

# Show forward process at different timesteps
display_steps = [0, 50, 200, 500, 750, 999]
fig, axes = plt.subplots(2, len(display_steps), figsize=(3.5 * len(display_steps), 7))

for idx, t in enumerate(display_steps):
    # Add noise: x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon
    noise = np.random.randn(*checkerboard.shape)
    alpha_bar = alpha_cumprod[t]
    noisy = np.sqrt(alpha_bar) * checkerboard + np.sqrt(1 - alpha_bar) * noise
    noisy = np.clip(noisy, 0, 1)
    
    # Top: noisy image
    axes[0][idx].imshow(noisy)
    axes[0][idx].set_title(f't = {t}\nalpha_bar = {alpha_bar:.4f}', fontsize=10, fontweight='bold')
    axes[0][idx].axis('off')
    
    # Bottom: noise level histogram
    signal_pct = alpha_bar * 100
    noise_pct = (1 - alpha_bar) * 100
    axes[1][idx].barh(['Signal', 'Noise'], [signal_pct, noise_pct],
                       color=['#4CAF50', '#F44336'], alpha=0.8)
    axes[1][idx].set_xlim(0, 100)
    axes[1][idx].set_xlabel('%')

plt.suptitle('Forward Diffusion Process: Gradually Adding Noise',
             fontsize=15, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

print('At t=0, the image is clean. At t=999, it is pure noise.')
print('The reverse process (denoising) starts from noise and recovers the image.')

In [None]:
# Visualize noise schedules and their impact
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Different noise schedules
T_steps = 1000
t_axis = np.arange(T_steps)

# Linear schedule
beta_linear = np.linspace(0.0001, 0.02, T_steps)
alpha_bar_linear = np.cumprod(1.0 - beta_linear)

# Cosine schedule (used in improved DDPM)
s = 0.008
steps = np.arange(T_steps + 1)
f_t = np.cos(((steps / T_steps) + s) / (1 + s) * np.pi / 2) ** 2
alpha_bar_cosine = f_t[1:] / f_t[0]
alpha_bar_cosine = np.clip(alpha_bar_cosine, 1e-5, 1.0)

# Scaled linear (used in SD)
beta_scaled = np.linspace(0.00085 ** 0.5, 0.012 ** 0.5, T_steps) ** 2
alpha_bar_scaled = np.cumprod(1.0 - beta_scaled)

# Plot 1: Alpha bar comparison
ax = axes[0]
ax.plot(t_axis, alpha_bar_linear, label='Linear', linewidth=2)
ax.plot(t_axis, alpha_bar_cosine, label='Cosine', linewidth=2)
ax.plot(t_axis, alpha_bar_scaled, label='Scaled Linear', linewidth=2)
ax.set_xlabel('Timestep', fontsize=12)
ax.set_ylabel('alpha_bar (signal retention)', fontsize=12)
ax.set_title('Noise Schedules Comparison', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)

# Plot 2: SNR (Signal-to-Noise Ratio)
ax = axes[1]
snr_linear = alpha_bar_linear / (1 - alpha_bar_linear + 1e-10)
snr_cosine = alpha_bar_cosine / (1 - alpha_bar_cosine + 1e-10)
snr_scaled = alpha_bar_scaled / (1 - alpha_bar_scaled + 1e-10)

ax.semilogy(t_axis, snr_linear, label='Linear', linewidth=2)
ax.semilogy(t_axis, snr_cosine, label='Cosine', linewidth=2)
ax.semilogy(t_axis, snr_scaled, label='Scaled Linear', linewidth=2)
ax.set_xlabel('Timestep', fontsize=12)
ax.set_ylabel('SNR (log scale)', fontsize=12)
ax.set_title('Signal-to-Noise Ratio', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)

# Plot 3: Number of effective steps per region
ax = axes[2]
for idx_s, (schedule_name, alpha_bar) in enumerate([('Linear', alpha_bar_linear), ('Cosine', alpha_bar_cosine)]):
    snr = alpha_bar / (1 - alpha_bar + 1e-10)
    high = np.sum(snr > 10)
    mid = np.sum((snr > 0.1) & (snr <= 10))
    low = np.sum(snr <= 0.1)
    
    ax.bar(idx_s, high, color='#4CAF50', alpha=0.8, label='High SNR' if idx_s == 0 else '')
    ax.bar(idx_s, mid, bottom=high, color='#FF9800', alpha=0.8, label='Medium SNR' if idx_s == 0 else '')
    ax.bar(idx_s, low, bottom=high+mid, color='#F44336', alpha=0.8, label='Low SNR' if idx_s == 0 else '')

ax.set_xticks([0, 1])
ax.set_xticklabels(['Linear', 'Cosine'])
ax.set_ylabel('Number of Steps', fontsize=12)
ax.set_title('Steps per SNR Region', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)

plt.tight_layout()
plt.show()

print('The cosine schedule spends more steps in the medium SNR range,')
print('where fine details are being resolved. This often leads to better quality.')

## Part 11: Computational Cost Analysis

Let's quantify the computational cost of image generation and understand what drives inference time.

In [None]:
# Computational cost analysis
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: FLOPs estimation for different configurations
ax = axes[0]

# Approximate FLOPs for SD v1.5 UNet per forward pass
# These are rough estimates for educational purposes
base_flops_per_step = 50e9  # ~50 GFLOPs per UNet forward pass at 512x512

configs = {
    '256x256\n20 steps\nno CFG': (0.25, 20, 1),
    '512x512\n20 steps\nno CFG': (1.0, 20, 1),
    '512x512\n20 steps\nCFG=7': (1.0, 20, 2),  # 2x for CFG
    '512x512\n50 steps\nCFG=7': (1.0, 50, 2),
    '768x768\n25 steps\nCFG=7': (2.25, 25, 2),
    '1024x1024\n25 steps\nCFG=7': (4.0, 25, 2),
}

names = list(configs.keys())
total_flops = []
for name, (size_mult, steps, cfg_mult) in configs.items():
    flops = base_flops_per_step * size_mult * steps * cfg_mult
    total_flops.append(flops / 1e12)  # TFLOPs

colors_cfg = plt.cm.viridis(np.linspace(0.2, 0.8, len(names)))
bars = ax.bar(range(len(names)), total_flops, color=colors_cfg, alpha=0.8, edgecolor='black')
ax.set_xticks(range(len(names)))
ax.set_xticklabels(names, fontsize=8, rotation=0)
ax.set_ylabel('Total TFLOPs', fontsize=12)
ax.set_title('Estimated Compute Cost\nfor Different Configurations', fontsize=14, fontweight='bold')

for bar, flop in zip(bars, total_flops):
    ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.1,
           f'{flop:.1f}T', ha='center', fontsize=9, fontweight='bold')

# Right: Latency per step breakdown
ax = axes[1]
components = ['Text\nEncoder', 'UNet\n(per step)', 'VAE\nDecoder', 'Scheduler\n(per step)']
# Approximate times on T4 GPU at 512x512, FP16
component_times = [15, 40, 30, 0.5]  # ms
component_colors = ['#4CAF50', '#2196F3', '#FF9800', '#9E9E9E']

bars = ax.bar(components, component_times, color=component_colors, alpha=0.8, edgecolor='black')
for bar, t in zip(bars, component_times):
    ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 1,
           f'{t}ms', ha='center', fontsize=11, fontweight='bold')

ax.set_ylabel('Time (milliseconds)', fontsize=12)
ax.set_title('Latency Breakdown per Component\n(512x512, FP16, T4 GPU)', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

print('Key insight: The UNet dominates inference time.')
print('Each denoising step requires one (or two with CFG) UNet forward passes.')
print('Text encoder and VAE decoder run only once, so reducing steps is the main lever.')

## Part 12: Key Takeaways

### Summary

1. **Guidance Scale** controls how strongly the model follows the text prompt. Typical range: 5-12 for good results.

2. **Step Count** controls the quality-speed tradeoff. With modern schedulers (DPM-Solver), 20-25 steps often suffice.

3. **CFG doubles compute**: Each step requires two forward passes (conditional + unconditional). This is the main reason image generation with guidance is slower.

4. **Diminishing returns**: Quality improvements plateau after ~25-30 steps. Going beyond 50 steps rarely helps.

5. **Noise schedules** affect how computational budget is allocated between coarse structure and fine details.

6. **Production tips**: Use FP16, efficient schedulers, smaller preview sizes, and model compilation for fast inference.

### Inference Engineering Perspective

| Parameter | Speed Impact | Quality Impact |
|-----------|-------------|----------------|
| Guidance Scale | ~2x cost (needs 2 fwd passes) | Major (prompt fidelity) |
| Step Count | Linear (more steps = slower) | Diminishing returns |
| Image Size | Quadratic (2x size = ~4x slower) | Higher detail |
| Precision (FP16) | ~1.5-2x faster | Minimal quality loss |
| Scheduler Choice | Same steps, different quality | Major at low step counts |
| Noise Schedule | Same cost | Affects detail quality |

---

## Exercises

### Exercise 1: Find the Optimal Settings for a Prompt
For a given prompt, systematically find the best balance of guidance scale and steps.

In [None]:
# Exercise 1: Test different prompts and find their "sweet spot" settings
# Try: photorealistic vs artistic styles
# Do different styles need different guidance scales?

prompts_to_test = [
    "A photorealistic portrait of an astronaut on Mars",
    "Abstract art in the style of Kandinsky, colorful shapes",
    "A simple pencil sketch of a cat",
]

# TODO: For each prompt, test guidance_scale in [3, 7, 12]
# and steps in [10, 25, 50]. Which combination looks best?

print("Exercise 1: Find the optimal settings for each prompt style!")

### Exercise 2: Negative Prompts
Experiment with negative prompts to avoid unwanted features.

In [None]:
# Exercise 2: Use negative prompts
# pipe(prompt, negative_prompt="blurry, low quality, distorted")

# TODO: Compare generations with and without negative prompts
# Try different negative prompts for the same positive prompt

print("Exercise 2: Experiment with negative prompts!")

### Exercise 3: Scheduler Comparison
Compare different schedulers (DDIM, DPM-Solver, Euler) at the same step count.

In [None]:
# Exercise 3: Compare schedulers
# from diffusers import DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler

# TODO: For the same prompt and step count, compare:
# - DDIMScheduler
# - EulerDiscreteScheduler  
# - DPMSolverMultistepScheduler
# Which produces the best results at 15 steps?

print("Exercise 3: Compare different schedulers!")

---

**End of Notebook 17: Image Generation -- Guidance Scale & Steps**

Next: [Notebook 18 - VLM Inference: Image + Text](./18_vlm_inference.ipynb)