# Introduction to Equilibrium Propagation (EP)

Equilibrium Propagation (EP) is a biologically plausible alternative to backpropagation. Instead of propagating error signals backward through the network (which requires symmetric weights and separate phases), EP uses local learning rules based on the contrast between two phases of network dynamics:

1. **Free Phase**: The network settles to a state that minimizes its internal energy, given an input $x$.
2. **Nudged Phase**: The output layer is nudged towards the target $y$, and the network settles to a new state.

The gradient is estimated as:
$$ \nabla_W L \approx \frac{1}{\beta} (s_i^{\text{nudged}} s_j^{\text{nudged}} - s_i^{\text{free}} s_j^{\text{free}}) $$

This tutorial demonstrates how to use the `mep` library to train a simple neural network using EP.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from mep import smep

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Data Loading
We use the standard MNIST dataset.

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST("../data", train=True, download=True, transform=transform),
    batch_size=64, shuffle=True
)

# Visualize one batch
data_iter = iter(train_loader)
images, labels = next(data_iter)
plt.imshow(images[0].reshape(28, 28), cmap="gray")
plt.title(f"Label: {labels[0].item()}")
plt.show()

## 2. Model Definition
We define a simple Multi-Layer Perceptron (MLP). Note that we use standard `nn.Sequential`.

In [None]:
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
).to(device)

## 3. Optimizer Setup (SMEP)
We use the `smep` preset, which combines:
- **Spectral Normalization**: Ensures stability of the fixed point dynamics.
- **Muon Update**: Orthogonalizes gradients for better conditioning.
- **Equilibrium Propagation**: Computes gradients without backprop.

Crucially, we set `mode="ep"`.

In [None]:
optimizer = smep(
    model.parameters(),
    model=model,        # Pass model for EP to access structure
    mode="ep",          # Enable EP
    lr=0.05,
    beta=0.5,           # Nudging strength
    settle_steps=15,    # Steps for settling dynamics
    loss_type="cross_entropy"
)

## 4. Training Loop
The training loop is standard, but `optimizer.step()` handles the forward passes (free and nudged phases) internally. We don t call `loss.backward()`.

In [None]:
model.train()
epochs = 1

for epoch in range(epochs):
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        
        # EP Step: Free phase -> Nudged phase -> Update
        optimizer.step(x=data, target=target)
        
        # Calculate accuracy (optional, requires extra forward pass)
        with torch.no_grad():
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
            
        if batch_idx % 100 == 0:
            print(f"Epoch {epoch+1} | Batch {batch_idx}/{len(train_loader)} | Acc: {100. * correct / total:.2f}%")
            correct = 0
            total = 0

## Conclusion
You have successfully trained a neural network using Equilibrium Propagation! This method avoids the biological implausibility of backpropagation while achieving competitive results on simple tasks.