# ðŸŽ“ Lesson 4: Advanced Science Lab

Welcome to the advanced lab! Here we will perform scientific experiments on the model internals.

### Experiments:
1. **Scheduler Shootout**: Euler vs DPM++ 2M
2. **LoRA Inspection**: Peeking inside the fine-tuning weights
3. **Step-by-Step Visualization**: Watching the image emerge from noise
4. **The Math of Guidance**: Understanding CFG vectors

In [None]:
# Setup
import sys
import os
from pathlib import Path
import torch
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

project_root = Path(os.getcwd()).parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

from core.pipeline import pipeline_manager
from diffusers import EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler

## 1. Scheduler Shootout: Euler vs DPM++

Schedulers solve the differential equation to generate the image. Let's compare two popular ones:
- **Euler Ancestral**: Adds random noise each step. "Creative" but less accurate.
- **DPM++ 2M Karras**: Uses 2nd order math to approximate the curve. Fast and smooth.

In [None]:
prompt = "a macro photo of a mechanical eye, clockwork, steampunk, 8k, detailed"
steps = 20 # Low step count to emphasize efficiency difference
seed = 42

pipe = pipeline_manager.get_txt2img_pipeline()
generator = torch.Generator(device=pipe.device).manual_seed(seed)

# 1. Run with Euler Ancestral
print("Running Euler Ancestral...")
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
image_euler = pipe(prompt=prompt, num_inference_steps=steps, generator=generator).images[0]

# 2. Run with DPM++ 2M Karras
print("Running DPM++ 2M Karras...")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
    pipe.scheduler.config, 
    use_karras_sigmas=True, 
    algorithm_type="sde-dpmsolver++"
)
generator = torch.Generator(device=pipe.device).manual_seed(seed) # Reset seed
image_dpm = pipe(prompt=prompt, num_inference_steps=steps, generator=generator).images[0]

# Compare
fig, ax = plt.subplots(1, 2, figsize=(15, 7))
ax[0].imshow(image_euler)
ax[0].set_title("Euler Ancestral (20 steps)")
ax[0].axis('off')

ax[1].imshow(image_dpm)
ax[1].set_title("DPM++ 2M Karras (20 steps)")
ax[1].axis('off')
plt.show()

## 2. LoRA Inspection

LoRA (Low-Rank Adaptation) works by adding small matrices to the big model. Let's see them!

$$W' = W + B \cdot A$$

We can inspect the fine-tuning code to see these matrices.

In [None]:
from fine_tuning.lora_trainer import LoRATrainer

# We can't easily load a LoRA without a file, but we can look at the config class
from peft import LoraConfig

config = LoraConfig(r=4, lora_alpha=4, target_modules=["to_k", "to_q", "to_v", "to_out.0"])

print(f"LoRA Rank (r): {config.r}")
print(f"Target Modules: {config.target_modules}")
print("When we train, we only train matrices of size (dim x 4) and (4 x dim)!")

## 3. Step-by-Step Visualization

How does the image emerge? Let's check the latent at step 10 vs step 30.

In [None]:
# Define a callback function to capture intermediates
latents_history = []

def decode_latents(latents, pipe):
    """Helper to decode latents to images"""
    with torch.no_grad():
        latents = 1 / 0.18215 * latents
        image = pipe.vae.decode(latents).sample
        image = (image / 2 + 0.5).clamp(0, 1)
        # Convert to PIL
        image = image.cpu().permute(0, 2, 3, 1).numpy()
        image = (image * 255).round().astype("uint8")
        return [Image.fromarray(img) for img in image]

def callback(pipe, step, timestep, callback_kwargs):
    # Capture every 10 steps
    if step % 10 == 0:
        latents = callback_kwargs.get("latents")
        if latents is not None:
             # Decode and store (only the first image in batch)
             images = decode_latents(latents, pipe)
             latents_history.append((step, images[0]))
    return callback_kwargs

print("Generating with step visualization...")
prompt = "a cute robot painting a canvas, high resolution"

# Get pipeline and run
pipe = pipeline_manager.get_txt2img_pipeline()
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)

# We need to use a generator for reproducibility
generator = torch.Generator(device=pipe.device).manual_seed(42)

# Run generation with callback
output = pipe(
    prompt=prompt, 
    num_inference_steps=31, 
    generator=generator,
    callback_on_step_end=callback
).images[0]

print(f"Captured {len(latents_history)} intermediate steps.")

# Visualize
fig, axs = plt.subplots(1, len(latents_history) + 1, figsize=(20, 5))

# Plot intermediates
for i, (step, img) in enumerate(latents_history):
    axs[i].imshow(img)
    axs[i].set_title(f"Step {step}")
    axs[i].axis('off')

# Plot final
axs[-1].imshow(output)
axs[-1].set_title("Final Result")
axs[-1].axis('off')

plt.show()

## 4. The Math of Guidance (Simulator)

Classifier-Free Guidance (CFG) is the "Magic Spell" of Stable Diffusion.

$$ \epsilon_{final} = \epsilon_{uncond} + w \cdot (\epsilon_{cond} - \epsilon_{uncond}) $$

Let's visualize this with a simple 2D vector simulation.

In [None]:
# === CFG SIMULATOR ===
# Imagine our image is just a point in 2D space for simplicity.

def simulate_cfg(weight):
    # 1. Unconditioned Prediction (Model guessing without prompt)
    # It pushes towards generic "average" images
    uncond_vector = np.array([2.0, 1.0]) 

    # 2. Conditioned Prediction (Model pushing towards "Red Apple")
    cond_vector = np.array([4.0, 5.0])

    # 3. The Difference (The "Concept" of Red Apple)
    # This vector represents pure "Red Apple-ness" without the generic image parts
    concept_vector = cond_vector - uncond_vector

    # 4. Final Vector
    final_vector = uncond_vector + weight * concept_vector
    
    return uncond_vector, cond_vector, final_vector

# Visualize various weights
weights = [1.0, 7.0, 15.0]
colors = ['green', 'orange', 'red']

plt.figure(figsize=(10, 8))
plt.axvline(0, color='gray', alpha=0.3)
plt.axhline(0, color='gray', alpha=0.3)
plt.xlim(0, 40)
plt.ylim(0, 65)

for i, w in enumerate(weights):
    uncond, cond, final = simulate_cfg(w)
    
    # Plot the result
    plt.arrow(0, 0, final[0], final[1], head_width=1, head_length=1, fc=colors[i], ec=colors[i], label=f'CFG {w} (Result)')
    
    # Plot the Original Conditioned (Reference)
    if i == 0:
        plt.arrow(0, 0, cond[0], cond[1], head_width=0.5, head_length=0.5, fc='blue', ec='blue', alpha=0.5, linestyle=':', label='Raw Prompt Target')
        plt.arrow(0, 0, uncond[0], uncond[1], head_width=0.5, head_length=0.5, fc='gray', ec='gray', alpha=0.5, linestyle=':', label='Unconditional Base')

plt.title("Visualizing Classifier-Free Guidance Vectors", fontsize=15)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print("Observation:")
print("- CFG 1.0 (Green): Exactly matches the raw prompt target.")
print("- CFG 7.0 (Orange): Extrapolates FURTHER in that direction. This is 'Standard' SD.")
print("- CFG 15.0 (Red): Pushes extremely far. This can lead to 'frying' (oversaturation).")