# Grokking + LLC with mHC (Manifold-Constrained Hyper-Connections)

This notebook compares standard MLP vs MLP with mHC on the modular addition grokking task.

## What is mHC?

**mHC = Manifold-Constrained Hyper-Connections** (DeepSeek, Dec 2025)

Instead of standard residual connections `x_{l+1} = x_l + F(x_l)`, mHC uses:

```
x_{l+1} = H_res @ x_l + H_post^T @ F(H_pre @ x_l)
```

Where:
- `H_res`: **Doubly stochastic matrix** (rows & cols sum to 1) via Sinkhorn-Knopp
- `H_pre`, `H_post`: Non-negative mixing matrices via softmax

**Benefits:**
- Prevents training instability
- Prevents signal explosion/collapse
- Better scaling to large models

## Research Questions

1. Does mHC help with grokking on modular addition?
2. Does mHC prevent the 0% test accuracy issue?
3. Does mHC lead to different LLC trajectories?
4. Is the learned solution simpler (lower LLC) or more complex (higher LLC)?

## Implementation

Using the PyTorch implementation from: https://github.com/tokenbender/mHC-manifold-constrained-hyper-connections

Paper: https://arxiv.org/abs/2512.24880

## 1. Setup and Install Dependencies

In [None]:
# Install required packages
%pip install devinterp scipy einops torch

import random
from copy import deepcopy
from dataclasses import dataclass
from typing import Callable
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from pathlib import Path
from scipy import stats

from einops import rearrange, einsum
from einops.layers.torch import Rearrange, Reduce

from devinterp.optim.sgld import SGLD
from devinterp.slt.sampler import estimate_learning_coeff_with_summary
from devinterp.utils import evaluate_ce

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
RESULTS_DIR = Path("../results/mhc_grokking")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

def save_fig(fig, name):
    fig.savefig(RESULTS_DIR / name, dpi=300, bbox_inches='tight')
    print(f"Saved: {name}")
    plt.show()

print(f"Device: {DEVICE}")
print(f"Results: {RESULTS_DIR}")

## 2. mHC Implementation (from GitHub repo)

Simplified version adapted for our modular addition task

In [None]:
# mHC helper functions

def exists(v):
    return v is not None

def default(v, d):
    return v if exists(v) else d

def sinkhorn_log(logits, num_iters=10, tau=0.05):
    """Sinkhorn-Knopp algorithm in log-space for numerical stability.
    
    Projects a matrix onto the doubly stochastic manifold (Birkhoff polytope).
    Result has non-negative entries where rows and columns sum to 1.
    
    Args:
        logits: Input matrix
        num_iters: Number of Sinkhorn iterations (default: 10)
        tau: Temperature parameter (default: 0.05)
    
    Returns:
        Doubly stochastic matrix
    """
    n = logits.shape[-1]
    Z = logits / tau
    log_marginal = torch.zeros((n,), device=logits.device, dtype=logits.dtype)
    
    u = torch.zeros(logits.shape[:-1], device=Z.device, dtype=Z.dtype)
    v = torch.zeros_like(u)
    
    # Iteratively normalize rows and columns
    for _ in range(num_iters):
        u = log_marginal - torch.logsumexp(Z + v.unsqueeze(-2), dim=-1)
        v = log_marginal - torch.logsumexp(Z + u.unsqueeze(-1), dim=-2)
    
    return torch.exp(Z + u.unsqueeze(-1) + v.unsqueeze(-2))


class SimpleHyperConnection(nn.Module):
    """Simplified mHC layer for combining two inputs (left and right embeddings).
    
    Instead of: x = x_left + x_right (standard addition)
    We use:     x = H_res @ [x_left, x_right] (learnable weighted combination)
    
    Where H_res is constrained to be doubly stochastic via Sinkhorn-Knopp.
    """
    
    def __init__(self, dim, num_streams=2, mhc_num_iters=20, mhc_tau=0.05):
        super().__init__()
        self.dim = dim
        self.num_streams = num_streams
        self.mhc_num_iters = mhc_num_iters
        self.mhc_tau = mhc_tau
        
        # Initialize H_res_logits
        # Start with near-identity: diagonal = 0, off-diagonal = -8
        # After sinkhorn, this becomes approximately I (identity)
        init_h_res = torch.full((num_streams, num_streams), -8.0)
        init_h_res.fill_diagonal_(0.0)
        self.H_res_logits = nn.Parameter(init_h_res)
    
    def forward(self, x_left, x_right):
        """Combine left and right inputs using doubly stochastic mixing.
        
        Args:
            x_left: [batch, dim] - left embedding
            x_right: [batch, dim] - right embedding
        
        Returns:
            [batch, dim] - mixed output
        """
        batch_size = x_left.shape[0]
        
        # Stack inputs: [batch, 2, dim]
        x_stacked = torch.stack([x_left, x_right], dim=1)
        
        # Apply Sinkhorn to get doubly stochastic matrix
        H_res = sinkhorn_log(
            self.H_res_logits, 
            num_iters=self.mhc_num_iters, 
            tau=self.mhc_tau
        )
        
        # Mix: [2, 2] @ [batch, 2, dim] -> [batch, 2, dim]
        # We want to combine the 2 streams into 1, so take first stream output
        x_mixed = einsum(H_res[0], x_stacked, 's, b s d -> b d')
        
        return x_mixed


print("✅ mHC implementation loaded!")

# Quick test
test_layer = SimpleHyperConnection(dim=16)
x_l = torch.randn(4, 16)
x_r = torch.randn(4, 16)
out = test_layer(x_l, x_r)
print(f"Test successful: input shapes ({x_l.shape}, {x_r.shape}) -> output shape {out.shape}")

# Verify doubly stochastic property
with torch.no_grad():
    H = sinkhorn_log(test_layer.H_res_logits)
    print(f"\nDoubly stochastic verification:")
    print(f"  Row sums: {H.sum(dim=1).numpy()}")
    print(f"  Col sums: {H.sum(dim=0).numpy()}")
    print(f"  All close to 1.0? {torch.allclose(H.sum(dim=1), torch.ones(2), atol=1e-5)}")

## 3. Model Architectures

Baseline MLP vs MLP with mHC

In [None]:
class BaselineMLP(nn.Module):
    """Standard MLP for modular addition (baseline)."""
    
    def __init__(self, vocab_size, embed_dim=12, hidden_size=48):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.linear1l = nn.Linear(embed_dim, hidden_size, bias=True)
        self.linear1r = nn.Linear(embed_dim, hidden_size, bias=True)
        self.act = nn.GELU()
        self.linear2 = nn.Linear(hidden_size, vocab_size, bias=False)
        self.vocab_size = vocab_size
    
    def forward(self, x):
        x = x.to(self.embedding.weight.device)
        x1 = self.embedding(x[..., 0])
        x2 = self.embedding(x[..., 1])
        x1 = self.linear1l(x1)
        x2 = self.linear1r(x2)
        x = x1 + x2  # Standard addition
        x = self.act(x)
        x = self.linear2(x)
        return x


class MLPWithMHC(nn.Module):
    """MLP with mHC for combining left/right embeddings."""
    
    def __init__(self, vocab_size, embed_dim=12, hidden_size=48, 
                 mhc_num_iters=20, mhc_tau=0.05):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.linear1l = nn.Linear(embed_dim, hidden_size, bias=True)
        self.linear1r = nn.Linear(embed_dim, hidden_size, bias=True)
        
        # mHC layer to combine left and right
        self.mhc = SimpleHyperConnection(
            dim=hidden_size,
            num_streams=2,
            mhc_num_iters=mhc_num_iters,
            mhc_tau=mhc_tau
        )
        
        self.act = nn.GELU()
        self.linear2 = nn.Linear(hidden_size, vocab_size, bias=False)
        self.vocab_size = vocab_size
    
    def forward(self, x):
        x = x.to(self.embedding.weight.device)
        x1 = self.embedding(x[..., 0])
        x2 = self.embedding(x[..., 1])
        x1 = self.linear1l(x1)
        x2 = self.linear1r(x2)
        
        # Use mHC instead of simple addition
        x = self.mhc(x1, x2)  # Learnable doubly stochastic mixing!
        
        x = self.act(x)
        x = self.linear2(x)
        return x


# Compare parameter counts
baseline = BaselineMLP(vocab_size=64)
mhc_model = MLPWithMHC(vocab_size=64)

baseline_params = sum(p.numel() for p in baseline.parameters())
mhc_params = sum(p.numel() for p in mhc_model.parameters())

print(f"Baseline MLP parameters: {baseline_params:,}")
print(f"MLP with mHC parameters: {mhc_params:,}")
print(f"Overhead: {mhc_params - baseline_params:,} (+{100*(mhc_params-baseline_params)/baseline_params:.2f}%)")
print(f"\n(mHC adds a 2×2 learnable mixing matrix = 4 parameters)")

## 4. Training Utilities

In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def test(model, dataset, device):
    n_correct = 0
    total_loss = 0
    model.eval()
    loss_fn = nn.CrossEntropyLoss()
    with torch.no_grad():
        for x, y in dataset:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = loss_fn(out, y)
            total_loss += loss.item()
            pred = torch.argmax(out, dim=-1)
            if pred == y:
                n_correct += 1
    return n_correct / len(dataset), total_loss / len(dataset)

def train(train_dataset, test_dataset, model, params, verbose=True):
    optimizer = torch.optim.Adam(
        model.parameters(), weight_decay=params.weight_decay, lr=params.lr
    )
    loss_fn = nn.CrossEntropyLoss()
    train_loader = DataLoader(train_dataset, batch_size=params.batch_size, shuffle=True)
    
    print_every = params.n_batches // params.print_times
    checkpoint_every = params.n_batches // params.n_checkpoints
    
    all_models = []
    loss_data = []
    
    if verbose:
        pbar = tqdm(total=params.n_batches, desc="Training")
    
    for i in range(params.n_batches):
        batch = next(iter(train_loader))
        X, Y = batch
        X, Y = X.to(params.device), Y.to(params.device)
        
        optimizer.zero_grad()
        out = model(X)
        loss = loss_fn(out, Y)
        loss.backward()
        optimizer.step()
        
        if (i + 1) % checkpoint_every == 0:
            all_models.append(deepcopy(model))
        
        if (i + 1) % print_every == 0:
            val_acc, val_loss = test(model, test_dataset, params.device)
            train_acc, train_loss = test(model, train_dataset, params.device)
            loss_data.append({
                "batch": i + 1,
                "train_loss": train_loss,
                "train_acc": train_acc,
                "val_loss": val_loss,
                "val_acc": val_acc,
            })
            if verbose:
                pbar.set_postfix({
                    "train_acc": f"{train_acc:.4f}",
                    "val_acc": f"{val_acc:.4f}",
                })
                pbar.update(print_every)
    
    if verbose:
        pbar.close()
    
    df = pd.DataFrame(loss_data)
    return all_models, df

def make_dataset(modulus, max_input_val):
    """Create dataset for modular addition.
    
    Args:
        modulus: The modulus for the operation (e.g., 64)
        max_input_val: Maximum value for inputs a,b (e.g., 63 for full task)
    """
    data = []
    for a in range(max_input_val + 1):
        for b in range(max_input_val + 1):
            x = torch.tensor([a, b])
            y = torch.tensor((a + b) % modulus)
            data.append((x, y))
    return data

def train_test_split(dataset, train_frac, seed):
    n = len(dataset)
    n_train = int(train_frac * n)
    indices = list(range(n))
    rng = random.Random(seed)
    rng.shuffle(indices)
    train_idx = indices[:n_train]
    test_idx = indices[n_train:]
    return [dataset[i] for i in train_idx], [dataset[i] for i in test_idx]

print("Training utilities defined!")

## 5. Experiment Parameters

In [None]:
@dataclass
class Params:
    modulus: int = 64
    n_batches: int = 100000  # More training for grokking
    n_checkpoints: int = 100
    print_times: int = 100
    lr: float = 0.001  # Slightly lower LR for stability
    batch_size: int = 128
    embed_dim: int = 12
    hidden_size: int = 48
    weight_decay: float = 1.0  # Higher for grokking
    device: str = DEVICE

SEEDS = [0, 1, 2]  # Run 3 seeds

# Create dataset (full mod-64 task)
dataset = make_dataset(modulus=64, max_input_val=63)
train_data, test_data = train_test_split(dataset, train_frac=0.3, seed=0)

print(f"Dataset: {len(dataset)} pairs total")
print(f"Train: {len(train_data)} pairs ({100*len(train_data)/len(dataset):.1f}%)")
print(f"Test: {len(test_data)} pairs ({100*len(test_data)/len(dataset):.1f}%)")
print(f"\nTraining for {Params().n_batches:,} batches with {len(SEEDS)} seeds")

## 6. Run Baseline MLP

In [None]:
baseline_results = {}

for seed in SEEDS:
    print("\n" + "="*80)
    print(f"BASELINE MLP - SEED {seed}")
    print("="*80)
    
    set_seed(seed)
    params = Params()
    
    model = BaselineMLP(
        vocab_size=params.modulus,
        embed_dim=params.embed_dim,
        hidden_size=params.hidden_size
    ).to(params.device)
    
    train_ds = [(x.to(DEVICE), y.to(DEVICE)) for x, y in train_data]
    test_ds = [(x.to(DEVICE), y.to(DEVICE)) for x, y in test_data]
    
    checkpoints, df = train(train_ds, test_ds, model, params, verbose=True)
    baseline_results[seed] = {'checkpoints': checkpoints, 'df': df}

print("\n✅ Baseline training complete!")

## 7. Run MLP with mHC

In [None]:
mhc_results = {}

for seed in SEEDS:
    print("\n" + "="*80)
    print(f"MLP WITH mHC - SEED {seed}")
    print("="*80)
    
    set_seed(seed)
    params = Params()
    
    model = MLPWithMHC(
        vocab_size=params.modulus,
        embed_dim=params.embed_dim,
        hidden_size=params.hidden_size,
        mhc_num_iters=20,
        mhc_tau=0.05
    ).to(params.device)
    
    train_ds = [(x.to(DEVICE), y.to(DEVICE)) for x, y in train_data]
    test_ds = [(x.to(DEVICE), y.to(DEVICE)) for x, y in test_data]
    
    checkpoints, df = train(train_ds, test_ds, model, params, verbose=True)
    mhc_results[seed] = {'checkpoints': checkpoints, 'df': df}

print("\n✅ mHC training complete!")

## 8. Compare Training Dynamics

In [None]:
# Aggregate across seeds
def aggregate_dfs(results_dict):
    all_dfs = [results_dict[seed]['df'] for seed in SEEDS]
    min_len = min(len(df) for df in all_dfs)
    
    metrics = {}
    for col in ['train_acc', 'val_acc', 'train_loss', 'val_loss']:
        values = np.array([df[col].values[:min_len] for df in all_dfs])
        metrics[col + '_mean'] = np.mean(values, axis=0)
        metrics[col + '_std'] = np.std(values, axis=0)
    return metrics, min_len

baseline_metrics, _ = aggregate_dfs(baseline_results)
mhc_metrics, _ = aggregate_dfs(mhc_results)

# Plot test accuracy
fig, ax = plt.subplots(figsize=(12, 6))
x = np.arange(len(baseline_metrics['val_acc_mean']))

# Baseline
ax.plot(x, baseline_metrics['val_acc_mean'], label='Baseline MLP', color='blue', linewidth=2)
ax.fill_between(x,
                baseline_metrics['val_acc_mean'] - baseline_metrics['val_acc_std'],
                baseline_metrics['val_acc_mean'] + baseline_metrics['val_acc_std'],
                alpha=0.3, color='blue')

# mHC
ax.plot(x, mhc_metrics['val_acc_mean'], label='MLP with mHC', color='orange', linewidth=2)
ax.fill_between(x,
                mhc_metrics['val_acc_mean'] - mhc_metrics['val_acc_std'],
                mhc_metrics['val_acc_mean'] + mhc_metrics['val_acc_std'],
                alpha=0.3, color='orange')

ax.set_xlabel('Checkpoint', fontsize=12)
ax.set_ylabel('Test Accuracy', fontsize=12)
ax.set_title(f'Grokking: Baseline vs mHC (mean ± std, n={len(SEEDS)} seeds)', fontsize=13)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_ylim([0, 1.05])
plt.tight_layout()
save_fig(fig, 'baseline_vs_mhc_accuracy.png')

# Plot generalization gap (train - test accuracy)
fig, ax = plt.subplots(figsize=(12, 6))
baseline_gap = baseline_metrics['train_acc_mean'] - baseline_metrics['val_acc_mean']
mhc_gap = mhc_metrics['train_acc_mean'] - mhc_metrics['val_acc_mean']

ax.plot(x, baseline_gap, label='Baseline Gap', color='blue', linewidth=2)
ax.plot(x, mhc_gap, label='mHC Gap', color='orange', linewidth=2)
ax.axhline(y=0, color='black', linestyle='--', alpha=0.3, label='Perfect Generalization')

ax.set_xlabel('Checkpoint', fontsize=12)
ax.set_ylabel('Generalization Gap (Train - Test Acc)', fontsize=12)
ax.set_title('Generalization Gap: Baseline vs mHC', fontsize=13)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
save_fig(fig, 'generalization_gap.png')

## 9. Statistical Comparison

In [None]:
# Compare final test accuracy
baseline_final = [baseline_results[seed]['df']['val_acc'].iloc[-1] for seed in SEEDS]
mhc_final = [mhc_results[seed]['df']['val_acc'].iloc[-1] for seed in SEEDS]

t_stat, p_value = stats.ttest_ind(baseline_final, mhc_final)

print("\n" + "="*80)
print("STATISTICAL COMPARISON")
print("="*80)
print(f"\nFinal Test Accuracy (n={len(SEEDS)} seeds):")
print(f"  Baseline: {np.mean(baseline_final):.4f} ± {np.std(baseline_final):.4f}")
print(f"  mHC:      {np.mean(mhc_final):.4f} ± {np.std(mhc_final):.4f}")
print(f"\nTwo-sample t-test:")
print(f"  t-statistic: {t_stat:.4f}")
print(f"  p-value:     {p_value:.4f}")

if p_value < 0.05:
    print("\n✅ Difference is statistically significant (p < 0.05)")
    if np.mean(mhc_final) > np.mean(baseline_final):
        print("   mHC performs BETTER than baseline")
    else:
        print("   mHC performs WORSE than baseline")
else:
    print("\n⚠️  No significant difference (p >= 0.05)")

# Cohen's d
pooled_std = np.sqrt((np.var(baseline_final) + np.var(mhc_final)) / 2)
cohens_d = (np.mean(mhc_final) - np.mean(baseline_final)) / pooled_std
print(f"\nEffect size (Cohen's d): {cohens_d:.4f}")
print("="*80)

## 10. LLC Estimation (Seed 0 Only)

Due to computational cost, we estimate LLC for seed 0 only

In [None]:
# LLC hyperparameters
llc_params = {
    'lr': 3e-3,
    'nbeta': 2.0,
    'gamma': 10.0,  # Adjusted for mod-64
    'num_chains': 3,
    'num_draws': 1000,
}

seed = 0
train_loader = DataLoader(
    [(x.to(DEVICE), y.to(DEVICE)) for x, y in train_data],
    batch_size=128,
    shuffle=True
)

# Estimate LLC for baseline (sample every 10th checkpoint)
print("\n" + "="*80)
print("LLC ESTIMATION - BASELINE")
print("="*80)

baseline_checkpoints = baseline_results[seed]['checkpoints']
baseline_llcs = []

for i in range(0, len(baseline_checkpoints), 10):
    print(f"Checkpoint {i+1}/{len(baseline_checkpoints)}")
    llc_stats = estimate_learning_coeff_with_summary(
        baseline_checkpoints[i],
        loader=train_loader,
        evaluate=evaluate_ce,
        sampling_method=SGLD,
        optimizer_kwargs=dict(
            lr=llc_params['lr'],
            nbeta=llc_params['nbeta'],
            localization=llc_params['gamma']
        ),
        num_chains=llc_params['num_chains'],
        num_draws=llc_params['num_draws'],
        device=DEVICE,
        online=False,
    )
    baseline_llcs.append(llc_stats)

# Estimate LLC for mHC
print("\n" + "="*80)
print("LLC ESTIMATION - mHC")
print("="*80)

mhc_checkpoints = mhc_results[seed]['checkpoints']
mhc_llcs = []

for i in range(0, len(mhc_checkpoints), 10):
    print(f"Checkpoint {i+1}/{len(mhc_checkpoints)}")
    llc_stats = estimate_learning_coeff_with_summary(
        mhc_checkpoints[i],
        loader=train_loader,
        evaluate=evaluate_ce,
        sampling_method=SGLD,
        optimizer_kwargs=dict(
            lr=llc_params['lr'],
            nbeta=llc_params['nbeta'],
            localization=llc_params['gamma']
        ),
        num_chains=llc_params['num_chains'],
        num_draws=llc_params['num_draws'],
        device=DEVICE,
        online=False,
    )
    mhc_llcs.append(llc_stats)

print("\n✅ LLC estimation complete!")

## 11. Plot LLC Trajectories

In [None]:
baseline_llc_values = [llc['llc/mean'] for llc in baseline_llcs]
mhc_llc_values = [llc['llc/mean'] for llc in mhc_llcs]

# Get corresponding accuracy values (sampled every 10th)
baseline_df = baseline_results[seed]['df']
mhc_df = mhc_results[seed]['df']

baseline_acc_sampled = baseline_df['val_acc'].values[::10]
mhc_acc_sampled = mhc_df['val_acc'].values[::10]

# Plot LLC vs Accuracy
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Baseline
color = 'blue'
ax1_twin = ax1.twinx()
ax1.plot(baseline_acc_sampled, label='Test Acc', color=color)
ax1_twin.plot(baseline_llc_values, label='LLC', color='green', linewidth=2)
ax1.set_xlabel('Checkpoint (×10)', fontsize=11)
ax1.set_ylabel('Test Accuracy', color=color, fontsize=11)
ax1_twin.set_ylabel('LLC (λ̂)', color='green', fontsize=11)
ax1.set_title('Baseline MLP: LLC vs Accuracy', fontsize=12)
ax1.tick_params(axis='y', labelcolor=color)
ax1_twin.tick_params(axis='y', labelcolor='green')
ax1.grid(True, alpha=0.3)

# mHC
color = 'orange'
ax2_twin = ax2.twinx()
ax2.plot(mhc_acc_sampled, label='Test Acc', color=color)
ax2_twin.plot(mhc_llc_values, label='LLC', color='green', linewidth=2)
ax2.set_xlabel('Checkpoint (×10)', fontsize=11)
ax2.set_ylabel('Test Accuracy', color=color, fontsize=11)
ax2_twin.set_ylabel('LLC (λ̂)', color='green', fontsize=11)
ax2.set_title('MLP with mHC: LLC vs Accuracy', fontsize=12)
ax2.tick_params(axis='y', labelcolor=color)
ax2_twin.tick_params(axis='y', labelcolor='green')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
save_fig(fig, 'llc_vs_accuracy.png')

# Direct LLC comparison
fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(baseline_llc_values, label='Baseline MLP', color='blue', linewidth=2)
ax.plot(mhc_llc_values, label='MLP with mHC', color='orange', linewidth=2)
ax.set_xlabel('Checkpoint (×10)', fontsize=12)
ax.set_ylabel('LLC (λ̂)', fontsize=12)
ax.set_title('LLC Trajectories: Baseline vs mHC', fontsize=13)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
save_fig(fig, 'llc_comparison.png')

print("\nLLC Summary:")
print(f"Baseline - Initial: {baseline_llc_values[0]:.2f}, Final: {baseline_llc_values[-1]:.2f}")
print(f"mHC      - Initial: {mhc_llc_values[0]:.2f}, Final: {mhc_llc_values[-1]:.2f}")

## 12. Summary

This experiment compared:
- **Baseline MLP**: Standard addition to combine embeddings
- **MLP with mHC**: Doubly stochastic mixing via Sinkhorn-Knopp

Key findings:
1. Training stability (did mHC prevent 0% test accuracy?)
2. Grokking behavior (faster/slower generalization?)
3. LLC evolution (simpler/more complex solutions?)
4. Final performance comparison

All results saved to: `{RESULTS_DIR}`