# ðŸŽ“ Lesson 1: The Physics of Stable Diffusion (From Scratch)

Welcome to the **Deep Dive**! Most courses teach you how to "drive" the car (using the `StableDiffusionPipeline`). Today, we are going to **build the engine**.

We will manually orchestrate the neural networks to generate an image, explaining the **science and math** at every step.

### ðŸ”¬ What we will explore:
1.  **The Manifold Hypothesis**: Why we need a VAE to compress images.
2.  **CLIP Embeddings**: Converting language into high-dimensional vectors.
3.  **The Physics of Diffusion**: How reversing a thermodynamic process generates art.
4.  **The Denoising Loop**: Manually stepping through the differential equation solver.

In [None]:
# Setup - Import necessary libraries
import torch
from PIL import Image
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers import UniPCMultistepScheduler
import sys
import os
from pathlib import Path
import matplotlib.pyplot as plt

# Boilerplate to import our local config
project_root = Path(os.getcwd()).parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))
from config import device_config

device = device_config.get_device()
dtype = device_config.dtype
print(f"Using device: {device}, dtype: {dtype}")

## 1. The Components of Creation

Stable Diffusion isn't one model; it's a team of three specialized neural networks working together.

### ðŸ§  The Team:
1.  **VAE (Variational Autoencoder)**: The *Data Compressor*. It shrinks huge images (512x512) into tiny "latent interactions" (64x64) so they fit in GPU memory.
2.  **Text Encoder (CLIP)**: The *Translator*. It turns English text into a 768-dimensional numerical coordinate system.
3.  **U-Net**: The *Artist*. A massive ResNet that predicts noise. It takes a noisy latent + text embeddings and outputs "estimated noise".

In [None]:
model_id = "runwayml/stable-diffusion-v1-5"

# 1. Load VAE
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=dtype).to(device)

# 2. Load Text Encoder & Tokenizer
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=dtype).to(device)

# 3. Load U-Net
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=dtype).to(device)

# 4. Load Scheduler (The Math Solver)
scheduler = UniPCMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")

print("âœ… All systems loaded directly into VRAM!")

## 2. Text to Vector Space (Math of Language)

Computers don't understand words; they understand vectors. We use **CLIP** (Contrastive Language-Image Pre-Training).

CLIP maps images and text to the *same* high-dimensional space. If the vector for "dog" is `[0.1, 0.9, ...]`, the U-Net uses these numbers to guide its generation.

In [None]:
prompt = ["a cyberpunk detective standing in rain, neon lights, highly detailed, 8k"]
height = 512
width = 512
num_inference_steps = 25
guidance_scale = 7.5 # Controls how strictly we follow the prompt
generator = torch.manual_seed(42)
batch_size = len(prompt)

# --- DEEP DIVE: What does the model actually see? ---

# 1. Tokenization: Break text into tokens (numbers)
# The tokenizer has a vocabulary of ~49k sub-words.
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")

print("MATCHING WORDS TO NUMBERS:")
print(f"Raw Input: '{prompt[0]}'")

# Let's peek at the first 15 tokens
raw_ids = text_input.input_ids[0][:15]
decoded_tokens = tokenizer.convert_ids_to_tokens(raw_ids)

for token_id, token_str in zip(raw_ids, decoded_tokens):
    print(f"  {token_id.item():<5} -> '{token_str}'")
    
print(f"\nTotal Sequence Length: {text_input.input_ids.shape[1]} (Standard for SD is 77 tokens)")
print("Notice the starting '<|startoftext|>' and ending '<|endoftext|>'? These are crucial flags for the model.")
print("Everything after the end token is 'padding' (49407) which the model ignores.\n")

# 2. Encoding: Convert tokens to 768-dim Vectors
with torch.no_grad():
    text_embeddings = text_encoder(text_input.input_ids.to(device))[0]

# 3. Classifier-Free Guidance (The Magic Trick)
# We also generate an "empty" embedding. We will later calculate:
# Final = Empty + Scale * (Text - Empty)
# This pushes the image AWAY from "generic/boring" and TOWARDS our prompt.
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
with torch.no_grad():
    uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]

# Stitch them together for batch processing
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

print(f"Embedding Shape: {text_embeddings.shape} (2 batches x 77 tokens x 768 dimensions)")
print("We have successfully translated English into 'Math'!")

## 3. The Latent Space & Initial Noise

### Why Latent Space?
Generating a 512x512 image means manipulating **786,432 pixels** ($512 \times 512 \times 3$). This is too slow.
Instead, we work in **Latent Space** ($64 \times 64 \times 4 = 16,384$ values). 
This is a **48x compression**! The VAE handles the translation.

We start with pure Gaussian noise $N(0, I)$.

In [None]:
latents = torch.randn(
    (batch_size, unet.config.in_channels, height // 8, width // 8),
    generator=generator,
    device=device,
    dtype=dtype
)

# Scale noise for the scheduler (Sigma)
latents = latents * scheduler.init_noise_sigma

print(f"Latent Shape: {latents.shape}")

## 4. The Denoising Loop (Solving the ODE)

Diffusion is a **physics simulation**. We are reversing entropy.

The equation we are solving is the **Probability Flow ODE**.
$$ dx = -\dot{\sigma}(t) \sigma(t) \nabla_x \log p_t(x) dt $$

In English: "Move the data point $x$ against the gradient of noise density."

We will visualize the image decoding every 5 steps so you can see the structure emerging from chaos.

In [None]:
scheduler.set_timesteps(num_inference_steps)

def decode_latents(latents):
    # Helper to peek at the image during generation
    l = 1 / 0.18215 * latents
    with torch.no_grad():
        image = vae.decode(l).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).float().numpy()
    return image[0]

print("Starting Denoising Loop...")

for i, t in enumerate(pattern for pattern in tqdm(scheduler.timesteps)):
    # 1. Expand latents for our double-pass (Conditioned + Unconditioned)
    latent_model_input = torch.cat([latents] * 2)
    latent_model_input = scheduler.scale_model_input(latent_model_input, t)

    # 2. U-Net Prediction: "What part of this image is noise?"
    # This involves massive matrix multiplications in the ResNet blocks and Attention layers
    with torch.no_grad():
        noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

    # 3. Classifier-Free Guidance Math
    # epsilon_pred = epsilon_uncond + s * (epsilon_text - epsilon_uncond)
    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

    # 4. Scheduler Step: Update x_t -> x_{t-1}
    latents = scheduler.step(noise_pred, t, latents).prev_sample
    
    # Visualization every 10 steps
    if (i + 1) % 5 == 0:
        img = decode_latents(latents)
        plt.figure(figsize=(3,3))
        plt.imshow(img)
        plt.title(f"Step {i+1}/{num_inference_steps}")
        plt.axis('off')
        plt.show()

## 5. Final Decode

The loop finishes with a clean latent. We pass it through the VAE Decoder to get our final pixel-perfect image.

In [None]:
final_image_arr = decode_latents(latents)
final_image = Image.fromarray((final_image_arr * 255).round().astype("uint8"))
final_image

## ðŸ“š Independent Study

Try changing the variables above to explore the science:

1.  **Guidance Scale**: Change `guidance_scale` to `1.0`. The image will look good but might not match the prompt. Why?
2.  **Timesteps**: Change `num_inference_steps` to `5`. It will look blurry/noisy. This is **under-sampling** the differential equation.
3.  **Seed**: Change the `seed` to get a different noise pattern, resulting in a completely different composition.