# From Hopfield Networks to Boltzmann Machines

**The Big Question**: What if neural networks could "remember" patterns and complete them from partial information?

Standard neural networks learn to map inputs → outputs. But what if we want something different: given a *corrupted* or *partial* pattern, recover the original? This is what your brain does constantly—you see half a face and instantly recall the whole person.

Hopfield networks and Boltzmann machines approach this through a beautiful idea borrowed from physics: **energy minimization**. Just as a ball rolls downhill to find the lowest point, these networks evolve their state to minimize an "energy" function—and the stored memories sit at the valleys.

This notebook builds intuition from scratch, culminating in training a production-scale Boltzmann machine and connecting to modern transformer attention.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import clear_output
import time

# Add parent directory to path for our utils
import sys
sys.path.append('../..')
from silen_lib.utils import utils

utils.set_seed(42)

# Use MPS if available (Apple Silicon), otherwise CUDA, otherwise CPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


## Why Energy? The Physics Intuition

Before we touch any neural network code, let's understand **why** physicists discovered that thinking about "energy" is powerful.

Consider a ball on a curved surface. You don't need to compute forces, accelerations, or differential equations to predict where it will end up. You just know: **it will roll to the lowest point**.

This is profound. The universe seems to "want" to minimize energy. Water flows downhill. Hot things cool down. Stretched springs contract. Nature finds the path of least resistance.

The question that leads to Hopfield networks: **Can we design a system where "good" states (stored memories) have low energy, so the system naturally evolves toward them?**


In [None]:
# Let's see energy minimization in action
# A simple energy landscape: E(x) = x^4 - 2x^2 (has two valleys)

x = torch.linspace(-2, 2, 200)


In [None]:
def energy(x):
    return x**4 - 2*x**2


In [None]:
energy(x)


In [None]:
# The gradient tells us which direction is "downhill"
def gradient(x):
    return 4*x**3 - 4*x  # derivative of x^4 - 2x^2


In [None]:
# Watch a "ball" roll downhill from a random starting point
# This is gradient descent: move opposite to gradient

def simulate_ball(start_pos, lr=0.1, steps=30):
    """Simulate a ball rolling down the energy landscape."""
    positions = [start_pos]
    pos = start_pos
    
    for _ in range(steps):
        grad = gradient(pos)
        pos = pos - lr * grad  # move opposite to gradient
        positions.append(pos)
    
    return torch.tensor(positions)

# Start from x = 0.3 (slightly right of the peak)
trajectory = simulate_ball(start_pos=0.3)

# Visualize
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(x.numpy(), energy(x).numpy(), 'b-', linewidth=2, label='Energy landscape')
ax.scatter(trajectory.numpy(), energy(trajectory).numpy(), c=range(len(trajectory)), 
           cmap='Reds', s=100, zorder=5)
ax.scatter([trajectory[0]], [energy(trajectory[0])], c='green', s=200, marker='o', label='Start', zorder=6)
ax.scatter([trajectory[-1]], [energy(trajectory[-1])], c='red', s=200, marker='*', label='End', zorder=6)
ax.axhline(y=-1, color='gray', linestyle='--', alpha=0.5, label='Minimum energy = -1')
ax.set_xlabel('State x')
ax.set_ylabel('Energy E(x)')
ax.set_title('A ball rolls to the nearest valley (local minimum)')
ax.legend()
plt.show()

print(f"Started at x = {trajectory[0]:.3f}, ended at x = {trajectory[-1]:.3f}")
print(f"Energy went from {energy(trajectory[0]):.3f} to {energy(trajectory[-1]):.3f}")


**Key insight**: The ball doesn't "know" calculus. It just follows the local slope. Yet it finds a minimum.

Notice it went to x ≈ 1 (the right valley), not x ≈ -1 (the left valley). Starting at x = 0.3, it rolled to the *nearest* valley. This is both a feature and a limitation we'll revisit.

**The deep why**: Energy minimization works because:
1. It's a **scalar function** — one number summarizes "how good" a state is
2. **Gradients point toward improvement** — we always know which direction is better
3. **Fixed points are stable** — once at a minimum, small perturbations return you there

This is exactly what we want for memory: stored patterns should be stable attractors.


In [None]:
# What happens if we start exactly at the peak (x=0)?
trajectory_peak = simulate_ball(start_pos=0.0)
trajectory_peak[-1]  # Stays at 0! Unstable equilibrium.


In [None]:
# But add a tiny perturbation...
trajectory_perturbed = simulate_ball(start_pos=0.001)
trajectory_perturbed[-1]  # Falls into the right valley!


The peak at x=0 has zero gradient, but it's **unstable** — the slightest push sends the ball rolling. Valleys are **stable** — push a little, it returns. This distinction between stable and unstable fixed points is crucial for memory.

---

## The Memory Problem

Let's make this concrete. Imagine you want to store some patterns (like images of digits) in a network, and later **recover** them from noisy or partial versions.

First, let's create some simple 5×5 binary patterns representing digits:


In [None]:
# Binary patterns: +1 (white) and -1 (black)
# These are simplified 5x5 "images" of digits

# Pattern for "0" - a ring shape
pattern_0 = torch.tensor([
    [-1, +1, +1, +1, -1],
    [+1, -1, -1, -1, +1],
    [+1, -1, -1, -1, +1],
    [+1, -1, -1, -1, +1],
    [-1, +1, +1, +1, -1]
], dtype=torch.float32)


In [None]:
# Pattern for "1" - a vertical line
pattern_1 = torch.tensor([
    [-1, -1, +1, -1, -1],
    [-1, -1, +1, -1, -1],
    [-1, -1, +1, -1, -1],
    [-1, -1, +1, -1, -1],
    [-1, -1, +1, -1, -1]
], dtype=torch.float32)


In [None]:
# Pattern for "X" - diagonal cross
pattern_x = torch.tensor([
    [+1, -1, -1, -1, +1],
    [-1, +1, -1, +1, -1],
    [-1, -1, +1, -1, -1],
    [-1, +1, -1, +1, -1],
    [+1, -1, -1, -1, +1]
], dtype=torch.float32)


In [None]:
# | export
def show_patterns(patterns, titles=None):
    """Display binary patterns as images."""
    n = len(patterns)
    fig, axes = plt.subplots(1, n, figsize=(3*n, 3))
    if n == 1:
        axes = [axes]
    
    for i, (ax, pattern) in enumerate(zip(axes, patterns)):
        ax.imshow(pattern.numpy(), cmap='gray', vmin=-1, vmax=1)
        ax.set_xticks([])
        ax.set_yticks([])
        if titles:
            ax.set_title(titles[i], fontsize=14)
    
    plt.tight_layout()
    plt.show()

patterns = [pattern_0, pattern_1, pattern_x]
show_patterns(patterns, titles=['Pattern "0"', 'Pattern "1"', 'Pattern "X"'])


Now here's the problem we want to solve: what if we receive a **corrupted** version of one of these patterns?


In [None]:
# | export
def corrupt_pattern(pattern, noise_fraction=0.2):
    """Flip a fraction of bits randomly to simulate noise/corruption."""
    flat = pattern.flatten().clone()
    n_flip = int(len(flat) * noise_fraction)
    flip_idx = torch.randperm(len(flat))[:n_flip]
    flat[flip_idx] *= -1  # flip the sign
    return flat.view(pattern.shape)


In [None]:
# Corrupt the "0" pattern with 20% noise
corrupted_0 = corrupt_pattern(pattern_0, noise_fraction=0.2)

show_patterns([pattern_0, corrupted_0], titles=['Original "0"', 'Corrupted (20% noise)'])


Can you still tell it's a "0"? Probably. Your brain does pattern completion automatically.

**The Hopfield idea**: Create an energy landscape where:
- Each stored pattern sits at a **valley** (local minimum)
- Corrupted versions have **higher energy**
- Running dynamics will "roll" the corrupted pattern into the nearest valley

The corrupted "0" should have higher energy than the clean "0", but lower than clean "1" or "X" (since it's closer to "0").

---

## Hebbian Learning: Building the Weight Matrix

How do we construct an energy function where our patterns are minima?

Before diving into the math, let's think about what we need. Imagine each neuron as a person at a party. We want to encode "rules" about who should agree:

- **Rule 1**: If neurons i and j are both ON (+1) in our memory, they're "friends" — they should tend to be ON together
- **Rule 2**: If neuron i is ON but j is OFF (-1), they're "enemies" — they should tend to disagree
- **Rule 3**: The rules should stack — if multiple memories agree that i and j should be friends, the friendship is stronger

The key insight comes from neuroscience: **"Neurons that fire together, wire together."** (Hebb's rule, 1949)

**The elegant solution**: Make the weight between neurons i and j equal to the **product** of their values:
- Both +1? Product = +1 (positive weight = friends)
- Both -1? Product = +1 (negative × negative = positive = friends)
- One +1, one -1? Product = -1 (negative weight = enemies)


In [None]:
# Tiny example: 3 neurons with pattern [+1, -1, +1]
tiny_pattern = torch.tensor([1., -1., 1.])


In [None]:
# The outer product captures "who should agree with whom"
torch.outer(tiny_pattern, tiny_pattern)


**Reading this 3×3 matrix — why the outer product is genius:**

The outer product $\xi \xi^T$ gives us $W_{ij} = \xi_i \cdot \xi_j$. Let's trace through:

- W[0,2] = (+1)(+1) = +1: neurons 0 and 2 are both +1, they should agree ✓
- W[0,1] = (+1)(-1) = -1: neuron 0 is +1 but neuron 1 is -1, they disagree ✓
- W[1,2] = (-1)(+1) = -1: neuron 1 is -1 but neuron 2 is +1, they disagree ✓

**The math captures exactly what we want**: The sign of the product tells us whether neurons should agree (+) or disagree (-).

The diagonal W[i,i] = ξ_i² = 1 always (since ±1 squared = 1). This represents "should neuron i agree with itself?" which is meaningless, so we zero it:


In [None]:
W_tiny = torch.outer(tiny_pattern, tiny_pattern)
W_tiny.fill_diagonal_(0)
W_tiny


For **multiple patterns**, we just add up the outer products. Each pattern contributes its preferences, and they accumulate:


In [None]:
# Flatten our patterns to 1D vectors (25 neurons for 5x5)
p0 = pattern_0.flatten()
p1 = pattern_1.flatten()  
px = pattern_x.flatten()

print(f"Each pattern is a vector of {len(p0)} neurons")


In [None]:
# | export
def build_hopfield_weights(patterns):
    """
    Build weight matrix from a list of patterns using Hebbian learning.
    W = (1/N) * sum of outer products, with zero diagonal.
    """
    n_neurons = patterns[0].numel()
    W = torch.zeros(n_neurons, n_neurons)
    
    for p in patterns:
        flat = p.flatten()
        W += torch.outer(flat, flat)
    
    W /= n_neurons  # normalize by number of neurons
    W.fill_diagonal_(0)  # no self-connections
    
    return W


In [None]:
W = build_hopfield_weights([pattern_0, pattern_1, pattern_x])
W.shape


In [None]:
# Visualize the weight matrix - it encodes all our patterns!
plt.figure(figsize=(8, 6))
plt.imshow(W.numpy(), cmap='RdBu', vmin=-W.abs().max(), vmax=W.abs().max())
plt.colorbar(label='Weight')
plt.title('Hopfield Weight Matrix (25×25)\nRed = neurons should agree, Blue = should disagree')
plt.xlabel('Neuron j')
plt.ylabel('Neuron i')
plt.show()


---

## The Hopfield Energy Function

Now comes the crucial step: defining an energy that's **low when neurons agree with their weights**.

### Deriving the Energy from First Principles

We want an energy function where:
1. Following the rules (friends agree, enemies disagree) = **low energy** = stable
2. Breaking the rules = **high energy** = unstable

Let's build it step by step.

**Step 1: Single pair contribution**

For neurons i and j, we want to reward "correct" configurations:
- If W_ij > 0 (they should agree) and they DO agree (same sign) → reward this
- If W_ij < 0 (they should disagree) and they DO disagree (opposite signs) → reward this

The product $W_{ij} \cdot x_i \cdot x_j$ does exactly this:
- If W_ij = +1 (agree) and x_i = x_j = +1: contribution = +1 (good!)
- If W_ij = +1 (agree) and x_i = +1, x_j = -1: contribution = -1 (bad!)
- If W_ij = -1 (disagree) and x_i = +1, x_j = -1: contribution = (-1)(+1)(-1) = +1 (good!)

**Step 2: Sum over all pairs**

Total "happiness" = $\sum_{i,j} W_{ij} x_i x_j = \mathbf{x}^T W \mathbf{x}$

**Step 3: Flip sign (energy = negative happiness)**

By convention, we want LOW energy to be GOOD, so:

$$E(\mathbf{x}) = -\frac{1}{2} \mathbf{x}^T W \mathbf{x}$$

The 1/2 is because each pair (i,j) appears twice in the sum (once as i,j and once as j,i since W is symmetric).

### Why this works visually


In [None]:
# Visualize WHY stored patterns have low energy
# Each cell W_ij * x_i * x_j contributes to the energy sum

def visualize_energy_contributions(pattern, W, title=""):
    """Show which neuron pairs contribute positively/negatively to energy."""
    flat = pattern.flatten()
    n = len(flat)
    
    # Compute contribution matrix: -W_ij * x_i * x_j (note the negative for energy)
    contributions = torch.zeros(n, n)
    for i in range(n):
        for j in range(n):
            contributions[i, j] = -W[i, j] * flat[i] * flat[j]
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Pattern
    axes[0].imshow(pattern.view(5, 5).numpy(), cmap='gray', vmin=-1, vmax=1)
    axes[0].set_title(f'Pattern\n{title}')
    axes[0].axis('off')
    
    # Contribution matrix (blue = lowers energy, red = raises energy)
    im = axes[1].imshow(contributions.numpy(), cmap='RdBu', 
                        vmin=-contributions.abs().max(), vmax=contributions.abs().max())
    axes[1].set_title('Energy contributions\nBlue = lowers E (good), Red = raises E (bad)')
    axes[1].set_xlabel('Neuron j')
    axes[1].set_ylabel('Neuron i')
    plt.colorbar(im, ax=axes[1])
    
    # Histogram
    axes[2].hist(contributions.flatten().numpy(), bins=30, edgecolor='black', alpha=0.7)
    axes[2].axvline(x=0, color='r', linestyle='--')
    axes[2].set_xlabel('Energy contribution')
    axes[2].set_ylabel('Count')
    total_E = contributions.sum().item() / 2  # divide by 2 for double counting
    axes[2].set_title(f'Total Energy = {total_E:.2f}')
    
    plt.tight_layout()
    plt.show()

# Compare stored pattern vs corrupted
visualize_energy_contributions(pattern_0, W, "Stored '0' (should be LOW energy)")
visualize_energy_contributions(corrupted_0, W, "Corrupted '0' (should be HIGHER energy)")


In [None]:
# | export
def hopfield_energy(state, W):
    """Compute energy E(x) = -0.5 * x^T W x"""
    flat = state.flatten()
    return -0.5 * flat @ W @ flat


In [None]:
# Energy of stored patterns (should be low)
hopfield_energy(pattern_0, W)


In [None]:
# Energy of corrupted pattern (should be higher)
hopfield_energy(corrupted_0, W)


In [None]:
# Compare energies
print("Stored patterns (should be low):")
print(f"  Pattern 0: {hopfield_energy(pattern_0, W):.4f}")
print(f"  Pattern 1: {hopfield_energy(pattern_1, W):.4f}")
print(f"  Pattern X: {hopfield_energy(pattern_x, W):.4f}")
print(f"\nCorrupted pattern 0 (20% noise): {hopfield_energy(corrupted_0, W):.4f}")
print(f"\nRandom noise: {hopfield_energy(torch.sign(torch.randn(5, 5)), W):.4f}")


Stored patterns have lower energy than corrupted or random states! The energy landscape has valleys at our memories.

---

## Dynamics: Making the Network "Think"

Now we need dynamics that **always decrease energy** (or keep it same). This is like letting the ball roll downhill.

### The Local Field: Democratic Voting Among Neurons

Imagine you're neuron $i$, trying to decide: should I be +1 or -1?

You poll each neighbor $j$ who casts a **weighted vote**:

$$h_i = \sum_j W_{ij} x_j$$

**Breaking this down:**

- Neighbor $j$ is currently in state $x_j$ (either +1 or -1)
- Your connection weight $W_{ij}$ says how much you should agree (+) or disagree (-) with $j$
- The vote from $j$ is: $W_{ij} \cdot x_j$

**Example votes:**
- $W_{ij} = +1$ (should agree), $x_j = +1$ → vote = +1 ("be +1 like me!")
- $W_{ij} = +1$ (should agree), $x_j = -1$ → vote = -1 ("be -1 like me!")  
- $W_{ij} = -1$ (should disagree), $x_j = +1$ → vote = -1 ("be opposite to me!")
- $W_{ij} = -1$ (should disagree), $x_j = -1$ → vote = +1 ("be opposite to me!")

**The update rule**: Follow the consensus! Set $x_i = \text{sign}(h_i)$.


In [None]:
# | export
def local_field(state, W):
    """
    Compute the local field h = W @ x for each neuron.
    
    h_i > 0 means neighbors vote for neuron i to be +1
    h_i < 0 means neighbors vote for neuron i to be -1
    """
    return W @ state.flatten()


In [None]:
# What does the network "want" the corrupted pattern to become?
h = local_field(corrupted_0, W)
h.view(5, 5)  # reshape to see it as an image


In [None]:
# Visualize the local field as "what neighbors want"
def visualize_local_field(state, W, pattern_name=""):
    """Show the local field h and how it guides updates."""
    h = local_field(state, W)
    recommended = torch.sign(h)  # what the network "wants"
    
    fig, axes = plt.subplots(1, 4, figsize=(14, 3.5))
    
    # Current state
    axes[0].imshow(state.view(5, 5).numpy(), cmap='gray', vmin=-1, vmax=1)
    axes[0].set_title('Current State')
    axes[0].axis('off')
    
    # Local field (what neighbors want)
    im = axes[1].imshow(h.view(5, 5).numpy(), cmap='RdBu', 
                        vmin=-h.abs().max(), vmax=h.abs().max())
    axes[1].set_title('Local Field h\\n(Red=vote -1, Blue=vote +1)')
    axes[1].axis('off')
    plt.colorbar(im, ax=axes[1], shrink=0.8)
    
    # Recommended next state
    axes[2].imshow(recommended.view(5, 5).numpy(), cmap='gray', vmin=-1, vmax=1)
    axes[2].set_title('sign(h)\\n(What network recommends)')
    axes[2].axis('off')
    
    # Difference (what would flip)
    diff = (state.flatten() != recommended).float().view(5, 5)
    axes[3].imshow(diff.numpy(), cmap='Reds', vmin=0, vmax=1)
    axes[3].set_title('Would flip\\n(Red = yes)')
    axes[3].axis('off')
    
    plt.suptitle(f'Local Field Voting - {pattern_name}', fontsize=12)
    plt.tight_layout()
    plt.show()

visualize_local_field(corrupted_0, W, "Corrupted '0'")


In [None]:
# The update rule: align with the sign of the local field
torch.sign(h).view(5, 5)


In [None]:
# | export
def hopfield_update(state, W, async_update=True, max_steps=100):
    """
    Run Hopfield dynamics until convergence.
    
    async_update=True: Update one random neuron at a time (guaranteed to converge)
    async_update=False: Update all neurons simultaneously (faster but may oscillate)
    """
    state = state.flatten().clone()
    n = len(state)
    energies = [hopfield_energy(state, W).item()]
    states = [state.clone()]
    
    for step in range(max_steps):
        if async_update:
            # Pick a random neuron and update it
            i = torch.randint(0, n, (1,)).item()
            h_i = W[i] @ state
            new_val = 1.0 if h_i >= 0 else -1.0
            if state[i] == new_val:
                continue  # no change
            state[i] = new_val
        else:
            # Update all neurons at once
            h = W @ state
            new_state = torch.sign(h)
            new_state[new_state == 0] = 1  # tie-break
            if torch.all(state == new_state):
                break
            state = new_state
        
        energies.append(hopfield_energy(state, W).item())
        states.append(state.clone())
    
    return state, energies, states


In [None]:
# Recover the corrupted "0" pattern!
recovered, energies, states = hopfield_update(corrupted_0, W, async_update=False)

show_patterns([corrupted_0, recovered.view(5, 5), pattern_0], 
              titles=['Corrupted input', 'Recovered', 'Original target'])


In [None]:
# Energy always decreases (or stays same) - this is the Lyapunov function!
plt.figure(figsize=(8, 4))
plt.plot(energies, 'b-o', markersize=8)
plt.xlabel('Update step')
plt.ylabel('Energy')
plt.title('Energy Decreases Until Pattern is Recovered')
plt.axhline(y=hopfield_energy(pattern_0, W).item(), color='g', linestyle='--', label='Target pattern energy')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()


### Why Does Energy Always Decrease? (The Lyapunov Proof)

This is beautiful and worth understanding deeply. Let's derive it step by step.

**Setup**: We're updating neuron $i$. Before update, its value is $x_i$. After update, it becomes $x_i^{\text{new}} = \text{sign}(h_i)$.

**Step 1: Energy only depends on neuron i through terms involving i**

$$E = -\frac{1}{2} \sum_{j,k} W_{jk} x_j x_k$$

The terms involving neuron $i$ are: $-\frac{1}{2} \sum_{j \neq i} W_{ij} x_i x_j - \frac{1}{2} \sum_{j \neq i} W_{ji} x_j x_i = -\sum_{j \neq i} W_{ij} x_i x_j = -x_i h_i$

(The factor of 2 comes from W being symmetric: W_ij = W_ji)

**Step 2: Energy change when we flip neuron i**

$$\Delta E = E^{\text{new}} - E^{\text{old}} = -x_i^{\text{new}} h_i - (-x_i^{\text{old}} h_i) = (x_i^{\text{old}} - x_i^{\text{new}}) h_i$$

**Step 3: Our update rule guarantees non-increase**

Our rule sets $x_i^{\text{new}} = \text{sign}(h_i)$. This means:
- If $h_i > 0$: we set $x_i^{\text{new}} = +1$, so $x_i^{\text{new}} h_i > 0$
- If $h_i < 0$: we set $x_i^{\text{new}} = -1$, so $x_i^{\text{new}} h_i > 0$ (negative × negative)
- If $h_i = 0$: energy doesn't change

In all cases, $x_i^{\text{new}} h_i \geq x_i^{\text{old}} h_i$, so $\Delta E \leq 0$. Energy never increases!

**Why this matters**: This is a **Lyapunov function** — a quantity that monotonically decreases under dynamics. It guarantees:
1. The system will eventually stop changing (convergence)
2. It will stop at a local minimum of energy
3. That minimum is a stored pattern (or spurious state)


### Test Problem: Implement Asynchronous Update

The synchronous update (all neurons at once) can oscillate in certain cases. Asynchronous update (one neuron at a time) is guaranteed to converge. Complete the function:


In [None]:
def test_async_update_step(test_state, test_W, test_neuron_idx):
    """
    Perform ONE asynchronous update step on a single neuron.
    
    Args:
        test_state: Current state vector (1D tensor)
        test_W: Weight matrix
        test_neuron_idx: Which neuron to update
    
    Returns:
        Updated state (modified in place is fine)
    """
    # FILL IN CODE HERE
    # 1. Compute the local field h_i = sum_j W[i,j] * x[j]
    # 2. Set state[i] = +1 if h_i >= 0, else -1
    
    pass

# Test: after updating, energy should not increase
test_state = corrupted_0.flatten().clone()
test_orig_energy = hopfield_energy(test_state, W).item()

test_async_update_step(test_state, W, test_neuron_idx=0)

test_new_energy = hopfield_energy(test_state, W).item()
# Uncomment to check: assert test_new_energy <= test_orig_energy + 1e-6, "Energy should not increase!"
# print("✓ Passed: Energy did not increase")


### Animated Recovery

Let's watch the network "think" step by step as it recovers a heavily corrupted pattern:


In [None]:
# Heavily corrupt pattern X (40% noise)
corrupted_x = corrupt_pattern(pattern_x, noise_fraction=0.4)

# Recover using asynchronous updates (many steps since we update one neuron at a time)
recovered_x, energies_x, states_x = hopfield_update(corrupted_x, W, async_update=True, max_steps=500)

# Show key frames of the recovery
n_frames = min(8, len(states_x))
frame_indices = [int(i * (len(states_x) - 1) / (n_frames - 1)) for i in range(n_frames)]

fig, axes = plt.subplots(1, n_frames, figsize=(2.5 * n_frames, 3))
for ax, idx in zip(axes, frame_indices):
    ax.imshow(states_x[idx].view(5, 5).numpy(), cmap='gray', vmin=-1, vmax=1)
    ax.set_title(f'Step {idx}\nE={energies_x[idx]:.2f}', fontsize=10)
    ax.set_xticks([])
    ax.set_yticks([])

plt.suptitle('Pattern Recovery Animation (40% noise → Original X)', fontsize=12)
plt.tight_layout()
plt.show()


---

## Capacity and Spurious States

How many patterns can a Hopfield network store? This is crucial for practical use.

**Theoretical result (Amit, Gutfreund, Sompolinsky, 1985)**: For N neurons, you can reliably store at most ~**0.14N** random patterns.

Beyond this, the network develops **spurious states** — false memories that weren't stored but appear as local minima. These are like hallucinations.


In [None]:
# For our 25-neuron network, theoretical capacity is ~0.14 * 25 ≈ 3.5 patterns
# We stored 3, which is within capacity. Let's try overloading it.

n_neurons = 64  # 8x8 patterns
capacity_limit = int(0.14 * n_neurons)
print(f"For {n_neurons} neurons, theoretical capacity: ~{capacity_limit} patterns")


In [None]:
def test_capacity(n_neurons, n_patterns, noise_fraction=0.2, n_trials=20):
    """Test how well a Hopfield network recovers patterns."""
    # Generate random binary patterns
    patterns = [torch.sign(torch.randn(n_neurons)) for _ in range(n_patterns)]
    
    # Build weight matrix
    W = torch.zeros(n_neurons, n_neurons)
    for p in patterns:
        W += torch.outer(p, p)
    W /= n_neurons
    W.fill_diagonal_(0)
    
    # Test recovery
    successes = 0
    for _ in range(n_trials):
        # Pick a random pattern and corrupt it
        idx = torch.randint(0, n_patterns, (1,)).item()
        original = patterns[idx]
        corrupted = original.clone()
        flip_idx = torch.randperm(n_neurons)[:int(n_neurons * noise_fraction)]
        corrupted[flip_idx] *= -1
        
        # Recover
        recovered, _, _ = hopfield_update(corrupted, W, async_update=False, max_steps=50)
        
        # Check if recovered matches original
        if torch.all(recovered == original):
            successes += 1
    
    return successes / n_trials


In [None]:
# Test different numbers of patterns
n_neurons = 100
pattern_counts = [1, 5, 10, 14, 20, 30, 50]  # 0.14 * 100 = 14

accuracies = []
for n_patterns in pattern_counts:
    acc = test_capacity(n_neurons, n_patterns, noise_fraction=0.1)
    accuracies.append(acc)
    print(f"{n_patterns:3d} patterns: {acc*100:.0f}% recovery")


In [None]:
# Visualize the capacity cliff
plt.figure(figsize=(8, 5))
plt.plot(pattern_counts, accuracies, 'bo-', markersize=10, linewidth=2)
plt.axvline(x=0.14*n_neurons, color='r', linestyle='--', label=f'Theoretical limit (0.14N = {0.14*n_neurons:.0f})')
plt.xlabel('Number of Stored Patterns')
plt.ylabel('Recovery Accuracy')
plt.title(f'Hopfield Network Capacity (N={n_neurons} neurons)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.ylim(0, 1.05)
plt.show()


Notice the sharp drop around the theoretical limit! Beyond capacity, the network hallucinates spurious patterns.

---

## Comparison to Neural Networks

How does the Hopfield network differ from a standard feedforward neural network?

**Feedforward NN**: Layers, forward pass, one-shot computation x → f(x), trained with gradient descent.

**Hopfield Network**: Single recurrent layer, iterate until equilibrium, Hebbian learning (one-shot storage).

The key conceptual difference: feedforward NNs **compute** answers, while Hopfield networks **settle** into them. It's the difference between calculating 2+2=4 versus recognizing that "4" feels right.

**Autoregressive models** (like GPT) are somewhere in between: they generate one token at a time through forward passes, but the sequential nature creates implicit dynamics. Interestingly, modern Hopfield networks reveal that attention mechanisms can be viewed as a single energy minimization step — we'll see this later.


---

## The Problem: Getting Stuck in Local Minima

Hopfield networks have a critical flaw: they always go "downhill" in energy. This greedy behavior means they can get stuck in spurious local minima (false memories) instead of finding the true stored pattern.

Think about it: if you're in a small dip, you can't see there's a deeper valley nearby because every direction looks "uphill."


In [None]:
# An energy landscape with a local minimum trap
x = torch.linspace(-3, 3, 300)
energy_with_trap = 0.5 * x**4 - x**3 - 2*x**2 + 3*x

plt.figure(figsize=(10, 5))
plt.plot(x.numpy(), energy_with_trap.numpy(), 'b-', linewidth=2)
plt.scatter([0.5], [0.5*0.5**4 - 0.5**3 - 2*0.5**2 + 3*0.5], c='red', s=200, zorder=5, label='Local minimum (trap)')
plt.scatter([-1.8], [0.5*(-1.8)**4 - (-1.8)**3 - 2*(-1.8)**2 + 3*(-1.8)], c='green', s=200, zorder=5, label='Global minimum (true memory)')
plt.xlabel('State')
plt.ylabel('Energy')
plt.title('Greedy descent gets trapped in local minima')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()


**The solution from physics**: Add temperature!

At temperature T > 0, particles don't just sit at minimum energy. They jiggle around randomly, and occasionally "jump" uphill. Higher temperature = more randomness = can escape traps.

This is the **Boltzmann distribution** from statistical mechanics:

$$P(\text{state}) \propto e^{-E(\text{state}) / T}$$

- At T → 0: Only the lowest energy state has non-zero probability (deterministic)
- At T → ∞: All states equally likely (pure randomness)
- At intermediate T: Lower energy states more likely, but can still sample higher energy states

Let's visualize this:


In [None]:
# Simple energy landscape: E = x^2 (single minimum at 0)
x = torch.linspace(-3, 3, 200)
E = x**2


In [None]:
# Boltzmann distribution at different temperatures
def boltzmann_prob(E, T):
    """P(x) ∝ exp(-E/T), normalized."""
    unnorm = torch.exp(-E / T)
    return unnorm / unnorm.sum()

temperatures = [0.1, 0.5, 1.0, 2.0, 5.0]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Energy landscape
axes[0].plot(x.numpy(), E.numpy(), 'b-', linewidth=2)
axes[0].set_xlabel('State x')
axes[0].set_ylabel('Energy')
axes[0].set_title('Energy Landscape E(x) = x²')
axes[0].grid(True, alpha=0.3)

# Right: Probability distributions at different T
colors = plt.cm.coolwarm(np.linspace(0, 1, len(temperatures)))
for T, color in zip(temperatures, colors):
    prob = boltzmann_prob(E, T)
    axes[1].plot(x.numpy(), prob.numpy(), linewidth=2, color=color, label=f'T={T}')

axes[1].set_xlabel('State x')
axes[1].set_ylabel('Probability')
axes[1].set_title('Boltzmann Distribution: P(x) ∝ exp(-E/T)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


At low temperature (T=0.1, blue), probability is sharply concentrated at the minimum. At high temperature (T=5, red), it's spread out — the system explores more states.

**Simulated annealing**: Start hot (explore widely), gradually cool down (settle into minimum). This is how Boltzmann machines escape local minima.

---

## Boltzmann Machines: Stochastic Hopfield Networks

### Where Does the Sigmoid Come From?

This derivation is beautiful and connects physics to neural networks. Let's build it from scratch.

**The Physical Setup**: We want to sample from the Boltzmann distribution:

$$P(\text{state } \mathbf{x}) \propto e^{-E(\mathbf{x})/T}$$

**The Question**: Given that we're updating neuron $i$, what's the probability it should be +1 vs -1?

**Step 1: Conditional probability**

We want $P(x_i = +1 | \text{all other neurons})$. By Bayes' rule, this is proportional to $e^{-E(\mathbf{x}^{+})/T}$ where $\mathbf{x}^{+}$ is the state with $x_i = +1$.

**Step 2: Energy difference**

From our earlier derivation, the terms involving $x_i$ contribute $-x_i h_i$ to the energy.

- If $x_i = +1$: contribution = $-h_i$
- If $x_i = -1$: contribution = $+h_i$

**Step 3: The ratio**

$$\frac{P(x_i = +1)}{P(x_i = -1)} = \frac{e^{-(-h_i)/T}}{e^{-(+h_i)/T}} = \frac{e^{h_i/T}}{e^{-h_i/T}} = e^{2h_i/T}$$

**Step 4: Normalize to get probability**

Since $P(+1) + P(-1) = 1$:

$$P(x_i = +1) = \frac{e^{h_i/T}}{e^{h_i/T} + e^{-h_i/T}} = \frac{1}{1 + e^{-2h_i/T}}$$

With a rescaling of T, this is exactly the **sigmoid function**: $\sigma(h_i/T) = \frac{1}{1 + e^{-h_i/T}}$

**The stunning insight**: The sigmoid function isn't arbitrary — it's the ONLY function that samples correctly from the Boltzmann distribution!


In [None]:
# Sigmoid at different temperatures
h = torch.linspace(-5, 5, 200)  # local field values

fig, ax = plt.subplots(figsize=(10, 5))
for T, color in zip([0.2, 0.5, 1.0, 2.0, 5.0], plt.cm.coolwarm(np.linspace(0, 1, 5))):
    prob = torch.sigmoid(h / T)
    ax.plot(h.numpy(), prob.numpy(), linewidth=2, color=color, label=f'T={T}')

ax.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
ax.axvline(x=0, color='gray', linestyle='--', alpha=0.5)
ax.set_xlabel('Local field $h_i$')
ax.set_ylabel('P($x_i$ = +1)')
ax.set_title('Stochastic neuron firing probability at different temperatures')
ax.legend()
ax.grid(True, alpha=0.3)
plt.show()


### Understanding Temperature: A Unifying Concept

Temperature T appears everywhere in ML, often hidden under different names:

| Concept | Where It Appears | Effect |
|---------|-----------------|--------|
| **Hopfield T** | Boltzmann machine | Exploration vs exploitation |
| **Softmax temperature** | Transformer attention | Sharpness of attention |
| **Sampling temperature** | LLM generation | Creativity vs coherence |
| **Annealing schedule** | Optimization | Escaping local minima |

**The Deep Why**: All of these are the SAME concept from statistical mechanics:

$$P(\text{state}) \propto e^{-E(\text{state})/T}$$

- **T → 0**: Only the minimum energy state has probability (deterministic/greedy)
- **T → ∞**: All states equally likely (random)
- **Finite T**: Explore states proportional to their "quality"

**This is exploration vs exploitation!** High T explores (might find better solutions), low T exploits (commits to current best). The optimal strategy is often to start hot and cool down (annealing).


In [None]:
# Visualize the effect of temperature on sampling
# Low T → greedy/deterministic, High T → random exploration

def sample_many_states(start_state, W, T, n_samples=20, steps_per_sample=100):
    """Sample multiple final states at a given temperature."""
    samples = []
    for _ in range(n_samples):
        final, _ = boltzmann_update(start_state.clone(), W, T=T, n_steps=steps_per_sample)
        samples.append(final)
    return torch.stack(samples)

# Start from random state
random_start = torch.sign(torch.randn(25))

# Sample at different temperatures
temps = [0.1, 0.5, 1.0, 3.0]
n_show = 8

fig, axes = plt.subplots(len(temps), n_show + 1, figsize=(3*(n_show+1), 3*len(temps)))

for row, T in enumerate(temps):
    samples = sample_many_states(random_start, W, T, n_samples=n_show, steps_per_sample=200)
    
    # Show temperature label
    axes[row, 0].text(0.5, 0.5, f'T={T}', fontsize=16, ha='center', va='center', fontweight='bold')
    axes[row, 0].axis('off')
    
    # Show samples
    for col in range(n_show):
        axes[row, col+1].imshow(samples[col].view(5, 5).numpy(), cmap='gray', vmin=-1, vmax=1)
        
        # Check if it matches any stored pattern
        matches = []
        for name, pattern in [("0", pattern_0), ("1", pattern_1), ("X", pattern_x)]:
            if torch.all(samples[col].view(5, 5) == pattern):
                matches.append(name)
        
        title = f"→{matches[0]}" if matches else "spurious"
        axes[row, col+1].set_title(title, fontsize=10)
        axes[row, col+1].axis('off')

plt.suptitle('Boltzmann Sampling at Different Temperatures\\n'
             'Low T: Always finds stored patterns (but can get stuck)\\n'
             'High T: Explores more (but may not settle)', fontsize=12)
plt.tight_layout()
plt.show()


In [None]:
# | export
def boltzmann_update(state, W, T=1.0, n_steps=100):
    """
    Stochastic Boltzmann machine update.
    Neurons flip probabilistically based on local field and temperature.
    
    The key insight: instead of x_i = sign(h_i), we use
    P(x_i = +1) = sigmoid(h_i / T)
    
    This samples from the Boltzmann distribution P(state) ∝ exp(-E/T)
    """
    state = state.flatten().clone()
    n = len(state)
    energies = [hopfield_energy(state, W).item()]
    
    for _ in range(n_steps):
        # Pick a random neuron
        i = torch.randint(0, n, (1,)).item()
        
        # Compute local field: what do my neighbors want?
        h_i = W[i] @ state
        
        # Probability of being +1 (derived from Boltzmann distribution!)
        prob_plus = torch.sigmoid(torch.tensor(h_i / T))
        
        # Stochastic update: flip a biased coin
        if torch.rand(1) < prob_plus:
            state[i] = 1.0
        else:
            state[i] = -1.0
        
        energies.append(hopfield_energy(state, W).item())
    
    return state, energies


In [None]:
# Compare energy trajectories at different temperatures
corrupted = corrupt_pattern(pattern_0, noise_fraction=0.3).flatten()

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for ax, T in zip(axes, [0.1, 1.0, 5.0]):
    _, energies = boltzmann_update(corrupted.clone(), W, T=T, n_steps=300)
    ax.plot(energies, alpha=0.7)
    ax.axhline(y=hopfield_energy(pattern_0, W).item(), color='g', linestyle='--', label='Target energy')
    ax.set_xlabel('Step')
    ax.set_ylabel('Energy')
    ax.set_title(f'T = {T}')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.suptitle('Boltzmann Machine Energy at Different Temperatures', fontsize=12)
plt.tight_layout()
plt.show()


See the tradeoff:
- **Low T (0.1)**: Energy decreases quickly but might get stuck
- **High T (5.0)**: Explores widely but never settles
- **Medium T (1.0)**: A balance, but not optimal

**Simulated annealing** solves this: start hot, gradually cool:


In [None]:
# | export
def simulated_annealing(state, W, T_start=5.0, T_end=0.1, n_steps=500):
    """
    Run Boltzmann dynamics with decreasing temperature.
    
    The key insight: high T → explore (escape local minima)
                     low T → exploit (settle into minimum)
    Gradually cooling combines the best of both!
    """
    state = state.flatten().clone()
    n = len(state)
    energies = [hopfield_energy(state, W).item()]
    temperatures = []
    
    for step in range(n_steps):
        # Linear cooling schedule (exponential schedule also popular)
        T = T_start - (T_start - T_end) * step / n_steps
        temperatures.append(T)
        
        # Pick a random neuron and do stochastic update
        i = torch.randint(0, n, (1,)).item()
        h_i = W[i] @ state
        prob_plus = torch.sigmoid(torch.tensor(h_i / T))
        state[i] = 1.0 if torch.rand(1) < prob_plus else -1.0
        
        energies.append(hopfield_energy(state, W).item())
    
    return state, energies, temperatures


In [None]:
# Run simulated annealing on corrupted pattern
state_annealed, energies_annealed, temps = simulated_annealing(corrupted, W, n_steps=500)

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), sharex=True)

# Energy
ax1.plot(energies_annealed, 'b-', alpha=0.7, label='Energy')
ax1.axhline(y=hopfield_energy(pattern_0, W).item(), color='g', linestyle='--', label='Target energy')
ax1.set_ylabel('Energy')
ax1.legend()
ax1.set_title('Simulated Annealing: Start Hot, Cool Down')
ax1.grid(True, alpha=0.3)

# Temperature
ax2.plot(temps, 'r-', linewidth=2)
ax2.set_xlabel('Step')
ax2.set_ylabel('Temperature')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Final energy: {energies_annealed[-1]:.4f}")
print(f"Target energy: {hopfield_energy(pattern_0, W).item():.4f}")


---

## Restricted Boltzmann Machines (RBMs)

The full Boltzmann machine is powerful but has a practical problem: **training is intractable**. Computing gradients requires sampling from the model distribution, which takes exponentially long.

**Restricted Boltzmann Machines (RBMs)** solve this by introducing a clever restriction:
- Divide neurons into **visible** units (data we observe) and **hidden** units (latent features)
- **No connections within a layer** — only visible ↔ hidden connections
- This bipartite structure makes sampling tractable!

Why does removing intra-layer connections help? Given the visible units, hidden units become **conditionally independent** (no connections between them), so we can sample them all in parallel. Same for visible given hidden.


In [None]:
# Visualize RBM architecture
fig, ax = plt.subplots(figsize=(10, 6))

# Visible layer (bottom)
n_visible, n_hidden = 6, 4
for i in range(n_visible):
    circle = plt.Circle((i * 1.5 + 1, 1), 0.3, color='steelblue', ec='black', linewidth=2)
    ax.add_patch(circle)
    ax.text(i * 1.5 + 1, 1, f'v{i}', ha='center', va='center', fontsize=10, color='white')

# Hidden layer (top)
for j in range(n_hidden):
    circle = plt.Circle((j * 2 + 1.75, 4), 0.3, color='coral', ec='black', linewidth=2)
    ax.add_patch(circle)
    ax.text(j * 2 + 1.75, 4, f'h{j}', ha='center', va='center', fontsize=10)

# Connections (visible ↔ hidden only)
for i in range(n_visible):
    for j in range(n_hidden):
        ax.plot([i * 1.5 + 1, j * 2 + 1.75], [1.3, 3.7], 'gray', alpha=0.3, linewidth=1)

ax.text(4.5, 5, 'Hidden Layer (latent features)', ha='center', fontsize=12, fontweight='bold')
ax.text(4.5, 0, 'Visible Layer (data)', ha='center', fontsize=12, fontweight='bold')

ax.set_xlim(-1, 10)
ax.set_ylim(-1, 6)
ax.set_aspect('equal')
ax.axis('off')
ax.set_title('Restricted Boltzmann Machine\n(No intra-layer connections)', fontsize=14)
plt.show()


The RBM energy function:

$$E(\mathbf{v}, \mathbf{h}) = -\mathbf{a}^T\mathbf{v} - \mathbf{b}^T\mathbf{h} - \mathbf{v}^T W \mathbf{h}$$

where:
- $\mathbf{v}$ = visible units, $\mathbf{h}$ = hidden units
- $\mathbf{a}$ = visible biases, $\mathbf{b}$ = hidden biases
- $W$ = weights connecting visible and hidden layers

The conditional distributions are simple sigmoids:
- $P(h_j = 1 | \mathbf{v}) = \sigma(b_j + \sum_i v_i W_{ij})$
- $P(v_i = 1 | \mathbf{h}) = \sigma(a_i + \sum_j W_{ij} h_j)$

This is called **block Gibbs sampling** — sample all hidden given visible, then all visible given hidden.


In [None]:
# | export
class RBM(nn.Module):
    """
    Restricted Boltzmann Machine with binary visible and hidden units.
    
    The "restriction": No connections within layers (only visible ↔ hidden).
    This makes sampling tractable: given visible, all hidden are independent!
    """
    def __init__(self, n_visible, n_hidden):
        super().__init__()
        self.n_visible = n_visible
        self.n_hidden = n_hidden
        
        # Initialize weights and biases
        # Small random weights, zero biases
        self.W = nn.Parameter(torch.randn(n_visible, n_hidden) * 0.01)
        self.a = nn.Parameter(torch.zeros(n_visible))  # visible bias
        self.b = nn.Parameter(torch.zeros(n_hidden))   # hidden bias
    
    def sample_hidden(self, v):
        """Sample hidden units given visible units."""
        # P(h_j = 1 | v) = sigmoid(b_j + sum_i v_i W_ij)
        # This is the same sigmoid derivation as before!
        prob_h = torch.sigmoid(F.linear(v, self.W.t(), self.b))
        return torch.bernoulli(prob_h), prob_h
    
    def sample_visible(self, h):
        """Sample visible units given hidden units."""
        # P(v_i = 1 | h) = sigmoid(a_i + sum_j W_ij h_j)
        prob_v = torch.sigmoid(F.linear(h, self.W, self.a))
        return torch.bernoulli(prob_v), prob_v
    
    def energy(self, v, h):
        """Compute energy E(v, h)."""
        return -v @ self.a - h @ self.b - (v @ self.W @ h)
    
    def free_energy(self, v):
        """
        Free energy F(v) = -log sum_h exp(-E(v,h))
        For binary hidden units: F(v) = -a^T v - sum_j log(1 + exp(b_j + v^T W_j))
        
        This is analytically tractable because hidden units are independent given v!
        """
        vbias_term = v @ self.a
        wx_b = F.linear(v, self.W.t(), self.b)
        hidden_term = torch.sum(F.softplus(wx_b), dim=-1)
        return -vbias_term - hidden_term


### Contrastive Divergence: The Training Algorithm

This is where the magic of learning happens. Let's derive it from scratch.

**Goal**: Adjust weights so that training data has LOW energy (high probability).

**The likelihood gradient**: We want to maximize $\log P(\mathbf{v}_{\text{data}})$. Taking the derivative:

$$\frac{\partial \log P(\mathbf{v})}{\partial W_{ij}} = \langle v_i h_j \rangle_{\text{data}} - \langle v_i h_j \rangle_{\text{model}}$$

**Intuition for this gradient:**

- **Positive term** $\langle v_i h_j \rangle_{\text{data}}$: "When I see training data, how often do $v_i$ and $h_j$ fire together?"
- **Negative term** $\langle v_i h_j \rangle_{\text{model}}$: "When the model is left to run freely, how often do they fire together?"

**The learning rule**: 
- If $v_i$ and $h_j$ fire together MORE in data than in the model → increase $W_{ij}$
- If they fire together LESS in data than in the model → decrease $W_{ij}$

This is like saying: "Make the model's 'dreams' match reality!"

**The problem**: Computing $\langle v_i h_j \rangle_{\text{model}}$ requires running the model to equilibrium — which takes forever.

**Contrastive Divergence (Hinton, 2002)**: Just run k steps of Gibbs sampling starting from the data! The intuition:
- Starting from data, after 1 step, you get something "close" to the data
- The difference between data and this reconstruction tells you which direction to push

CD-1 (k=1) is: data → sample hidden → sample visible → done. Surprisingly effective!


In [None]:
# Visualize Contrastive Divergence step by step
def visualize_cd_steps(rbm, v_data, k=3):
    """Show what happens during CD-k: data → hidden → visible → hidden → ..."""
    fig, axes = plt.subplots(2, k+1, figsize=(3*(k+1), 6))
    
    v = v_data.clone().unsqueeze(0) if v_data.dim() == 1 else v_data[:1]
    
    # Show original data
    axes[0, 0].imshow(v.view(8, 8).detach().numpy(), cmap='gray', vmin=0, vmax=1)
    axes[0, 0].set_title('v₀ (data)')
    axes[0, 0].axis('off')
    axes[1, 0].axis('off')
    axes[1, 0].text(0.5, 0.5, 'Start', ha='center', va='center', fontsize=12)
    
    for step in range(k):
        # Sample hidden
        h, prob_h = rbm.sample_hidden(v)
        
        # Sample visible
        v, prob_v = rbm.sample_visible(h)
        
        # Show hidden probabilities
        axes[0, step+1].bar(range(rbm.n_hidden), prob_h.squeeze().detach().numpy(), alpha=0.7)
        axes[0, step+1].set_ylim(0, 1)
        axes[0, step+1].set_title(f'h_{step} probs')
        
        # Show reconstructed visible
        axes[1, step+1].imshow(prob_v.view(8, 8).detach().numpy(), cmap='gray', vmin=0, vmax=1)
        axes[1, step+1].set_title(f'v_{step+1} (recon)')
        axes[1, step+1].axis('off')
    
    plt.suptitle('Contrastive Divergence: Data → Hidden → Visible → ...\\n'
                 'We compare v₀ (data) with v_k (reconstruction)', fontsize=12)
    plt.tight_layout()
    plt.show()

# Use one of our comparison patterns
visualize_cd_steps(rbm, train_data[0], k=3)


In [None]:
# | export
def contrastive_divergence(rbm, v_data, k=1, lr=0.1):
    """
    One step of Contrastive Divergence training.
    
    The algorithm:
    1. Positive phase: Clamp visible to data, sample hidden
    2. Negative phase: Run k steps of Gibbs sampling  
    3. Update: ΔW ∝ ⟨v·h⟩_data - ⟨v·h⟩_reconstruction
    
    Args:
        rbm: RBM model
        v_data: Batch of visible data (batch_size x n_visible)
        k: Number of Gibbs sampling steps
        lr: Learning rate
    
    Returns:
        Reconstruction error (for monitoring)
    """
    batch_size = v_data.shape[0]
    
    # Positive phase: sample hidden from data (what the model "sees")
    h_data, prob_h_data = rbm.sample_hidden(v_data)
    
    # Negative phase: k steps of Gibbs sampling
    v_neg = v_data.clone()
    for _ in range(k):
        h_neg, _ = rbm.sample_hidden(v_neg)
        v_neg, prob_v_neg = rbm.sample_visible(h_neg)
    
    # Final hidden probabilities for negative phase
    _, prob_h_neg = rbm.sample_hidden(v_neg)
    
    # Compute gradients
    # dW = <v_data h_data> - <v_neg h_neg>
    positive_grad = v_data.t() @ prob_h_data / batch_size
    negative_grad = v_neg.t() @ prob_h_neg / batch_size
    
    # Update parameters
    with torch.no_grad():
        rbm.W += lr * (positive_grad - negative_grad)
        rbm.a += lr * (v_data.mean(0) - v_neg.mean(0))
        rbm.b += lr * (prob_h_data.mean(0) - prob_h_neg.mean(0))
    
    # Reconstruction error
    recon_error = F.mse_loss(v_neg, v_data)
    return recon_error.item()


---

## Comparison: Hopfield Network vs RBM

Let's compare these two models on the same task: learning and reconstructing patterns from a small dataset. We'll use 8×8 binary images.


In [None]:
# Create a dataset of simple binary patterns
# These represent simple shapes: vertical line, horizontal line, diagonal, cross, etc.

def create_pattern_dataset(n_patterns=10, size=8):
    """Create diverse binary patterns for testing."""
    patterns = []
    
    # Vertical lines at different positions
    for i in [2, 5]:
        p = torch.zeros(size, size)
        p[:, i] = 1
        patterns.append(p)
    
    # Horizontal lines
    for i in [2, 5]:
        p = torch.zeros(size, size)
        p[i, :] = 1
        patterns.append(p)
    
    # Diagonals
    p = torch.zeros(size, size)
    for i in range(size):
        p[i, i] = 1
    patterns.append(p)
    
    p = torch.zeros(size, size)
    for i in range(size):
        p[i, size-1-i] = 1
    patterns.append(p)
    
    # Border
    p = torch.zeros(size, size)
    p[0, :] = 1
    p[-1, :] = 1
    p[:, 0] = 1
    p[:, -1] = 1
    patterns.append(p)
    
    # Center cross
    p = torch.zeros(size, size)
    p[size//2, :] = 1
    p[:, size//2] = 1
    patterns.append(p)
    
    # Checkerboard
    p = torch.zeros(size, size)
    p[::2, ::2] = 1
    p[1::2, 1::2] = 1
    patterns.append(p)
    
    # X shape
    p = torch.zeros(size, size)
    for i in range(size):
        p[i, i] = 1
        p[i, size-1-i] = 1
    patterns.append(p)
    
    return patterns[:n_patterns]

comparison_patterns = create_pattern_dataset(n_patterns=8, size=8)
print(f"Created {len(comparison_patterns)} patterns of size 8x8 = 64 neurons")


In [None]:
# Visualize the patterns
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i, (ax, p) in enumerate(zip(axes.flat, comparison_patterns)):
    ax.imshow(p.numpy(), cmap='gray_r', vmin=0, vmax=1)
    ax.set_title(f'Pattern {i}')
    ax.axis('off')
plt.suptitle('Comparison Dataset: 8 Binary Patterns', fontsize=14)
plt.tight_layout()
plt.show()


In [None]:
# Convert to ±1 format for Hopfield (0/1 → -1/+1)
hopfield_patterns = [2*p - 1 for p in comparison_patterns]

# Build Hopfield weight matrix
W_comparison = build_hopfield_weights(hopfield_patterns)
print(f"Hopfield weights: {W_comparison.shape}")


In [None]:
# Train RBM on the patterns
n_visible = 64  # 8x8
n_hidden = 32   # Half the visible units

rbm = RBM(n_visible, n_hidden)

# Stack patterns for batch training
train_data = torch.stack([p.flatten() for p in comparison_patterns])

# Train for several epochs
n_epochs = 500
recon_errors = []

for epoch in range(n_epochs):
    # Shuffle data
    perm = torch.randperm(len(train_data))
    shuffled = train_data[perm]
    
    error = contrastive_divergence(rbm, shuffled, k=1, lr=0.1)
    recon_errors.append(error)
    
    if epoch % 100 == 0:
        print(f"Epoch {epoch}: Reconstruction error = {error:.4f}")


In [None]:
# RBM training curve
plt.figure(figsize=(10, 4))
plt.plot(recon_errors)
plt.xlabel('Epoch')
plt.ylabel('Reconstruction Error (MSE)')
plt.title('RBM Training: Reconstruction Error Over Time')
plt.grid(True, alpha=0.3)
plt.show()


In [None]:
def test_pattern_recovery(pattern_idx, noise_fraction=0.2):
    """Compare Hopfield and RBM pattern recovery."""
    
    # Get pattern
    original = comparison_patterns[pattern_idx]
    original_hopfield = hopfield_patterns[pattern_idx]
    
    # Corrupt pattern
    corrupted_0_1 = original.clone().flatten()
    n_flip = int(len(corrupted_0_1) * noise_fraction)
    flip_idx = torch.randperm(len(corrupted_0_1))[:n_flip]
    corrupted_0_1[flip_idx] = 1 - corrupted_0_1[flip_idx]  # flip 0↔1
    
    corrupted_pm1 = 2 * corrupted_0_1 - 1  # convert to ±1
    
    # Hopfield recovery
    hopfield_recovered, _, _ = hopfield_update(corrupted_pm1, W_comparison, async_update=False)
    hopfield_recovered = ((hopfield_recovered + 1) / 2).view(8, 8)  # back to 0/1
    
    # RBM recovery: run Gibbs sampling
    v = corrupted_0_1.unsqueeze(0)
    for _ in range(10):  # 10 Gibbs steps
        h, _ = rbm.sample_hidden(v)
        v, _ = rbm.sample_visible(h)
    rbm_recovered = v.squeeze().view(8, 8)
    
    # Visualize
    fig, axes = plt.subplots(1, 4, figsize=(12, 3))
    
    axes[0].imshow(original.numpy(), cmap='gray_r', vmin=0, vmax=1)
    axes[0].set_title('Original')
    axes[0].axis('off')
    
    axes[1].imshow(corrupted_0_1.view(8, 8).numpy(), cmap='gray_r', vmin=0, vmax=1)
    axes[1].set_title(f'Corrupted ({noise_fraction*100:.0f}% noise)')
    axes[1].axis('off')
    
    axes[2].imshow(hopfield_recovered.numpy(), cmap='gray_r', vmin=0, vmax=1)
    hopfield_acc = (hopfield_recovered == original).float().mean()
    axes[2].set_title(f'Hopfield\n({hopfield_acc*100:.0f}% match)')
    axes[2].axis('off')
    
    axes[3].imshow(rbm_recovered.detach().numpy(), cmap='gray_r', vmin=0, vmax=1)
    rbm_acc = (rbm_recovered.round() == original).float().mean()
    axes[3].set_title(f'RBM\n({rbm_acc*100:.0f}% match)')
    axes[3].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return hopfield_acc.item(), rbm_acc.item()


In [None]:
# Test on several patterns
for i in [0, 4, 6]:  # vertical line, diagonal, checkerboard
    print(f"\n--- Pattern {i} ---")
    test_pattern_recovery(i, noise_fraction=0.25)


**Key differences observed:**

| Aspect | Hopfield | RBM |
|--------|----------|-----|
| **Training** | One-shot (Hebbian) | Iterative (CD) |
| **Storage** | Patterns in weights directly | Learned latent representation |
| **Recovery** | Deterministic (may get stuck) | Stochastic (explores) |
| **Capacity** | ~0.14N patterns | Can model distributions, not just patterns |

The Hopfield network directly memorizes patterns; the RBM learns a generative model of the data distribution. RBMs can generate new samples that look like training data, while Hopfield only retrieves stored patterns.

---

## Training a Production-Scale Boltzmann Machine

Now let's build and train a larger RBM on real data (MNIST digits), with proper production monitoring using Weights & Biases.


In [None]:
# Production RBM with proper monitoring
import wandb
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms


In [None]:
# | export
class ProductionRBM(nn.Module):
    """
    Production-ready RBM with:
    - Proper initialization (Xavier-like for stable training)
    - Momentum-based training (faster convergence)
    - Weight regularization (prevent explosion)
    - Monitoring utilities (dead units, weight stats)
    """
    
    def __init__(self, n_visible, n_hidden, device='cpu'):
        super().__init__()
        self.n_visible = n_visible
        self.n_hidden = n_hidden
        self.device = device
        
        # Xavier-like initialization scaled for RBMs
        # Why 4*sqrt(6/(n_in + n_out))? Ensures initial activations aren't saturated
        std = 4 * np.sqrt(6. / (n_visible + n_hidden))
        self.W = nn.Parameter(torch.randn(n_visible, n_hidden, device=device) * std)
        self.a = nn.Parameter(torch.zeros(n_visible, device=device))
        self.b = nn.Parameter(torch.zeros(n_hidden, device=device))
        
        # Momentum terms for faster convergence
        self.W_momentum = torch.zeros_like(self.W)
        self.a_momentum = torch.zeros_like(self.a)
        self.b_momentum = torch.zeros_like(self.b)
        
    def sample_hidden(self, v):
        prob_h = torch.sigmoid(F.linear(v, self.W.t(), self.b))
        return torch.bernoulli(prob_h), prob_h
    
    def sample_visible(self, h):
        prob_v = torch.sigmoid(F.linear(h, self.W, self.a))
        return torch.bernoulli(prob_v), prob_v
    
    def free_energy(self, v):
        """Free energy for monitoring."""
        vbias_term = torch.sum(v * self.a, dim=-1)
        wx_b = F.linear(v, self.W.t(), self.b)
        hidden_term = torch.sum(F.softplus(wx_b), dim=-1)
        return -vbias_term - hidden_term
    
    def get_weight_stats(self):
        """Get weight statistics for monitoring training health."""
        with torch.no_grad():
            return {
                'weight_mean': self.W.mean().item(),
                'weight_std': self.W.std().item(),
                'weight_max': self.W.abs().max().item(),
                'weight_norm': self.W.norm().item(),
                'hidden_bias_mean': self.b.mean().item(),
                'visible_bias_mean': self.a.mean().item(),
            }
    
    def get_hidden_activation_stats(self, v_batch):
        """Analyze hidden unit activation patterns."""
        with torch.no_grad():
            _, prob_h = self.sample_hidden(v_batch)
            active = (prob_h > 0.5).float()
            
            return {
                'hidden_prob_mean': prob_h.mean().item(),
                'hidden_prob_std': prob_h.std().item(),
                'fraction_active': active.mean().item(),
                'dead_units': (prob_h.mean(0) < 0.01).sum().item(),  # units rarely active
                'saturated_units': (prob_h.mean(0) > 0.99).sum().item(),  # units always active
            }


In [None]:
# | export
def train_rbm_production(
    rbm, 
    train_loader, 
    n_epochs=10,
    lr=0.01,
    momentum=0.9,
    weight_decay=0.0001,
    k=1,  # CD-k
    use_wandb=False
):
    """
    Train RBM with production best practices.
    
    Uses momentum SGD with weight decay, and optionally logs to wandb.
    """
    history = {
        'recon_error': [],
        'free_energy': [],
    }
    
    for epoch in range(n_epochs):
        epoch_recon = []
        epoch_fe = []
        
        for batch_idx, (data, _) in enumerate(train_loader):
            # Flatten and binarize (for MNIST)
            v_data = (data.view(-1, rbm.n_visible) > 0.5).float().to(rbm.device)
            batch_size = v_data.shape[0]
            
            # Positive phase
            h_data, prob_h_data = rbm.sample_hidden(v_data)
            
            # Negative phase (CD-k)
            v_neg = v_data.clone()
            for _ in range(k):
                h_neg, _ = rbm.sample_hidden(v_neg)
                v_neg, prob_v_neg = rbm.sample_visible(h_neg)
            _, prob_h_neg = rbm.sample_hidden(v_neg)
            
            # Gradients
            pos_grad_W = v_data.t() @ prob_h_data / batch_size
            neg_grad_W = v_neg.t() @ prob_h_neg / batch_size
            
            pos_grad_a = v_data.mean(0)
            neg_grad_a = v_neg.mean(0)
            
            pos_grad_b = prob_h_data.mean(0)
            neg_grad_b = prob_h_neg.mean(0)
            
            # Update with momentum and weight decay
            with torch.no_grad():
                # Momentum update
                rbm.W_momentum = momentum * rbm.W_momentum + lr * (pos_grad_W - neg_grad_W - weight_decay * rbm.W)
                rbm.a_momentum = momentum * rbm.a_momentum + lr * (pos_grad_a - neg_grad_a)
                rbm.b_momentum = momentum * rbm.b_momentum + lr * (pos_grad_b - neg_grad_b)
                
                rbm.W += rbm.W_momentum
                rbm.a += rbm.a_momentum
                rbm.b += rbm.b_momentum
            
            # Metrics
            recon_error = F.mse_loss(prob_v_neg, v_data).item()
            free_energy = rbm.free_energy(v_data).mean().item()
            
            epoch_recon.append(recon_error)
            epoch_fe.append(free_energy)
        
        # Epoch metrics
        avg_recon = np.mean(epoch_recon)
        avg_fe = np.mean(epoch_fe)
        
        history['recon_error'].append(avg_recon)
        history['free_energy'].append(avg_fe)
        
        # Get detailed stats
        weight_stats = rbm.get_weight_stats()
        activation_stats = rbm.get_hidden_activation_stats(v_data)
        
        # Log to wandb
        if use_wandb:
            wandb.log({
                'epoch': epoch,
                'recon_error': avg_recon,
                'free_energy': avg_fe,
                **weight_stats,
                **activation_stats
            })
        
        if epoch % 5 == 0:
            print(f"Epoch {epoch}: Recon={avg_recon:.4f}, FE={avg_fe:.1f}, "
                  f"Dead units={activation_stats['dead_units']}, "
                  f"W_norm={weight_stats['weight_norm']:.2f}")
    
    return history


In [None]:
# Load MNIST
transform = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Smaller batch for M3 Max - adjust based on your system
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Image size: 28x28 = 784 visible units")


### Training Configuration

For a production run on 2x Tesla T4 (56GB VRAM total), we'd use:
- **n_hidden = 2048** (large hidden layer for rich representations)
- **batch_size = 512** (utilize GPU parallelism)
- **n_epochs = 100+** (several hours of training)

For local M3 Max testing, we'll use smaller settings:


In [None]:
# Configuration for local training (M3 Max)
config = {
    'n_visible': 784,
    'n_hidden': 500,  # Moderate size for local
    'batch_size': 128,
    'n_epochs': 30,
    'lr': 0.01,
    'momentum': 0.9,
    'weight_decay': 0.0001,
    'k': 1,  # CD-1
}

# For production on Tesla T4s, uncomment:
# config = {
#     'n_visible': 784,
#     'n_hidden': 2048,
#     'batch_size': 512,
#     'n_epochs': 100,
#     'lr': 0.005,
#     'momentum': 0.9,
#     'weight_decay': 0.0001,
#     'k': 5,  # CD-5 for better gradient estimates
# }

print(f"Config: {config}")


In [None]:
# Initialize wandb (set use_wandb=True to log, False for local testing)
use_wandb = False  # Set to True to log to wandb

if use_wandb:
    wandb.init(
        project="boltzmann-machines",
        config=config,
        name=f"rbm-h{config['n_hidden']}-cd{config['k']}"
    )


In [None]:
# Create model
rbm_prod = ProductionRBM(
    n_visible=config['n_visible'],
    n_hidden=config['n_hidden'],
    device=device
)

print(f"Model parameters: {sum(p.numel() for p in rbm_prod.parameters()):,}")


In [None]:
# Train!
history = train_rbm_production(
    rbm_prod,
    train_loader,
    n_epochs=config['n_epochs'],
    lr=config['lr'],
    momentum=config['momentum'],
    weight_decay=config['weight_decay'],
    k=config['k'],
    use_wandb=use_wandb
)

if use_wandb:
    wandb.finish()


In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(history['recon_error'])
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Reconstruction Error')
ax1.set_title('Training: Reconstruction Error')
ax1.grid(True, alpha=0.3)

ax2.plot(history['free_energy'])
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Free Energy')
ax2.set_title('Training: Free Energy')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


In [None]:
# Visualize learned features (hidden unit weight vectors = learned "features")
def visualize_rbm_features(rbm, n_features=100):
    """Show what features the hidden units have learned."""
    W = rbm.W.detach().cpu()  # (n_visible, n_hidden)
    
    n_show = min(n_features, W.shape[1])
    n_cols = 10
    n_rows = n_show // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 1.5 * n_rows))
    
    for i, ax in enumerate(axes.flat):
        if i < n_show:
            feature = W[:, i].view(28, 28)
            ax.imshow(feature.numpy(), cmap='RdBu', vmin=-feature.abs().max(), vmax=feature.abs().max())
        ax.axis('off')
    
    plt.suptitle('Learned RBM Features (Hidden Unit Weights)', fontsize=14)
    plt.tight_layout()
    plt.show()

visualize_rbm_features(rbm_prod, n_features=100)


In [None]:
# Generate new samples from the RBM
def generate_samples(rbm, n_samples=10, n_gibbs=1000):
    """Generate samples by running Gibbs sampling from random initialization."""
    # Start from random visible state
    v = torch.bernoulli(torch.ones(n_samples, rbm.n_visible) * 0.5).to(rbm.device)
    
    # Run Gibbs sampling
    for _ in range(n_gibbs):
        h, _ = rbm.sample_hidden(v)
        v, _ = rbm.sample_visible(h)
    
    return v

samples = generate_samples(rbm_prod, n_samples=20, n_gibbs=1000)

fig, axes = plt.subplots(2, 10, figsize=(15, 3))
for i, ax in enumerate(axes.flat):
    ax.imshow(samples[i].detach().cpu().view(28, 28).numpy(), cmap='gray')
    ax.axis('off')
plt.suptitle('Samples Generated by RBM (Gibbs Sampling)', fontsize=14)
plt.tight_layout()
plt.show()


---

## Production Considerations

### How to Know if Training is Going Well

Unlike discriminative models where we track accuracy/loss directly, RBMs require more nuanced monitoring:

**Key metrics to watch:**

1. **Reconstruction Error**: Should decrease steadily. If it plateaus, try increasing k in CD-k.

2. **Free Energy Gap**: Compare average free energy of training data vs random samples. Training data should have lower free energy.

3. **Hidden Unit Statistics**:
   - **Dead units**: Units that never activate (<1% of the time). Indicates wasted capacity.
   - **Saturated units**: Units always active (>99%). Also wasted.
   - Ideal: diverse activation patterns with mean activation around 0.1-0.5.

4. **Weight Statistics**:
   - Weight norm shouldn't explode (use weight decay)
   - Mean should stay near 0
   - If weights become very large, gradients may become unstable


In [None]:
# Analyze hidden unit activation distribution
sample_batch = next(iter(train_loader))[0][:256]
v_sample = (sample_batch.view(-1, 784) > 0.5).float().to(device)

with torch.no_grad():
    _, prob_h = rbm_prod.sample_hidden(v_sample)
    
# Histogram of mean activations per hidden unit
mean_activations = prob_h.mean(0).cpu().numpy()

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.hist(mean_activations, bins=50, edgecolor='black', alpha=0.7)
ax1.axvline(x=0.01, color='r', linestyle='--', label='Dead threshold (1%)')
ax1.axvline(x=0.99, color='orange', linestyle='--', label='Saturated threshold (99%)')
ax1.set_xlabel('Mean Activation per Hidden Unit')
ax1.set_ylabel('Count')
ax1.set_title('Hidden Unit Activation Distribution')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Weight distribution
weights = rbm_prod.W.detach().cpu().flatten().numpy()
ax2.hist(weights, bins=100, edgecolor='black', alpha=0.7)
ax2.set_xlabel('Weight Value')
ax2.set_ylabel('Count')
ax2.set_title(f'Weight Distribution (mean={weights.mean():.3f}, std={weights.std():.3f})')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Dead units: {(mean_activations < 0.01).sum()}")
print(f"Saturated units: {(mean_activations > 0.99).sum()}")


### Time Complexity

**Full Boltzmann Machine**: 
- Inference requires sampling from the full joint distribution
- This is NP-hard in general: $O(2^N)$ states to consider
- Why? All units are connected, so we can't factorize the distribution

**Restricted Boltzmann Machine**:
- Inference is $O(N_{visible} \times N_{hidden})$ per Gibbs step (matrix multiply)
- The bipartite structure means we can sample all hidden (or visible) units in parallel
- Training with CD-k: $O(k \times N_v \times N_h \times \text{batch\_size})$ per iteration

### Training Stability Tricks

**From transformer land (adapted for RBMs):**

1. **Weight initialization**: Xavier-like scaling prevents vanishing/exploding activations
2. **Learning rate warmup**: Start with small LR, increase gradually
3. **Gradient clipping**: Clip large gradients to prevent instability
4. **Weight decay**: L2 regularization prevents weight explosion
5. **Momentum**: SGD with momentum smooths gradient updates

**RBM-specific:**

1. **CD-k**: Higher k gives better gradient estimates (but slower). Start with CD-1, increase if needed.
2. **Persistent Contrastive Divergence (PCD)**: Maintain Gibbs chains across batches for better mixing
3. **Sparsity penalty**: Encourage sparse hidden activations for better features
4. **Adaptive learning rates**: Different LR for weights vs biases


---

## Modern Hopfield Networks and Transformer Attention

The classic Hopfield network stores ~0.14N patterns. But in 2020, Ramsauer et al. showed that **continuous Hopfield networks with exponential energy** can store **exponentially many patterns**—and their update rule is exactly **transformer attention**!

### Why Change the Energy Function?

The classic Hopfield energy $E = -\frac{1}{2} \mathbf{x}^T W \mathbf{x}$ is **quadratic** in x. This limits capacity.

**Key insight**: Make the energy **exponential** in the similarity between query and patterns. Why exponential?

1. **Sharper separation**: Exponentials amplify differences. If pattern A has similarity 2 and B has similarity 1, quadratic gives ratio 4:1, but exponential gives ratio $e^2:e^1 \approx 7.4:1$
2. **Softer than argmax**: Unlike hard maximum, exponential allows "soft" retrieval of multiple similar patterns
3. **Natural normalization**: logsumexp is a "soft maximum" — it approximates max while being differentiable

### Deriving the Modern Hopfield Energy

**Goal**: Energy should be LOW when query x is similar to a stored pattern.

**Step 1: Measure similarity**

For stored patterns $\Xi = [\xi_1, ..., \xi_N]$, similarity scores are $\Xi^T \mathbf{x}$ (dot products).

**Step 2: We want energy low when ANY pattern is similar**

The "soft minimum over all patterns" is the negative of "soft maximum of similarities":

$$E = -\text{logsumexp}(\beta \cdot \Xi^T \mathbf{x})$$

where $\beta$ controls sharpness (like inverse temperature).

**Step 3: Add regularization**

To prevent x from blowing up (just making itself larger increases all similarities), add $\frac{1}{2}\|\mathbf{x}\|^2$:

$$E(\mathbf{x}) = -\text{logsumexp}(\beta \cdot \boldsymbol{\Xi}^T \mathbf{x}) + \frac{1}{2}\|\mathbf{x}\|^2$$

### The Update Rule = Attention!

To find the minimum energy state, we take the gradient and set to zero:

$$\nabla_{\mathbf{x}} E = -\boldsymbol{\Xi} \cdot \text{softmax}(\beta \cdot \boldsymbol{\Xi}^T \mathbf{x}) + \mathbf{x} = 0$$

Solving:

$$\mathbf{x}^{\text{new}} = \boldsymbol{\Xi} \cdot \text{softmax}(\beta \cdot \boldsymbol{\Xi}^T \mathbf{x})$$

**This IS attention!** Compare to transformer attention: $\text{Attention}(Q, K, V) = \text{softmax}(QK^T / \sqrt{d}) \cdot V$

The correspondence:
- Query x ↔ Q
- Patterns Ξ ↔ Keys K (and also Values V when K=V)
- $\beta$ ↔ $1/\sqrt{d}$


In [None]:
# Visualize: Classic vs Modern Hopfield Energy Landscapes (1D projection)
# This shows WHY modern Hopfield can store exponentially more patterns

# Two stored patterns in 2D for visualization
patterns_2d = torch.tensor([[1.0, 0.0], [0.0, 1.0]])  # Two orthogonal patterns

# Create grid of possible states
x_range = torch.linspace(-1.5, 1.5, 100)
y_range = torch.linspace(-1.5, 1.5, 100)
X, Y = torch.meshgrid(x_range, y_range, indexing='ij')
states = torch.stack([X.flatten(), Y.flatten()], dim=1)  # (10000, 2)

# Classic Hopfield energy: E = -0.5 * x^T W x where W = sum of outer products
W_classic = patterns_2d.t() @ patterns_2d  # (2, 2)
W_classic.fill_diagonal_(0)
E_classic = -0.5 * (states @ W_classic * states).sum(dim=1)  # (10000,)

# Modern Hopfield energy: E = -logsumexp(Xi^T x) + 0.5 ||x||^2
scores = states @ patterns_2d.t()  # (10000, 2)
E_modern = -torch.logsumexp(scores * 2.0, dim=1) + 0.5 * (states ** 2).sum(dim=1)

# Reshape for plotting
E_classic_2d = E_classic.view(100, 100)
E_modern_2d = E_modern.view(100, 100)

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Classic Hopfield
im1 = axes[0].contourf(X.numpy(), Y.numpy(), E_classic_2d.numpy(), levels=30, cmap='viridis')
axes[0].scatter(patterns_2d[:, 0].numpy(), patterns_2d[:, 1].numpy(), 
                c='red', s=200, marker='*', label='Stored patterns', zorder=5)
axes[0].set_xlabel('State dimension 1')
axes[0].set_ylabel('State dimension 2')
axes[0].set_title('Classic Hopfield Energy\\n(Quadratic: shallow, broad valleys)')
axes[0].legend()
axes[0].set_aspect('equal')
plt.colorbar(im1, ax=axes[0], label='Energy')

# Modern Hopfield  
im2 = axes[1].contourf(X.numpy(), Y.numpy(), E_modern_2d.numpy(), levels=30, cmap='viridis')
axes[1].scatter(patterns_2d[:, 0].numpy(), patterns_2d[:, 1].numpy(), 
                c='red', s=200, marker='*', label='Stored patterns', zorder=5)
axes[1].set_xlabel('State dimension 1')
axes[1].set_ylabel('State dimension 2')
axes[1].set_title('Modern Hopfield Energy\\n(Exponential: deep, sharp valleys)')
axes[1].legend()
axes[1].set_aspect('equal')
plt.colorbar(im2, ax=axes[1], label='Energy')

plt.tight_layout()
plt.show()

print("Notice: Modern Hopfield has SHARPER minima around stored patterns!")
print("This is why it can store exponentially more patterns without interference.")


In [None]:
# Visualize: logsumexp as a "soft maximum"
# This is the key insight connecting Hopfield energy to attention

x = torch.linspace(-3, 3, 100)

# Consider logsumexp of [x, 0] - how does it compare to max(x, 0)?
lse = torch.logsumexp(torch.stack([x, torch.zeros_like(x)], dim=1), dim=1)
hard_max = torch.maximum(x, torch.zeros_like(x))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Left: Compare logsumexp to max
ax1.plot(x.numpy(), lse.numpy(), 'b-', linewidth=2, label='logsumexp(x, 0)')
ax1.plot(x.numpy(), hard_max.numpy(), 'r--', linewidth=2, label='max(x, 0)')
ax1.axhline(y=0, color='gray', linestyle=':', alpha=0.5)
ax1.axvline(x=0, color='gray', linestyle=':', alpha=0.5)
ax1.set_xlabel('x')
ax1.set_ylabel('Output')
ax1.set_title('logsumexp is a "Soft Maximum"\\n(Smooth, differentiable everywhere)')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Right: Effect of beta (temperature) on softmax
betas = [0.5, 1.0, 2.0, 5.0, 10.0]
scores = torch.tensor([1.0, 0.5, 0.3, 0.1, 0.05])  # Different similarities

for beta in betas:
    probs = F.softmax(beta * scores, dim=0).numpy()
    ax2.plot(range(5), probs, 'o-', linewidth=2, markersize=8, label=f'β={beta}')

ax2.set_xlabel('Pattern index')
ax2.set_ylabel('Attention weight')
ax2.set_title('Higher β → Sharper Attention\\n(More like argmax)')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Key insight: β controls the exploration/exploitation tradeoff")
print("  - Low β: Soft attention, consider all patterns (exploration)")
print("  - High β: Hard attention, pick the best match (exploitation)")


In [None]:
# | export
def modern_hopfield_energy(x, patterns, beta=1.0):
    """
    Modern Hopfield energy with exponential interaction.
    
    E(x) = -logsumexp(beta * Xi^T x) + 0.5 * ||x||^2
    
    This energy has EXPONENTIALLY many local minima (one per pattern),
    compared to O(N) for classic Hopfield!
    
    Args:
        x: Query state (d,) or (batch, d)
        patterns: Stored patterns (n_patterns, d)
        beta: Inverse temperature (higher = sharper attention)
    """
    # Xi^T x: similarity scores (n_patterns,) or (batch, n_patterns)
    scores = x @ patterns.t() * beta
    return -torch.logsumexp(scores, dim=-1) + 0.5 * (x ** 2).sum(-1)


In [None]:
# | export
def modern_hopfield_update(x, patterns, beta=1.0):
    """
    One step of Modern Hopfield dynamics.
    
    x_new = Xi * softmax(beta * Xi^T x)
    
    This IS attention: Query x attends to keys/values Xi.
    The output is a weighted sum of patterns, with weights determined
    by similarity to the query.
    """
    # Attention scores: how similar is x to each pattern?
    scores = x @ patterns.t() * beta  # (batch, n_patterns)
    attention_weights = F.softmax(scores, dim=-1)  # (batch, n_patterns)
    
    # Weighted combination of patterns
    x_new = attention_weights @ patterns  # (batch, d)
    
    return x_new, attention_weights


In [None]:
# Demo: Modern Hopfield pattern retrieval

# Create some random "memory" patterns (like tokens in a sequence)
d = 64  # dimension
n_patterns = 10

# Stored patterns (normalized for stability)
stored = torch.randn(n_patterns, d)
stored = stored / stored.norm(dim=1, keepdim=True)  # normalize

print(f"Stored {n_patterns} patterns of dimension {d}")


In [None]:
# Create a noisy query (corrupted pattern 3)
query = stored[3] + 0.5 * torch.randn(d)
query = query / query.norm()
query = query.unsqueeze(0)  # add batch dim

print(f"Query similarity to stored patterns:")
print((query @ stored.t()).squeeze())


In [None]:
# Run modern Hopfield update (= one attention step)
retrieved, attn_weights = modern_hopfield_update(query, stored, beta=8.0)

print(f"\nAttention weights (should peak at pattern 3):")
print(attn_weights.squeeze())
print(f"\nMax attention on pattern: {attn_weights.argmax().item()}")


In [None]:
# Compare: recovered pattern vs original
print(f"Cosine similarity to pattern 3: {F.cosine_similarity(retrieved, stored[3].unsqueeze(0)).item():.4f}")
print(f"Original query similarity to pattern 3: {F.cosine_similarity(query, stored[3].unsqueeze(0)).item():.4f}")


In [None]:
# Visualize how beta (temperature) affects attention sharpness
betas = [0.5, 1.0, 2.0, 4.0, 8.0, 16.0]

fig, axes = plt.subplots(2, 3, figsize=(15, 8))

for ax, beta in zip(axes.flat, betas):
    _, attn = modern_hopfield_update(query, stored, beta=beta)
    attn_np = attn.squeeze().numpy()
    
    ax.bar(range(len(attn_np)), attn_np, color='steelblue')
    ax.axvline(x=3, color='red', linestyle='--', alpha=0.7, label='True pattern (3)')
    ax.set_xlabel('Pattern Index')
    ax.set_ylabel('Attention Weight')
    ax.set_title(f'β = {beta} (higher = sharper)')
    ax.set_ylim(0, 1)
    ax.legend()

plt.suptitle('Attention Sharpness vs Temperature (β)\nHigher β → more like argmax → retrieves single pattern', fontsize=14)
plt.tight_layout()
plt.show()


### The Transformer Connection

In a transformer attention operates as:

**Keys K** = stored patterns (what to match against)

**Values V** = what to retrieve (often same as K in self-attention)

**Queries Q** = what we're looking for

Standard attention: Attention(Q, K, V) = softmax(QK^T / sqrt(d)) V

The sqrt(d) scaling is like beta = 1/sqrt(d) in Hopfield terms.

**What this means:**
- Each attention head is doing **content-based memory retrieval**
- The query asks "what stored patterns am I similar to?"
- The output is a weighted average of the matching values
- This is exactly what Hopfield networks do, but continuously and differentiably

**The deep insight:** Transformers are doing **energy minimization on an exponential energy landscape**, where stored patterns (context tokens) are attractors. Attention is one step of this minimization.


In [None]:
# Show that attention IS the Hopfield update
def standard_attention(Q, K, V):
    """Standard scaled dot-product attention."""
    d_k = K.shape[-1]
    scores = Q @ K.t() / np.sqrt(d_k)
    attn_weights = F.softmax(scores, dim=-1)
    return attn_weights @ V, attn_weights

# Use same patterns as keys and values
K = stored  # (n_patterns, d)
V = stored  # (n_patterns, d)
Q = query   # (1, d)

# Standard attention
attn_out, attn_weights_std = standard_attention(Q, K, V)

# Modern Hopfield (with matching beta)
beta_equivalent = 1.0 / np.sqrt(d)
hopfield_out, hopfield_weights = modern_hopfield_update(Q, stored, beta=beta_equivalent)

print("Attention weights:")
print(attn_weights_std.squeeze()[:5].detach().numpy(), "...")
print("\nHopfield weights:")
print(hopfield_weights.squeeze()[:5].detach().numpy(), "...")

print(f"\nAre they identical? {torch.allclose(attn_weights_std, hopfield_weights)}")


### Exponential Storage Capacity

Classic Hopfield: O(N) patterns for N neurons.

Modern Hopfield with exponential energy: O(2^(d/2)) patterns for dimension d!

This means a 512-dimensional attention head can theoretically store around 2^256 patterns. In practice, transformers work with context lengths of thousands of tokens, which is far below this theoretical limit.


---

## Other Cool Energy-Based Insights

### Energy-Based Models (EBMs) as a Unifying Framework

Many ML models can be viewed through the lens of energy functions:

1. **Hopfield/Boltzmann**: E(x) = -x^T W x (quadratic)
2. **Modern Hopfield/Attention**: E(x) = -logsumexp(Xi^T x) (exponential)
3. **Variational Autoencoders**: ELBO is an energy bound
4. **Diffusion Models**: Learn to denoise by estimating score = gradient of log probability = -gradient of energy
5. **Contrastive Learning**: Pushes positive pairs to low energy, negatives to high energy

### Connection to Diffusion Models

Diffusion models learn the **score function**: the gradient of log probability.

Since P(x) ∝ exp(-E(x)), we have:
- log P(x) = -E(x) + const
- Score = ∇ log P(x) = -∇E(x)

Denoising is literally **gradient descent on energy**! The model learns to push noisy samples toward low-energy (high probability) regions.

### Hopfield Networks in Biological Memory

The energy landscape metaphor connects to neuroscience:
- **Attractor dynamics**: Brain states settle into stable "attractors"
- **Memory consolidation**: During sleep, the brain may be doing something like simulated annealing
- **Pattern completion**: Hippocampus does content-addressable memory retrieval

This isn't just a metaphor—there's real evidence that neural populations exhibit attractor dynamics during memory retrieval.


---

## Summary

We've built intuition from first principles:

1. **Energy minimization** is nature's universal optimization algorithm
2. **Hopfield networks** encode memories as valleys in an energy landscape
3. **Hebbian learning** creates these valleys through outer products
4. **Dynamics** evolve states toward nearest memory (attractor)
5. **Capacity limits** exist (~0.14N patterns for classic Hopfield)
6. **Boltzmann machines** add temperature for exploration
7. **RBMs** make this tractable with bipartite structure
8. **Modern Hopfield networks** use exponential energy for massive capacity
9. **Transformer attention IS Hopfield retrieval** — one step of energy minimization

The energy perspective unifies many seemingly disparate ideas in machine learning, from memory networks to diffusion models to attention mechanisms.


### Final Challenge: Build a Multi-Head Hopfield Attention Layer

Implement a multi-head version of modern Hopfield attention that matches PyTorch's MultiheadAttention interface:


In [None]:
class TestHopfieldMultiheadAttention(nn.Module):
    """
    Multi-head attention implemented as multi-head Hopfield retrieval.
    
    TODO: Complete the forward method.
    """
    def __init__(self, test_embed_dim, test_num_heads):
        super().__init__()
        self.test_embed_dim = test_embed_dim
        self.test_num_heads = test_num_heads
        self.test_head_dim = test_embed_dim // test_num_heads
        
        # Linear projections for Q, K, V
        self.test_W_q = nn.Linear(test_embed_dim, test_embed_dim)
        self.test_W_k = nn.Linear(test_embed_dim, test_embed_dim)
        self.test_W_v = nn.Linear(test_embed_dim, test_embed_dim)
        self.test_W_out = nn.Linear(test_embed_dim, test_embed_dim)
    
    def forward(self, test_query, test_key, test_value):
        """
        Args:
            test_query: (batch, seq_q, embed_dim)
            test_key: (batch, seq_k, embed_dim)
            test_value: (batch, seq_v, embed_dim)
        
        Returns:
            output: (batch, seq_q, embed_dim)
        """
        # FILL IN CODE HERE
        # 1. Project Q, K, V
        # 2. Reshape for multi-head: (batch, seq, n_heads, head_dim) -> (batch, n_heads, seq, head_dim)
        # 3. For each head, apply Hopfield update: softmax(Q @ K^T / sqrt(d)) @ V
        # 4. Concatenate heads and apply output projection
        
        pass

# Test: should work like regular attention
test_batch, test_seq = 2, 5
test_embed = 64
test_heads = 4

test_model = TestHopfieldMultiheadAttention(test_embed, test_heads)
test_x = torch.randn(test_batch, test_seq, test_embed)

# Uncomment after implementing:
# test_out = test_model(test_x, test_x, test_x)  # self-attention
# assert test_out.shape == (test_batch, test_seq, test_embed), f"Expected shape {(test_batch, test_seq, test_embed)}, got {test_out.shape}"
# print("✓ Passed: Output shape is correct")


---

## Shortcuts and Libraries

Now that you understand the fundamentals, here are production shortcuts:

**For RBMs:**
- `sklearn.neural_network.BernoulliRBM`: Simple RBM implementation
- `pytorch-rbm`: GPU-accelerated RBM with CD-k

**For Attention as Hopfield:**
- `hopfield-layers` package: Implements modern Hopfield networks from the 2020 paper
- PyTorch's `nn.MultiheadAttention`: Standard attention (which IS Hopfield)

**For Energy-Based Models:**
- `ebm-torch`: General EBM training
- Diffusion libraries (e.g., `diffusers`): EBMs trained with score matching

The Hopfield perspective helps debug attention: if attention weights are too uniform (high temperature), patterns aren't being retrieved sharply. If weights collapse to single tokens, you're over-indexing on specific memories.liked

## Further Reading

1. **Original Hopfield Paper**: Hopfield, J.J. (1982). "Neural networks and physical systems with emergent collective computational abilities"

2. **Boltzmann Machines**: Hinton & Sejnowski (1986). "Learning and relearning in Boltzmann machines"

3. **Contrastive Divergence**: Hinton (2002). "Training products of experts by minimizing contrastive divergence"

4. **Modern Hopfield Networks**: Ramsauer et al. (2020). "Hopfield Networks is All You Need" — The key paper connecting Hopfield to attention

5. **Energy-Based Models**: LeCun et al. (2006). "A Tutorial on Energy-Based Learning"

6. **Score Matching & Diffusion**: Song & Ermon (2019). "Generative Modeling by Estimating Gradients of the Data Distribution"
