# Unified Mamba-Hopfield-DEQ Demo

This notebook demonstrates the key capabilities of the unified architecture.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from src.models.unified import UnifiedMambaHopfieldDEQ
from experiments.theory.convergence_proofs import ConvergenceValidator
from experiments.theory.energy_analysis import EnergyLandscapeAnalyzer

# Initialize model
model = UnifiedMambaHopfieldDEQ(
    vocab_size=1000,
    d_model=128,
    d_state=32,
    memory_size=500,
    max_iterations=20
)

print("Model initialized!")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

## Basic Usage

Let's run a simple forward pass and examine the equilibrium.

In [None]:
# Create dummy input
input_ids = torch.randint(0, 1000, (1, 20))

# Forward with diagnostics
logits, diag = model(input_ids, return_diagnostics=True)

print("Forward pass complete!")
print(f"Converged: {diag['solver_info']['converged']}")
print(f"Iterations: {diag['solver_info']['iterations']}")
print(f"Final energy: {diag['solver_info']['final_energy']:.4f}")

# Visualize convergence
if 'energy_history' in diag['solver_info']:
    plt.figure(figsize=(10, 4))
    plt.plot(diag['solver_info']['energy_history'], 'o-')
    plt.xlabel('Iteration')
    plt.ylabel('Energy')
    plt.title('Energy During Convergence')
    plt.grid(True, alpha=0.3)
    plt.show()

## Memory Dynamics

Examine how memory patterns are stored and retrieved.

In [None]:
# Check current memory usage
memory_stats = diag['memory_usage']
print(f"Memory attention entropy: {memory_stats['attention_entropy']:.4f}")
print(f"Top-10 pattern mass: {memory_stats['top_10_mass']:.4f}")

# Visualize attention over memory
z_eq = diag['z_equilibrium']
with torch.no_grad():
    similarities = torch.matmul(z_eq, model.memory_patterns.T)
    attention = torch.softmax(similarities, dim=-1)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.bar(range(len(attention[0])), attention[0].cpu().numpy())
plt.xlabel('Memory Pattern Index')
plt.ylabel('Attention Weight')
plt.title('Memory Retrieval Pattern')

plt.subplot(1, 2, 2)
top_k = 20
top_indices = attention[0].topk(top_k).indices.cpu().numpy()
top_values = attention[0].topk(top_k).values.cpu().numpy()
plt.barh(range(top_k), top_values)
plt.xlabel('Attention Weight')
plt.ylabel('Pattern Rank')
plt.title(f'Top {top_k} Retrieved Patterns')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()

## Theoretical Validation

Test convergence properties empirically.

In [None]:
validator = ConvergenceValidator(model)

print('Running convergence tests...')
print('
1. Testing contraction property...')
contraction_results = validator.test_contraction_property(num_samples=20)

print('
2. Testing energy descent...')
descent_results = validator.test_energy_descent(num_trajectories=10, num_steps=20)

print('
3. Testing fixed-point stability...')
stability_results = validator.test_fixed_point_stability(num_fixed_points=5)

print('
' + '='*50)
print('RESULTS SUMMARY')
print('='*50)
print(f"✓ Contraction: {contraction_results['is_contraction']}")
print(f"✓ Energy Descent: {descent_results['monotonic_descent']}")
print(f"✓ Stability: {stability_results['is_stable']}")

## Energy Landscape Visualization

Visualize the energy surface around an equilibrium.

In [None]:
analyzer = EnergyLandscapeAnalyzer(model)

# Get an equilibrium point
with torch.no_grad():
    z_init = torch.randn(1, model.d_model)
    context = torch.randn(1, 10, model.d_model)
    z_eq, _ = model.solver.solve(z_init, context, model.memory_patterns)

# Visualize 2D slice
print('Computing energy landscape (this may take a minute)...')
energies, residuals = analyzer.visualize_2d_slice(
    z_eq, context, resolution=30, radius=2.0
)

print('Landscape visualization saved!')

## Associative Memory Capabilities

Test key-value retrieval.

In [None]:
def test_associative_recall(model, num_pairs=10):
    """Simple associative recall test."""
    keys = torch.randint(0, 500, (num_pairs,))
    values = torch.randint(500, 1000, (num_pairs,))

    sequence = []
    for key, value in zip(keys, values):
        sequence.extend([key.item(), value.item()])

    query_idx = torch.randint(0, num_pairs, (1,)).item()
    query_key = keys[query_idx].item()
    target_value = values[query_idx].item()

    sequence.append(query_key)
    input_ids = torch.tensor(sequence).unsqueeze(0)

    with torch.no_grad():
        logits = model(input_ids)
        prediction = logits[0, -1].argmax().item()

    correct = (prediction == target_value)
    return correct, prediction, target_value

print('Testing associative recall...')
num_trials = 20
correct_count = 0

for trial in range(num_trials):
    correct, pred, target = test_associative_recall(model, num_pairs=10)
    correct_count += correct
    if trial < 5:
        print(f"Trial {trial+1}: Pred={pred}, Target={target}, {'✓' if correct else '✗'}")

accuracy = correct_count / num_trials
print(f"
Accuracy: {accuracy:.1%} ({correct_count}/{num_trials})")

## Interactive Exploration

Modify parameters and observe effects.

In [None]:
from ipywidgets import interact, FloatSlider, IntSlider

@interact(
    beta=FloatSlider(min=0.1, max=5.0, step=0.1, value=2.0),
    max_iter=IntSlider(min=5, max=50, step=5, value=20),
    tolerance=FloatSlider(min=1e-4, max=1e-2, step=1e-4, value=1e-3)
)
def explore_parameters(beta, max_iter, tolerance):
    """Interactive parameter exploration."""
    model.energy_fn.beta = beta
    model.dynamics.beta = beta
    model.solver.max_iter = max_iter
    model.solver.tol_fp = tolerance
    model.solver.tol_energy = tolerance

    input_ids = torch.randint(0, 1000, (1, 20))
    with torch.no_grad():
        logits, diag = model(input_ids, return_diagnostics=True)

    info = diag['solver_info']
    print(f"Converged: {info['converged']}")
    print(f"Iterations: {info['iterations']}")
    print(f"Final energy: {info['final_energy']:.4f}")

    if 'energy_history' in info:
        plt.figure(figsize=(8, 4))
        plt.plot(info['energy_history'], 'o-')
        plt.xlabel('Iteration')
        plt.ylabel('Energy')
        plt.title(f'Convergence (β={beta}, max_iter={max_iter}, tol={tolerance})')
        plt.grid(True, alpha=0.3)
        plt.show()