In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Full Pipeline: Multi-Scale Score Matching and Generation -- Vizuara

## 1. Why Does This Matter?

In the previous notebooks, we trained a score network using a single noise level. This works for simple distributions, but real-world data has complex multi-modal structure. Regions between modes are low-density "deserts" where the score is hard to learn.

The solution: use **multiple noise scales**. This is the bridge between score matching and modern diffusion models. By training the score network across many noise levels, we get accurate score estimates everywhere -- from the broad global structure (high noise) to fine local details (low noise).

**By the end of this notebook, you will:**
- Implement multi-scale denoising score matching
- Train a noise-conditioned score network
- Implement annealed Langevin dynamics
- Generate samples from a complex multi-modal distribution
- Understand the direct connection to DDPM and Score SDEs

## 2. Building Intuition

### Why Multiple Noise Scales?

Imagine you are trying to find a specific house in a city. With a single level of zoom:
- If you zoom out too far (high noise), you can see the city but not individual houses
- If you zoom in too close (low noise), you can see one house but cannot navigate across neighborhoods

The solution: start zoomed out, get to the right neighborhood, then zoom in progressively.

This is exactly what annealed Langevin dynamics does: start with high noise (global navigation), then progressively reduce noise (local refinement).

### Think About This
- Why would a single low noise level fail for a distribution with widely separated modes?
- Why would a single high noise level produce blurry samples?

## 3. The Mathematics

### Noise-Conditioned Score Network

Instead of one score network, we train a single network conditioned on the noise level $\sigma$:

$$s_\theta(x, \sigma) \approx \nabla_x \log q_\sigma(x)$$

where $q_\sigma(x)$ is the data distribution convolved with Gaussian noise of standard deviation $\sigma$.

### Multi-Scale DSM Loss

$$J(\theta) = \sum_{i=1}^{L} \lambda(\sigma_i) \cdot \mathbb{E}_{p(x)\, q(\tilde{x}|x)} \left[\left\|s_\theta(\tilde{x}, \sigma_i) + \frac{\tilde{x} - x}{\sigma_i^2}\right\|^2\right]$$

Computationally: for each noise level $\sigma_i$, add noise to the data, compute the target score, predict the score conditioned on $\sigma_i$, and take the MSE. Sum across all noise levels.

### Annealed Langevin Dynamics

Start from pure noise and run Langevin dynamics at decreasing noise levels:

For $i = L, L-1, \ldots, 1$:

$$x_{t+1} = x_t + \eta_i \cdot s_\theta(x_t, \sigma_i) + \sqrt{2\eta_i} \cdot \epsilon$$

where $\eta_i = c \cdot \sigma_i^2 / \sigma_L^2$ scales the step size with the noise level.

## 4. Let's Build It -- Component by Component

### 4.1 Noise-Conditioned Score Network

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

class ConditionalScoreNet(nn.Module):
    """
    Score network conditioned on noise level sigma.

    The noise level is encoded and concatenated with the input,
    allowing a single network to predict scores at all noise levels.
    """
    def __init__(self, dim=2, hidden=256, n_sigmas=10):
        super().__init__()
        self.sigma_embed = nn.Embedding(n_sigmas, hidden)
        self.input_proj = nn.Linear(dim, hidden)

        self.net = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden, hidden),
            nn.SiLU(),
            nn.Linear(hidden, hidden),
            nn.SiLU(),
            nn.Linear(hidden, dim),
        )

    def forward(self, x, sigma_idx):
        """
        Args:
            x: Input points (batch, dim)
            sigma_idx: Index into the noise schedule (batch,) -- integer indices
        Returns:
            Predicted score (batch, dim)
        """
        h = self.input_proj(x) + self.sigma_embed(sigma_idx)
        return self.net(h)

### 4.2 Geometric Noise Schedule

In [None]:
def get_noise_schedule(sigma_min=0.01, sigma_max=5.0, n_levels=10):
    """
    Create a geometric noise schedule from sigma_max to sigma_min.
    Geometric spacing ensures good coverage of all scales.
    """
    sigmas = torch.exp(torch.linspace(
        np.log(sigma_max), np.log(sigma_min), n_levels
    ))
    return sigmas

sigmas = get_noise_schedule(sigma_min=0.01, sigma_max=5.0, n_levels=10)
print("Noise schedule (sigma_max -> sigma_min):")
for i, s in enumerate(sigmas):
    print(f"  Level {i}: sigma = {s.item():.4f}")

In [None]:
# Visualization: effect of each noise level on data
data_demo = torch.tensor([[-2.0, 0.0], [2.0, 0.0]]).repeat(200, 1)
data_demo = data_demo + torch.randn_like(data_demo) * 0.3

fig, axes = plt.subplots(2, 5, figsize=(18, 7))
for i, (ax, sigma) in enumerate(zip(axes.flat, sigmas)):
    noisy = data_demo + torch.randn_like(data_demo) * sigma
    ax.scatter(noisy[:, 0].numpy(), noisy[:, 1].numpy(), s=3, alpha=0.3, c='steelblue')
    ax.set_title(f'sigma={sigma:.3f}', fontsize=11)
    ax.set_xlim(-8, 8)
    ax.set_ylim(-6, 6)
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.2)

plt.suptitle('Data at Different Noise Levels', fontsize=15)
plt.tight_layout()
plt.show()
print("High noise: modes merge. Low noise: modes are distinct.")
print("The network must learn the score at EVERY level.")

### 4.3 Complex Training Data

In [None]:
def sample_complex_distribution(n, pattern='four_gaussians'):
    """Generate samples from a more complex distribution."""
    if pattern == 'four_gaussians':
        centers = [[-2, -2], [-2, 2], [2, -2], [2, 2]]
    elif pattern == 'ring':
        angles = torch.rand(n) * 2 * np.pi
        r = 2.0 + torch.randn(n) * 0.2
        return torch.stack([r * torch.cos(angles), r * torch.sin(angles)], dim=-1)
    elif pattern == 'two_moons':
        # Upper moon
        n1 = n // 2
        angles1 = torch.linspace(0, np.pi, n1)
        x1 = torch.cos(angles1) + torch.randn(n1) * 0.1
        y1 = torch.sin(angles1) + torch.randn(n1) * 0.1
        # Lower moon
        n2 = n - n1
        angles2 = torch.linspace(0, np.pi, n2)
        x2 = 1 - torch.cos(angles2) + torch.randn(n2) * 0.1
        y2 = -torch.sin(angles2) + 0.5 + torch.randn(n2) * 0.1
        return torch.cat([
            torch.stack([x1, y1], dim=-1),
            torch.stack([x2, y2], dim=-1)
        ], dim=0)
    else:
        centers = [[-2, 0], [2, 0]]

    centers = torch.tensor(centers, dtype=torch.float)
    idx = torch.randint(0, len(centers), (n,))
    return torch.randn(n, 2) * 0.3 + centers[idx]

# Visualize all three patterns
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for ax, pattern in zip(axes, ['four_gaussians', 'ring', 'two_moons']):
    data = sample_complex_distribution(2000, pattern)
    ax.scatter(data[:, 0].numpy(), data[:, 1].numpy(), s=3, alpha=0.3, c='steelblue')
    ax.set_title(pattern.replace('_', ' ').title(), fontsize=13)
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)
plt.suptitle('Available Training Distributions', fontsize=14)
plt.tight_layout()
plt.show()

### 4.4 Multi-Scale DSM Training

In [None]:
def multi_scale_dsm_loss(model, x, sigmas):
    """
    Multi-scale denoising score matching loss.

    For each sample in the batch:
    1. Randomly pick a noise level
    2. Add noise at that level
    3. Compute the target score
    4. Predict the score (conditioned on noise level)
    5. MSE between prediction and target, weighted by sigma^2
    """
    batch_size = x.shape[0]
    n_sigmas = len(sigmas)

    # Random noise level for each sample
    sigma_idx = torch.randint(0, n_sigmas, (batch_size,))
    sigma = sigmas[sigma_idx].unsqueeze(-1)  # (batch, 1)

    # Add noise
    noise = torch.randn_like(x)
    x_noisy = x + sigma * noise

    # Target score: -(x_noisy - x) / sigma^2 = -noise / sigma
    target = -noise / sigma

    # Predicted score
    pred = model(x_noisy, sigma_idx)

    # Weighted MSE (weight by sigma^2 as suggested by Song & Ermon)
    weights = sigma.squeeze() ** 2
    loss = (weights * ((pred - target) ** 2).sum(dim=-1)).mean()

    return loss

In [None]:
# Train the noise-conditioned score network
N_SIGMAS = 10
sigmas = get_noise_schedule(sigma_min=0.01, sigma_max=5.0, n_levels=N_SIGMAS)
PATTERN = 'four_gaussians'

model = ConditionalScoreNet(dim=2, hidden=256, n_sigmas=N_SIGMAS)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
losses = []

for epoch in range(3000):
    x = sample_complex_distribution(512, PATTERN)
    loss = multi_scale_dsm_loss(model, x, sigmas)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

    if (epoch + 1) % 500 == 0:
        print(f"Epoch {epoch+1:5d} | Loss: {loss.item():.4f}")

In [None]:
# Visualization checkpoint: loss curve
plt.figure(figsize=(10, 4))
plt.semilogy(losses, alpha=0.3, linewidth=0.5, color='blue')
smoothed = np.convolve(losses, np.ones(50)/50, mode='valid')
plt.semilogy(smoothed, linewidth=2, color='red')
plt.xlabel('Epoch')
plt.ylabel('Multi-Scale DSM Loss (log)')
plt.title('Training Loss')
plt.grid(True, alpha=0.3)
plt.show()

## 5. Your Turn -- Implement Annealed Langevin Dynamics

In [None]:
def annealed_langevin_dynamics(model, sigmas, n_samples=500,
                                steps_per_level=100, eps=0.01):
    """
    TODO: Implement annealed Langevin dynamics.

    For each noise level (from highest to lowest):
    1. Set the step size: eta = eps * (sigma_i / sigma_max)^2
    2. Run Langevin dynamics for steps_per_level steps at this noise level
    3. Use model(x, sigma_idx) to get the score conditioned on sigma_i

    Start from random Gaussian noise: x ~ N(0, sigma_max^2 * I)

    Args:
        model: Noise-conditioned score network
        sigmas: Noise schedule (highest to lowest)
        n_samples: Number of parallel samples
        steps_per_level: Steps per noise level
        eps: Base step size

    Returns:
        x: Final samples (n_samples, 2)
        history: List of snapshots for visualization
    """
    dim = 2
    sigma_max = sigmas[0]
    x = torch.randn(n_samples, dim) * sigma_max
    history = [x.clone()]

    for i in range(len(sigmas)):
        sigma = sigmas[i]
        sigma_idx = torch.full((n_samples,), i, dtype=torch.long)

        # ============ TODO ============
        # Step size scales with noise level
        eta = ???  # eps * (sigma / sigma_max) ** 2

        for step in range(steps_per_level):
            # Compute score and update x
            score = ???  # model(x, sigma_idx) -- use torch.no_grad()!
            noise = ???  # torch.randn_like(x)
            x = ???      # Langevin update
        # ==============================

        history.append(x.clone())

    return x, history

In [None]:
# Solution and execution
def annealed_langevin_solution(model, sigmas, n_samples=500,
                                steps_per_level=100, eps=0.005):
    dim = 2
    sigma_max = sigmas[0]
    x = torch.randn(n_samples, dim) * sigma_max
    history = [x.clone()]

    for i in range(len(sigmas)):
        sigma = sigmas[i]
        sigma_idx = torch.full((n_samples,), i, dtype=torch.long)
        eta = eps * (sigma / sigma_max) ** 2

        for step in range(steps_per_level):
            with torch.no_grad():
                score = model(x, sigma_idx)
            noise = torch.randn_like(x)
            x = x + eta * score + (2 * eta) ** 0.5 * noise

        history.append(x.clone())
        print(f"  Level {i}: sigma={sigma:.4f}, eta={eta:.6f}")

    return x, history

print("Running annealed Langevin dynamics...")
samples, history = annealed_langevin_solution(
    model, sigmas, n_samples=1000, steps_per_level=100, eps=0.005
)

## 5.2 Your Turn -- Visualize the Annealing Process

In [None]:
# TODO: Create a visualization showing samples at each noise level
# The samples should progressively sharpen from blurry blobs to
# well-separated clusters

fig, axes = plt.subplots(2, 5, figsize=(18, 7))

for i, (ax, snap) in enumerate(zip(axes.flat, history[:10])):
    ax.scatter(snap[:, 0].numpy(), snap[:, 1].numpy(),
               s=3, alpha=0.3, c='steelblue')
    if i == 0:
        ax.set_title('Initial Noise', fontsize=11)
    elif i < len(sigmas):
        ax.set_title(f'After sigma={sigmas[i-1]:.3f}', fontsize=10)
    else:
        ax.set_title('Final', fontsize=11)
    ax.set_xlim(-6, 6)
    ax.set_ylim(-6, 6)
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.2)

plt.suptitle('Annealed Langevin Dynamics: Progressive Refinement', fontsize=15)
plt.tight_layout()
plt.show()
print("Samples start as noise and progressively form the target distribution!")

## 6. Putting It All Together

In [None]:
# Score fields at different noise levels
fig, axes = plt.subplots(2, 5, figsize=(18, 7))
n_g = 15
g = torch.linspace(-5, 5, n_g)
G1, G2 = torch.meshgrid(g, g, indexing='ij')
gp = torch.stack([G1.flatten(), G2.flatten()], dim=-1)

for i, (ax, sigma) in enumerate(zip(axes.flat, sigmas)):
    sigma_idx = torch.full((n_g*n_g,), i, dtype=torch.long)
    with torch.no_grad():
        s = model(gp, sigma_idx)

    ax.quiver(G1.numpy(), G2.numpy(),
              s[:, 0].reshape(n_g, n_g).numpy(),
              s[:, 1].reshape(n_g, n_g).numpy(),
              color='darkblue', scale=80, width=0.005)
    ax.set_title(f'sigma={sigma:.3f}', fontsize=10)
    ax.set_aspect('equal')
    ax.set_xlim(-5, 5)
    ax.set_ylim(-5, 5)

plt.suptitle('Learned Score Fields at Each Noise Level', fontsize=15)
plt.tight_layout()
plt.show()
print("High noise: broad structure. Low noise: fine details.")

## 7. Training and Results

In [None]:
# Compare generated samples with true data
true_data = sample_complex_distribution(2000, PATTERN)

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

axes[0].scatter(true_data[:, 0].numpy(), true_data[:, 1].numpy(),
                s=5, alpha=0.3, c='steelblue')
axes[0].set_title('True Data', fontsize=14)

axes[1].scatter(samples[:, 0].numpy(), samples[:, 1].numpy(),
                s=5, alpha=0.3, c='coral')
axes[1].set_title('Generated Samples', fontsize=14)

for ax in axes:
    ax.set_aspect('equal')
    ax.set_xlim(-5, 5)
    ax.set_ylim(-5, 5)
    ax.grid(True, alpha=0.3)

plt.suptitle('Multi-Scale Score Matching: True vs Generated', fontsize=15)
plt.tight_layout()
plt.show()
print("The generated samples match the four-mode target distribution!")

## 8. Final Output

In [None]:
# Grand finale: train on the two-moons dataset and generate
print("Training on Two Moons dataset...")
model_moons = ConditionalScoreNet(dim=2, hidden=256, n_sigmas=N_SIGMAS)
opt_moons = torch.optim.Adam(model_moons.parameters(), lr=1e-3)

for epoch in range(3000):
    x = sample_complex_distribution(512, 'two_moons')
    loss = multi_scale_dsm_loss(model_moons, x, sigmas)
    opt_moons.zero_grad()
    loss.backward()
    opt_moons.step()
    if (epoch + 1) % 1000 == 0:
        print(f"  Epoch {epoch+1} | Loss: {loss.item():.4f}")

print("Sampling...")
moon_samples, moon_hist = annealed_langevin_solution(
    model_moons, sigmas, n_samples=2000, steps_per_level=100, eps=0.005
)

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# True
true_moons = sample_complex_distribution(2000, 'two_moons')
axes[0].scatter(true_moons[:, 0], true_moons[:, 1], s=3, alpha=0.3, c='steelblue')
axes[0].set_title('True Two Moons', fontsize=14)

# Score field
n_g = 20
g = torch.linspace(-1.5, 2.5, n_g)
gy = torch.linspace(-1.5, 1.5, n_g)
G1, G2 = torch.meshgrid(g, gy, indexing='ij')
gp = torch.stack([G1.flatten(), G2.flatten()], dim=-1)
with torch.no_grad():
    s = model_moons(gp, torch.zeros(n_g*n_g, dtype=torch.long))  # lowest noise
axes[1].quiver(G1.numpy(), G2.numpy(),
               s[:, 0].reshape(n_g, n_g).numpy(),
               s[:, 1].reshape(n_g, n_g).numpy(),
               color='darkblue', scale=80, width=0.004)
axes[1].set_title('Learned Score Field', fontsize=14)

# Generated
axes[2].scatter(moon_samples[:, 0].numpy(), moon_samples[:, 1].numpy(),
                s=3, alpha=0.3, c='coral')
axes[2].set_title('Generated Samples', fontsize=14)

for ax in axes:
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)

plt.suptitle('Score-Based Generation on Two Moons', fontsize=15, y=1.02)
plt.tight_layout()
plt.show()

print("\nThis is exactly what modern diffusion models do, but with MANY more noise levels")
print("and on high-dimensional data (images, audio, video).")
print("\nThe path: EBMs -> Score Function -> DSM -> Multi-Scale DSM -> DDPM -> Score SDE")
print("You have now understood the entire intellectual foundation!")

## 9. Reflection and Next Steps

### Think About These Questions:
1. We used 10 noise levels. What would happen with 100? 1000? (Hint: in the limit, you get a continuous-time SDE)
2. How does the noise schedule (geometric vs linear) affect generation quality?
3. Modern diffusion models (DDPM, Score SDE) use hundreds or thousands of noise levels and a U-Net architecture. What are the key differences from our simple implementation?
4. Could you apply this technique to 1D audio waveforms? What would change?

### Challenge Exercise:
Try modifying the training to use the 'ring' distribution instead of 'four_gaussians'. Does the model learn the circular structure? How many noise levels do you need?

### The Big Picture:
You have now traced the complete intellectual path from energy-based models to modern diffusion models:
- **Energy functions** assign likelihoods via the Boltzmann distribution
- **The score function** bypasses the intractable partition function
- **Score matching** learns the score from data alone
- **Denoising score matching** simplifies this to noise prediction
- **Multi-scale DSM** captures structure at all scales
- **DDPM and Score SDEs** are the industrial-strength versions of what we built here

Congratulations -- you now understand the foundations of generative AI!