Skip to content
/ ebrm Public

Energy-Based Reasoning Models: Interactive Visualization & PyTorch Implementation

Notifications You must be signed in to change notification settings

Shr1ftyy/ebrm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Energy-Based Reasoning Models (EBRM)

A PyTorch library showcasing Energy-Based Models for structured reasoning tasks.

What are Energy-Based Models?

Energy-Based Models (EBMs) define a scalar energy function E(x) over configurations x. The key idea is simple:

  • Lower energy = more compatible/likely configuration
  • Higher energy = less compatible/unlikely configuration

For reasoning, this translates to:

  • E(problem, reasoning_chain, answer) → low energy if the reasoning is valid
  • E(problem, reasoning_chain, answer) → high energy if the reasoning is flawed

Why EBMs for Reasoning?

Traditional approaches generate reasoning directly (autoregressive). EBMs offer a complementary paradigm:

Aspect Autoregressive Energy-Based
Output Generates one solution Scores any solution
Verification Implicit Explicit
Multiple solutions Sequential sampling Parallel scoring
Inference compute Fixed Scalable (more search = better)

Key advantages:

  1. Verification without generation: Check if reasoning is valid without generating it
  2. Best-of-N selection: Generate multiple candidates, pick the lowest-energy one
  3. Inference-time scaling: More compute at inference → better results
  4. Principled uncertainty: Energy naturally captures confidence

Installation

# Clone the repository
git clone https://github.com/yourusername/ebrm.git
cd ebrm

# Install with uv (recommended)
uv sync

# Or with pip
pip install -e .

Quick Start

from ebrm import EBM, MLPEnergy, contrastive_loss
import torch

# Create an energy function
energy_fn = MLPEnergy(input_dim=10, hidden_dims=[64, 64])
ebm = EBM(energy_fn)

# Training: push down energy of positive examples, up for negative
positive = torch.randn(32, 10)  # "good" configurations
negative = torch.randn(32, 10)  # "bad" configurations

loss = contrastive_loss(energy_fn, positive, negative, margin=1.0)
loss.backward()

# Inference: find low-energy configurations via Langevin dynamics
samples = ebm.langevin_sample(
    x_init=torch.randn(10, 10),
    n_steps=100,
    step_size=0.01
)

Running the Demos

# Main conceptual demo
python main.py

# Arithmetic reasoning
python examples/arithmetic_reasoning.py

# Logical inference
python examples/logic_reasoning.py

# Chain-of-thought verification
python examples/chain_of_thought.py

Core Concepts

1. Energy Functions

The foundation of EBMs. Maps configurations to scalar energy:

from ebrm import MLPEnergy, JointEnergy

# Simple energy function over vectors
energy = MLPEnergy(input_dim=64, hidden_dims=[256, 256])
E_x = energy(x)  # Shape: (batch_size,)

# Joint energy over (input, output) pairs
joint_energy = JointEnergy(input_dim=32, output_dim=16)
E_xy = joint_energy(x, y)  # For conditional modeling

2. Training Methods

EBMs are trained by shaping the energy landscape:

from ebrm import contrastive_loss, noise_contrastive_estimation, denoising_score_matching

# Contrastive: E(positive) < E(negative) - margin
loss = contrastive_loss(energy_fn, positive, negative, margin=1.0)

# Noise Contrastive Estimation: distinguish data from noise
loss = noise_contrastive_estimation(energy_fn, data, noise_distribution)

# Denoising Score Matching: learn to denoise corrupted samples
loss = denoising_score_matching(energy_fn, data, noise_std=0.1)

3. Inference

Find low-energy configurations:

from ebrm import EBM, langevin_dynamics

ebm = EBM(energy_fn)

# Langevin dynamics: MCMC sampling from p(x) ∝ exp(-E(x))
samples = ebm.langevin_sample(x_init, n_steps=100, step_size=0.01)

# MAP inference: gradient descent to minimize energy
x_star = ebm.map_inference(x_init, n_steps=100)

4. Reasoning with EBMs

The ReasoningEBM class models multi-step reasoning:

from ebrm import ReasoningEBM

model = ReasoningEBM(
    state_dim=128,      # Dimension of reasoning states
    problem_dim=256,    # Dimension of problem encoding
    answer_dim=64,      # Dimension of answer encoding
    max_steps=10        # Maximum reasoning steps
)

# Score a reasoning chain
energy = model.energy(problem, reasoning_steps, answer)

# Generate low-energy reasoning
steps, answer = model.generate_chain(problem, n_steps=5)

# Best-of-N: generate multiple, pick lowest energy
best_steps, best_answer = model.best_of_n(problem, n_candidates=8)

Architecture

ebrm/
├── __init__.py          # Package exports
├── core.py              # EnergyFunction, EBM, ConditionalEBM
├── reasoning.py         # ReasoningEBM, ChainEnergyFunction, VerifierEBM
├── training.py          # Loss functions and training utilities
└── visualization.py     # Plotting utilities

examples/
├── arithmetic_reasoning.py  # Multi-step arithmetic
├── logic_reasoning.py       # Logical inference verification
└── chain_of_thought.py      # CoT verification and ranking

Key Classes

Class Description
EnergyFunction Abstract base for energy functions
MLPEnergy MLP-based energy function
JointEnergy Energy over (input, output) pairs
EBM Energy model with sampling methods
ConditionalEBM Conditional model for structured prediction
ReasoningEBM Multi-step reasoning with chain energy
VerifierEBM Verify reasoning chains

Training Losses

Loss Description Use Case
contrastive_loss Margin loss between pos/neg When you have explicit pairs
infoNCE_loss Softmax over multiple negatives Self-supervised learning
noise_contrastive_estimation Classify data vs noise When negatives are expensive
denoising_score_matching Learn to denoise Stable training
persistent_contrastive_divergence PCD with persistent chains Better gradient estimates

Theoretical Background

Energy and Probability

EBMs define a probability distribution via the Boltzmann distribution:

p(x) = exp(-E(x)) / Z

where Z is the partition function (normalizing constant). Lower energy → higher probability.

Contrastive Learning

Training pushes down energy of positive examples and up for negatives:

L = max(0, E(x⁺) - E(x⁻) + margin)

This shapes the energy landscape without computing Z explicitly.

Langevin Dynamics

Samples from p(x) ∝ exp(-E(x)) by iterating:

x_{t+1} = x_t - ε∇E(x_t) + √(2ε)·noise

As t → ∞, samples converge to the target distribution.

Reasoning as Energy Minimization

Given a problem, valid reasoning is a low-energy path:

E_total = E_start(problem → step_1)
        + Σ E_step(step_i → step_{i+1})
        + E_end(step_T → answer)

Inference finds the reasoning chain that minimizes this total energy.

References

  • LeCun, Y., et al. "A Tutorial on Energy-Based Learning" (2006)
  • Song & Ermon. "Generative Modeling by Estimating Gradients of the Data Distribution" (2019)
  • Du & Mordatch. "Implicit Generation and Generalization in Energy-Based Models" (2019)
  • Cobbe et al. "Training Verifiers to Solve Math Word Problems" (2021) - ORM/PRM
  • Lightman et al. "Let's Verify Step by Step" (2023) - Process Reward Models

License

MIT License - see LICENSE file for details.

About

Energy-Based Reasoning Models: Interactive Visualization & PyTorch Implementation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages