In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Score Functions and Energy-Based Models -- Vizuara

## 1. Why Does This Matter?

In generative modeling, we want to learn the probability distribution of our data so that we can generate new, realistic samples. But there is a fundamental problem: for complex, high-dimensional data, computing the probability density directly is **intractable** because of the dreaded partition function.

The **score function** offers an elegant escape. Instead of modeling the density itself, we model its gradient -- a vector field that points toward regions of high probability. This simple shift unlocks an entire family of powerful generative models, including the diffusion models behind DALL-E, Stable Diffusion, and Sora.

By the end of this notebook, you will:
- Understand what the score function is and why it bypasses the partition function
- Implement and visualize score functions for simple distributions
- Build a neural network that learns the score function using the **tractable score matching** objective
- See why this objective is computationally expensive and why we need something better

## 2. Building Intuition

### The Compass Analogy

Imagine you are standing in a vast, foggy landscape. You know there are treasure chests hidden somewhere, but you cannot see them. All you have is a magical compass that, at any point, tells you the direction where the treasure density is highest.

That compass is the **score function**. It does not tell you the exact probability of finding treasure at your location. Instead, it tells you which direction to walk to increase your chances the most.

### Why Not Just Model the Probability?

If we use an energy function $E_\theta(x)$ to model how "unlikely" a data point is, the probability becomes:

$$p_\theta(x) = \frac{1}{Z(\theta)} \exp(-E_\theta(x))$$

The normalization constant $Z(\theta) = \int \exp(-E_\theta(x)) dx$ requires integrating over ALL possible data points. For a 28x28 image, that is an integral over a 784-dimensional space. Completely intractable.

But the score function -- the gradient of the log density -- eliminates $Z$ entirely!

## 3. The Mathematics

### Definition of the Score Function

The score function is defined as:

$$s(x) = \nabla_x \log p(x)$$

This gives us a vector at every point in space, pointing in the direction of steepest increase in log-probability.

**Computationally, this means:** for each dimension of x, compute the partial derivative of the log-density. The result is a vector of the same dimension as x.

### Why the Partition Function Vanishes

Starting from the energy-based density:

$$\log p_\theta(x) = -E_\theta(x) - \log Z(\theta)$$

Taking the gradient with respect to x:

$$\nabla_x \log p_\theta(x) = -\nabla_x E_\theta(x) - \nabla_x \log Z(\theta)$$

Since $Z(\theta)$ does not depend on x, $\nabla_x \log Z(\theta) = 0$, so:

$$s_\theta(x) = -\nabla_x E_\theta(x)$$

The partition function has vanished completely.

### Numerical Example

Let us verify this with concrete numbers. Suppose $E_\theta(x) = (x - 2)^2$ and $Z = 10$.

$$\log p_\theta(x) = -(x-2)^2 - \log(10)$$

$$s_\theta(x) = \nabla_x \log p_\theta(x) = -2(x-2)$$

At $x = 5$: $s_\theta(5) = -2(5-2) = -6$ (points toward the minimum at $x=2$).

At $x = 0$: $s_\theta(0) = -2(0-2) = 4$ (points toward $x=2$ from the left).

The $\log(10)$ from the partition function disappeared. This is exactly what we want.

## 4. Let's Build It -- Component by Component

### 4.1 Score Function for a 1D Gaussian

Let us start by computing and visualizing the score function for a simple Gaussian distribution.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

# Define a 1D Gaussian distribution
mu, sigma = 0.0, 1.0

# Create a grid of x values
x = torch.linspace(-4, 4, 200)

# Compute the probability density
p_x = (1 / (sigma * np.sqrt(2 * np.pi))) * torch.exp(-0.5 * ((x - mu) / sigma) ** 2)

# Compute the score function analytically: s(x) = -(x - mu) / sigma^2
score = -(x - mu) / (sigma ** 2)

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Probability density
axes[0].plot(x.numpy(), p_x.numpy(), 'b-', linewidth=2)
axes[0].set_title('Gaussian Probability Density p(x)')
axes[0].set_xlabel('x')
axes[0].set_ylabel('p(x)')
axes[0].grid(True, alpha=0.3)

# Plot 2: Score function
axes[1].plot(x.numpy(), score.numpy(), 'r-', linewidth=2)
axes[1].axhline(y=0, color='k', linestyle='--', alpha=0.3)
axes[1].set_title('Score Function s(x) = -x')
axes[1].set_xlabel('x')
axes[1].set_ylabel('s(x)')
axes[1].grid(True, alpha=0.3)

# Add arrows showing score direction at key points
for xi in [-3, -2, -1, 1, 2, 3]:
    si = -xi
    axes[1].annotate('', xy=(xi + 0.3 * np.sign(si), si),
                     xytext=(xi, si),
                     arrowprops=dict(arrowstyle='->', color='green', lw=2))

plt.tight_layout()
plt.show()

print(f"Score at x=-3: {-(-3):.1f} (points right, toward center)")
print(f"Score at x= 0: {-(0):.1f} (zero, already at peak)")
print(f"Score at x= 3: {-(3):.1f} (points left, toward center)")

### 4.2 Score Field for a 2D Mixture of Gaussians

Now let us compute and visualize the score field for a more interesting distribution -- a mixture of two Gaussians.

In [None]:
# 2D Mixture of Gaussians
def mixture_log_prob(x, means, covs, weights):
    """Compute log probability of a mixture of Gaussians."""
    log_probs = []
    for mean, cov, w in zip(means, covs, weights):
        diff = x - mean
        # For diagonal covariance
        log_p = -0.5 * (diff ** 2 / cov).sum(dim=-1)
        log_p += np.log(w) - 0.5 * torch.log(cov).sum() - np.log(2 * np.pi)
        log_probs.append(log_p)
    log_probs = torch.stack(log_probs, dim=-1)
    return torch.logsumexp(log_probs, dim=-1)

def compute_score_field(x, means, covs, weights):
    """Compute the score using autograd."""
    x_grad = x.clone().requires_grad_(True)
    log_p = mixture_log_prob(x_grad, means, covs, weights)
    log_p_sum = log_p.sum()
    log_p_sum.backward()
    return x_grad.grad.clone()

# Define mixture parameters
means = [torch.tensor([2.0, 2.0]), torch.tensor([-2.0, -2.0])]
covs = [torch.tensor([0.25, 0.25]), torch.tensor([0.25, 0.25])]
weights = [0.5, 0.5]

# Create grid
n_grid = 20
x_range = torch.linspace(-5, 5, n_grid)
y_range = torch.linspace(-5, 5, n_grid)
xx, yy = torch.meshgrid(x_range, y_range, indexing='ij')
grid = torch.stack([xx.flatten(), yy.flatten()], dim=1)

# Compute score field
scores = compute_score_field(grid, means, covs, weights)

# Also compute density for contour plot
with torch.no_grad():
    n_dense = 100
    x_dense = torch.linspace(-5, 5, n_dense)
    y_dense = torch.linspace(-5, 5, n_dense)
    xx_d, yy_d = torch.meshgrid(x_dense, y_dense, indexing='ij')
    grid_dense = torch.stack([xx_d.flatten(), yy_d.flatten()], dim=1)
    log_p = mixture_log_prob(grid_dense, means, covs, weights)
    density = torch.exp(log_p).reshape(n_dense, n_dense)

# Visualize
plt.figure(figsize=(10, 10))
plt.contourf(xx_d.numpy(), yy_d.numpy(), density.numpy(), levels=20, cmap='Blues', alpha=0.5)
plt.colorbar(label='p(x)')
plt.quiver(grid[:, 0].numpy(), grid[:, 1].numpy(),
           scores[:, 0].numpy(), scores[:, 1].numpy(),
           color='red', alpha=0.8, scale=40)
plt.title('Score Field of a 2D Mixture of Gaussians')
plt.xlabel('x1')
plt.ylabel('x2')
plt.axis('equal')
plt.grid(True, alpha=0.2)
plt.show()

print("Notice: All arrows point TOWARD the high-density regions (blue peaks)")
print("This is the score function acting as a compass!")

## 5. Your Turn -- TODO Exercises

### TODO 1: Score Function for a Different Distribution

In [None]:
def compute_score_uniform_mixture(x, centers, radius=1.0):
    """
    Compute the score function for a mixture of 3 Gaussians
    centered at the given positions.

    Args:
        x: tensor of shape (N, 2) -- query points
        centers: list of 3 tensors, each shape (2,)
        radius: standard deviation of each Gaussian

    Returns:
        scores: tensor of shape (N, 2)
    """
    # ============ TODO ============
    # Step 1: Define covariances (use radius^2 for diagonal)
    # Step 2: Define equal weights (1/3 each)
    # Step 3: Call compute_score_field() to get scores
    # ==============================

    scores = None  # YOUR CODE HERE

    return scores

# Test with 3 clusters at triangle vertices
centers = [
    torch.tensor([0.0, 2.0]),
    torch.tensor([-1.73, -1.0]),
    torch.tensor([1.73, -1.0])
]

In [None]:
# Verification cell
test_point = torch.tensor([[0.0, 0.0]])
test_score = compute_score_uniform_mixture(test_point, centers, radius=1.0)
assert test_score is not None, "You need to implement the function!"
assert test_score.shape == (1, 2), f"Expected shape (1, 2), got {test_score.shape}"
# The score at the origin should be small since it is equidistant from all 3 centers
assert torch.norm(test_score) < 2.0, "Score at center should be small"
print("Your implementation works! Now visualize it below.")

In [None]:
# Visualize your result
n_grid = 20
x_range = torch.linspace(-4, 4, n_grid)
y_range = torch.linspace(-4, 4, n_grid)
xx, yy = torch.meshgrid(x_range, y_range, indexing='ij')
grid = torch.stack([xx.flatten(), yy.flatten()], dim=1)

your_scores = compute_score_uniform_mixture(grid, centers)

plt.figure(figsize=(8, 8))
plt.quiver(grid[:, 0].numpy(), grid[:, 1].numpy(),
           your_scores[:, 0].numpy(), your_scores[:, 1].numpy(),
           color='red', alpha=0.8, scale=40)
for c in centers:
    plt.plot(c[0], c[1], 'b*', markersize=15)
plt.title('Your Score Field for 3-Cluster Mixture')
plt.axis('equal')
plt.grid(True, alpha=0.3)
plt.show()

### TODO 2: Tractable Score Matching Loss

In [None]:
def tractable_score_matching_loss(model, data):
    """
    Compute the tractable score matching loss (Hyvarinen 2005).

    L = E[ tr(J_s(x)) + 0.5 * ||s(x)||^2 ]

    Args:
        model: nn.Module that takes (N, D) and returns (N, D) scores
        data: tensor of shape (N, D)

    Returns:
        loss: scalar tensor
    """
    data = data.clone().requires_grad_(True)
    scores = model(data)

    # ============ TODO ============
    # Step 1: Compute ||s(x)||^2 for each sample
    #         Hint: (scores ** 2).sum(dim=-1)
    #
    # Step 2: Compute the trace of the Jacobian
    #         Hint: For each output dimension d, compute
    #         d(s_d) / d(x_d) using torch.autograd.grad
    #         Sum these diagonal elements to get the trace
    #
    # Step 3: Combine: loss = mean(trace + 0.5 * norm_sq)
    # ==============================

    loss = None  # YOUR CODE HERE

    return loss

In [None]:
# Verification
class SimpleLinearScore(torch.nn.Module):
    """Score model that outputs s(x) = -x (perfect for standard normal)."""
    def forward(self, x):
        return -x

simple_model = SimpleLinearScore()
test_data = torch.randn(100, 2)
test_loss = tractable_score_matching_loss(simple_model, test_data)
assert test_loss is not None, "Implement the loss function first!"
print(f"Loss for perfect score model: {test_loss.item():.4f}")
print("(Should be close to -1.0 for 2D standard normal)")
# For s(x)=-x on N(0,I) in 2D: trace = -2, norm_sq = ||x||^2 ~ 2
# Loss = -2 + 0.5*2 = -1.0

## 6. Putting It All Together

Let us train a neural network score model using the tractable score matching loss on our 2D mixture data.

In [None]:
import torch.nn as nn

# Generate training data
def generate_mixture_data(n_samples=2000):
    cluster1 = torch.randn(n_samples // 2, 2) * 0.5 + torch.tensor([2.0, 2.0])
    cluster2 = torch.randn(n_samples // 2, 2) * 0.5 + torch.tensor([-2.0, -2.0])
    return torch.cat([cluster1, cluster2], dim=0)

# Score network
class ScoreNet(nn.Module):
    def __init__(self, dim=2, hidden=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden), nn.SiLU(),
            nn.Linear(hidden, hidden), nn.SiLU(),
            nn.Linear(hidden, dim)
        )
    def forward(self, x):
        return self.net(x)

# Training with tractable score matching
data = generate_mixture_data(2000)
model = ScoreNet()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

losses = []
for epoch in range(500):
    # Shuffle data
    idx = torch.randperm(len(data))
    batch = data[idx[:256]]
    batch.requires_grad_(True)

    scores = model(batch)

    # Compute ||s||^2
    score_norm = (scores ** 2).sum(dim=-1)

    # Compute trace of Jacobian (diagonal elements only)
    trace = torch.zeros(len(batch))
    for d in range(2):  # For each dimension
        grad_d = torch.autograd.grad(
            scores[:, d].sum(), batch, create_graph=True
        )[0][:, d]
        trace += grad_d

    loss = (trace + 0.5 * score_norm).mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

    if (epoch + 1) % 100 == 0:
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

In [None]:
# Visualize training progress
plt.figure(figsize=(10, 4))
plt.plot(losses, alpha=0.7)
plt.title('Tractable Score Matching Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True, alpha=0.3)
plt.show()

## 7. Training and Results

In [None]:
# Visualize the learned score field
n_grid = 20
x_range = torch.linspace(-5, 5, n_grid)
y_range = torch.linspace(-5, 5, n_grid)
xx, yy = torch.meshgrid(x_range, y_range, indexing='ij')
grid = torch.stack([xx.flatten(), yy.flatten()], dim=1)

with torch.no_grad():
    learned_scores = model(grid)

plt.figure(figsize=(10, 10))
plt.scatter(data[:, 0], data[:, 1], alpha=0.1, s=3, c='blue')
plt.quiver(grid[:, 0].numpy(), grid[:, 1].numpy(),
           learned_scores[:, 0].numpy(), learned_scores[:, 1].numpy(),
           color='red', alpha=0.7, scale=50)
plt.title('Learned Score Field (Tractable Score Matching)')
plt.xlabel('x1')
plt.ylabel('x2')
plt.axis('equal')
plt.grid(True, alpha=0.3)
plt.show()

print("The arrows should point toward the two data clusters.")
print("But notice: this required computing the Jacobian trace,")
print("which is O(D^2) per sample. For images, this is too expensive!")

## 8. Final Output

In [None]:
# Side-by-side comparison: true vs learned score fields
fig, axes = plt.subplots(1, 2, figsize=(18, 8))

# True score field
true_scores = compute_score_field(grid.clone(), means, covs, weights)
axes[0].scatter(data[:, 0], data[:, 1], alpha=0.1, s=3, c='blue')
axes[0].quiver(grid[:, 0].numpy(), grid[:, 1].numpy(),
               true_scores[:, 0].numpy(), true_scores[:, 1].numpy(),
               color='green', alpha=0.7, scale=50)
axes[0].set_title('True Score Field', fontsize=14)
axes[0].set_xlabel('x1')
axes[0].set_ylabel('x2')
axes[0].set_aspect('equal')
axes[0].grid(True, alpha=0.3)

# Learned score field
axes[1].scatter(data[:, 0], data[:, 1], alpha=0.1, s=3, c='blue')
axes[1].quiver(grid[:, 0].numpy(), grid[:, 1].numpy(),
               learned_scores[:, 0].numpy(), learned_scores[:, 1].numpy(),
               color='red', alpha=0.7, scale=50)
axes[1].set_title('Learned Score Field (Tractable SM)', fontsize=14)
axes[1].set_xlabel('x1')
axes[1].set_ylabel('x2')
axes[1].set_aspect('equal')
axes[1].grid(True, alpha=0.3)

plt.suptitle('Score Function Learning via Tractable Score Matching', fontsize=16)
plt.tight_layout()
plt.show()

print("The learned field closely matches the true field!")
print("But the Jacobian trace made this SLOW for just 2 dimensions.")
print("Imagine doing this for 784 dimensions (28x28 image)...")
print("This motivates Denoising Score Matching, which we cover next.")

## 9. Reflection and Next Steps

### Think About This

1. **Why does the score function point toward high-probability regions?** Think about what the gradient of log-probability means geometrically.

2. **What happens to the score at a local maximum of the density?** What about at the boundary between two modes?

3. **Why is the Jacobian trace computation O(D^2)?** How many second derivatives do we need to compute?

4. **Can you think of a way to avoid computing the full Jacobian?** (Hint: what if we added some noise to our data first?)

### Extension Challenge

Try modifying the mixture to have 4 clusters arranged in a square pattern. Does the learned score field still capture the structure? What happens if you reduce the number of hidden units in the network?

### What's Next

The tractable score matching loss works but is computationally prohibitive for high-dimensional data. In the next notebook, we will discover **Denoising Score Matching** -- Pascal Vincent's elegant trick that replaces the expensive Jacobian with a simple noise prediction target. This is the foundation of modern diffusion models.