A PyTorch library showcasing Energy-Based Models for structured reasoning tasks.
Energy-Based Models (EBMs) define a scalar energy function E(x) over configurations x. The key idea is simple:
- Lower energy = more compatible/likely configuration
- Higher energy = less compatible/unlikely configuration
For reasoning, this translates to:
- E(problem, reasoning_chain, answer) → low energy if the reasoning is valid
- E(problem, reasoning_chain, answer) → high energy if the reasoning is flawed
Traditional approaches generate reasoning directly (autoregressive). EBMs offer a complementary paradigm:
| Aspect | Autoregressive | Energy-Based |
|---|---|---|
| Output | Generates one solution | Scores any solution |
| Verification | Implicit | Explicit |
| Multiple solutions | Sequential sampling | Parallel scoring |
| Inference compute | Fixed | Scalable (more search = better) |
Key advantages:
- Verification without generation: Check if reasoning is valid without generating it
- Best-of-N selection: Generate multiple candidates, pick the lowest-energy one
- Inference-time scaling: More compute at inference → better results
- Principled uncertainty: Energy naturally captures confidence
# Clone the repository
git clone https://github.com/yourusername/ebrm.git
cd ebrm
# Install with uv (recommended)
uv sync
# Or with pip
pip install -e .from ebrm import EBM, MLPEnergy, contrastive_loss
import torch
# Create an energy function
energy_fn = MLPEnergy(input_dim=10, hidden_dims=[64, 64])
ebm = EBM(energy_fn)
# Training: push down energy of positive examples, up for negative
positive = torch.randn(32, 10) # "good" configurations
negative = torch.randn(32, 10) # "bad" configurations
loss = contrastive_loss(energy_fn, positive, negative, margin=1.0)
loss.backward()
# Inference: find low-energy configurations via Langevin dynamics
samples = ebm.langevin_sample(
x_init=torch.randn(10, 10),
n_steps=100,
step_size=0.01
)# Main conceptual demo
python main.py
# Arithmetic reasoning
python examples/arithmetic_reasoning.py
# Logical inference
python examples/logic_reasoning.py
# Chain-of-thought verification
python examples/chain_of_thought.pyThe foundation of EBMs. Maps configurations to scalar energy:
from ebrm import MLPEnergy, JointEnergy
# Simple energy function over vectors
energy = MLPEnergy(input_dim=64, hidden_dims=[256, 256])
E_x = energy(x) # Shape: (batch_size,)
# Joint energy over (input, output) pairs
joint_energy = JointEnergy(input_dim=32, output_dim=16)
E_xy = joint_energy(x, y) # For conditional modelingEBMs are trained by shaping the energy landscape:
from ebrm import contrastive_loss, noise_contrastive_estimation, denoising_score_matching
# Contrastive: E(positive) < E(negative) - margin
loss = contrastive_loss(energy_fn, positive, negative, margin=1.0)
# Noise Contrastive Estimation: distinguish data from noise
loss = noise_contrastive_estimation(energy_fn, data, noise_distribution)
# Denoising Score Matching: learn to denoise corrupted samples
loss = denoising_score_matching(energy_fn, data, noise_std=0.1)Find low-energy configurations:
from ebrm import EBM, langevin_dynamics
ebm = EBM(energy_fn)
# Langevin dynamics: MCMC sampling from p(x) ∝ exp(-E(x))
samples = ebm.langevin_sample(x_init, n_steps=100, step_size=0.01)
# MAP inference: gradient descent to minimize energy
x_star = ebm.map_inference(x_init, n_steps=100)The ReasoningEBM class models multi-step reasoning:
from ebrm import ReasoningEBM
model = ReasoningEBM(
state_dim=128, # Dimension of reasoning states
problem_dim=256, # Dimension of problem encoding
answer_dim=64, # Dimension of answer encoding
max_steps=10 # Maximum reasoning steps
)
# Score a reasoning chain
energy = model.energy(problem, reasoning_steps, answer)
# Generate low-energy reasoning
steps, answer = model.generate_chain(problem, n_steps=5)
# Best-of-N: generate multiple, pick lowest energy
best_steps, best_answer = model.best_of_n(problem, n_candidates=8)ebrm/
├── __init__.py # Package exports
├── core.py # EnergyFunction, EBM, ConditionalEBM
├── reasoning.py # ReasoningEBM, ChainEnergyFunction, VerifierEBM
├── training.py # Loss functions and training utilities
└── visualization.py # Plotting utilities
examples/
├── arithmetic_reasoning.py # Multi-step arithmetic
├── logic_reasoning.py # Logical inference verification
└── chain_of_thought.py # CoT verification and ranking
| Class | Description |
|---|---|
EnergyFunction |
Abstract base for energy functions |
MLPEnergy |
MLP-based energy function |
JointEnergy |
Energy over (input, output) pairs |
EBM |
Energy model with sampling methods |
ConditionalEBM |
Conditional model for structured prediction |
ReasoningEBM |
Multi-step reasoning with chain energy |
VerifierEBM |
Verify reasoning chains |
| Loss | Description | Use Case |
|---|---|---|
contrastive_loss |
Margin loss between pos/neg | When you have explicit pairs |
infoNCE_loss |
Softmax over multiple negatives | Self-supervised learning |
noise_contrastive_estimation |
Classify data vs noise | When negatives are expensive |
denoising_score_matching |
Learn to denoise | Stable training |
persistent_contrastive_divergence |
PCD with persistent chains | Better gradient estimates |
EBMs define a probability distribution via the Boltzmann distribution:
p(x) = exp(-E(x)) / Z
where Z is the partition function (normalizing constant). Lower energy → higher probability.
Training pushes down energy of positive examples and up for negatives:
L = max(0, E(x⁺) - E(x⁻) + margin)
This shapes the energy landscape without computing Z explicitly.
Samples from p(x) ∝ exp(-E(x)) by iterating:
x_{t+1} = x_t - ε∇E(x_t) + √(2ε)·noise
As t → ∞, samples converge to the target distribution.
Given a problem, valid reasoning is a low-energy path:
E_total = E_start(problem → step_1)
+ Σ E_step(step_i → step_{i+1})
+ E_end(step_T → answer)
Inference finds the reasoning chain that minimizes this total energy.
- LeCun, Y., et al. "A Tutorial on Energy-Based Learning" (2006)
- Song & Ermon. "Generative Modeling by Estimating Gradients of the Data Distribution" (2019)
- Du & Mordatch. "Implicit Generation and Generalization in Energy-Based Models" (2019)
- Cobbe et al. "Training Verifiers to Solve Math Word Problems" (2021) - ORM/PRM
- Lightman et al. "Let's Verify Step by Step" (2023) - Process Reward Models
MIT License - see LICENSE file for details.