# Toroidal Geometry in EB-JEPA Representations

Self-contained notebook for:
1. **Phase 1**: Topological analysis of pretrained EB-JEPA embeddings
2. **Phase 2**: Training with toroidal regularization
3. **Evaluation**: Linear probing + topology metrics

Works on RunPod (A100/H100), Colab (T4), or any CUDA machine.

In [None]:
# Install dependencies
!pip install -q torch torchvision ripser persim umap-learn scikit-learn matplotlib pyyaml

In [None]:
import sys, os
import numpy as np
import torch
import matplotlib.pyplot as plt
from pathlib import Path

# Clone repo if not present
REPO_DIR = Path("jepa-torus")
if not REPO_DIR.exists():
    !git clone https://github.com/Paraxiom/jepa-torus.git

sys.path.insert(0, str(REPO_DIR))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

## Phase 1: Train Baseline EB-JEPA

In [None]:
from src.train import train

# Train baseline VICReg (reduce epochs for quick test)
baseline_config = {
    "model": {"embed_dim": 512, "hidden_dim": 1024, "ema_decay": 0.996},
    "training": {
        "epochs": 300,
        "batch_size": 256,
        "lr": 1e-3,
        "weight_decay": 0.05,
        "warmup_epochs": 10,
        "num_workers": 4,
        "seed": 42,
    },
    "loss": {"type": "vicreg", "lambda_std": 25.0, "lambda_cov": 1.0},
    "data": {"data_dir": "./data"},
    "output": {"dir": "./checkpoints/baseline_vicreg", "save_every": 50},
}

train(baseline_config)

## Phase 1: Topological Analysis of Baseline Embeddings

In [None]:
from src.eval import evaluate

# Evaluate baseline
baseline_results = evaluate(
    checkpoint_path="./checkpoints/baseline_vicreg/best.pt",
    run_linear_probe=True,
    run_analysis=True,
    run_umap=True,
    grid_size=12,
    output_dir="./results/baseline_vicreg",
)

In [None]:
# Persistence diagrams
from ripser import ripser
from persim import plot_diagrams

embeddings = np.load("./results/baseline_vicreg/embeddings.npy")

# Subsample for tractability
idx = np.random.choice(len(embeddings), 1000, replace=False)
X = embeddings[idx]
X = X / np.linalg.norm(X, axis=1, keepdims=True)

result = ripser(X, maxdim=2)

fig, ax = plt.subplots(1, 1, figsize=(8, 6))
plot_diagrams(result["dgms"], ax=ax)
ax.set_title("Persistence Diagram — Baseline VICReg EB-JEPA")
plt.tight_layout()
plt.savefig("./results/baseline_vicreg/persistence.png", dpi=150)
plt.show()

In [None]:
# UMAP 3D visualization
umap_3d = np.load("./results/baseline_vicreg/umap_3d.npy")

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(umap_3d[:, 0], umap_3d[:, 1], umap_3d[:, 2], s=1, alpha=0.3)
ax.set_title("UMAP 3D — Baseline VICReg EB-JEPA")
plt.tight_layout()
plt.savefig("./results/baseline_vicreg/umap_3d.png", dpi=150)
plt.show()

## Phase 2: Train with Toroidal Regularization

In [None]:
# Train toroidal N=12 (Tonnetz geometry)
toroidal_config = {
    "model": {"embed_dim": 512, "hidden_dim": 1024, "ema_decay": 0.996},
    "training": {
        "epochs": 300,
        "batch_size": 256,
        "lr": 1e-3,
        "weight_decay": 0.05,
        "warmup_epochs": 10,
        "num_workers": 4,
        "seed": 42,
    },
    "loss": {
        "type": "toroidal",
        "grid_size": 12,
        "lambda_std": 25.0,
        "lambda_torus": 1.0,
        "penalty_mode": "distance",
    },
    "data": {"data_dir": "./data"},
    "output": {"dir": "./checkpoints/toroidal_N12", "save_every": 50},
}

train(toroidal_config)

In [None]:
# Evaluate toroidal model
toroidal_results = evaluate(
    checkpoint_path="./checkpoints/toroidal_N12/best.pt",
    run_linear_probe=True,
    run_analysis=True,
    run_umap=True,
    grid_size=12,
    output_dir="./results/toroidal_N12",
)

## Phase 2: Ablations — Grid Size Sweep

In [None]:
# Train N=8 and N=16
for grid_size in [8, 16]:
    cfg = {
        "model": {"embed_dim": 512, "hidden_dim": 1024, "ema_decay": 0.996},
        "training": {
            "epochs": 300, "batch_size": 256, "lr": 1e-3,
            "weight_decay": 0.05, "warmup_epochs": 10,
            "num_workers": 4, "seed": 42,
        },
        "loss": {
            "type": "toroidal", "grid_size": grid_size,
            "lambda_std": 25.0, "lambda_torus": 1.0,
            "penalty_mode": "distance",
        },
        "data": {"data_dir": "./data"},
        "output": {"dir": f"./checkpoints/toroidal_N{grid_size}", "save_every": 50},
    }
    train(cfg)

## Phase 2: Lambda Sweep

In [None]:
# Lambda sweep for N=12
for lam in [0.01, 0.1, 1.0, 10.0]:
    cfg = {
        "model": {"embed_dim": 512, "hidden_dim": 1024, "ema_decay": 0.996},
        "training": {
            "epochs": 300, "batch_size": 256, "lr": 1e-3,
            "weight_decay": 0.05, "warmup_epochs": 10,
            "num_workers": 4, "seed": 42,
        },
        "loss": {
            "type": "toroidal", "grid_size": 12,
            "lambda_std": 25.0, "lambda_torus": lam,
            "penalty_mode": "distance",
        },
        "data": {"data_dir": "./data"},
        "output": {"dir": f"./checkpoints/toroidal_N12_lam{lam}", "save_every": 100},
    }
    train(cfg)

## Comparison: All Models

In [None]:
import json
from pathlib import Path

# Evaluate all checkpoints
results_all = {}
for ckpt_dir in sorted(Path("./checkpoints").iterdir()):
    best = ckpt_dir / "best.pt"
    if not best.exists():
        continue
    name = ckpt_dir.name
    print(f"\n{'='*60}")
    print(f"Evaluating: {name}")
    print(f"{'='*60}")
    res = evaluate(
        checkpoint_path=str(best),
        run_linear_probe=True,
        run_analysis=True,
        grid_size=12,
        output_dir=f"./results/{name}",
    )
    results_all[name] = res

# Summary table
print(f"\n{'='*80}")
print(f"{'Config':<30} {'Accuracy':>10} {'Eff. Rank':>10} {'Spec. Gap':>10} {'Torus Score':>12}")
print(f"{'='*80}")
for name, res in results_all.items():
    acc = res.get('linear_probe', {}).get('best_test_acc', 0)
    erank = res.get('topology', {}).get('effective_rank', 0)
    sgap = res.get('topology', {}).get('spectral_gap', 0)
    tscore = res.get('topology', {}).get('torus_score', 0)
    print(f"{name:<30} {acc:>10.4f} {erank:>10.2f} {sgap:>10.4f} {tscore:>12.4f}")

In [None]:
# Save combined results
with open("./results/summary.json", "w") as f:
    summary = {}
    for name, res in results_all.items():
        summary[name] = {
            "accuracy": res.get('linear_probe', {}).get('best_test_acc', 0),
            "effective_rank": res.get('covariance', {}).get('effective_rank', 0),
            "spectral_gap": res.get('topology', {}).get('spectral_gap', 0),
            "spectral_gap_ratio": res.get('topology', {}).get('spectral_gap_ratio', 0),
            "betti": [res.get('topology', {}).get(f'betti_{i}', 0) for i in range(3)],
            "torus_score": res.get('topology', {}).get('torus_score', 0),
            "intrinsic_dim": res.get('topology', {}).get('intrinsic_dim', 0),
        }
    json.dump(summary, f, indent=2)
print("Saved to ./results/summary.json")