# JEPA Reasoner — A100 Training on Colab

Full training pipeline for the I-JEPA cross-domain latent reasoning system.

**Before running:** Go to `Runtime → Change runtime type → A100 GPU`

In [None]:
# 1. Clone repo and install dependencies
!git clone https://github.com/akshai0296/jepa_reasoner.git
%cd jepa_reasoner
!pip install -q torch transformers sentence-transformers datasets tokenizers scikit-learn tqdm wandb openai anthropic

In [None]:
# 2. Verify GPU
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
else:
    print("WARNING: No GPU detected. Go to Runtime → Change runtime type → A100")

In [None]:
# 3. (Optional) Set API keys for LLM feedback loop
import os
# os.environ["OPENAI_API_KEY"] = "sk-..."       # for self-improvement loop
# os.environ["ANTHROPIC_API_KEY"] = "sk-ant-..." # alternative

In [None]:
# 4. Load data
import sys
sys.path.insert(0, '.')

from src.utils.data_loading import load_gsm8k, load_math_dataset, generate_synthetic_math

print("Loading GSM8K...")
gsm8k_data = load_gsm8k("train", max_samples=7000)
print(f"  GSM8K: {len(gsm8k_data)} examples")

print("Loading MATH competition dataset...")
math_data = load_math_dataset("train", max_samples=5000)
print(f"  MATH: {len(math_data)} examples")

print("Generating synthetic arithmetic...")
synth_data = generate_synthetic_math(3000)
print(f"  Synthetic: {len(synth_data)} examples")

import random
all_data = gsm8k_data + math_data + synth_data
random.shuffle(all_data)
split = int(len(all_data) * 0.9)
train_data = all_data[:split]
val_data = all_data[split:]
print(f"\nTotal: {len(train_data)} train, {len(val_data)} val")

In [None]:
# 5. Build model (full-size for A100)
from src.models import JEPAReasoner

model = JEPAReasoner(
    domain="math",
    latent_dim=768,
    predictor_type="transformer",
    predictor_kwargs={
        "hidden_dim": 1024,
        "num_layers": 6,
        "num_heads": 8,
        "dropout": 0.1,
        "use_latent_z": False,
        "num_predictor_tokens": 8,
    },
    decoder_type="scratch",
    decoder_kwargs={
        "hidden_dim": 768,
        "num_layers": 6,
        "num_heads": 8,
        "max_seq_len": 512,
        "num_latent_tokens": 8,
        "latent_dim": 768,
    },
    freeze_backbone=False,
    ema_decay=0.996,
)

total = sum(p.numel() for p in model.parameters()) / 1e6
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6
print(f"Total params: {total:.1f}M | Trainable: {trainable:.1f}M")

In [None]:
# 6. Train — Full 3-stage pipeline
from src.train import Trainer

config = {
    "domains": ["math"],
    "latent_dim": 768,
    "device": "cuda",

    # Stage 1: Predictor
    "predictor_epochs": 40,
    "lr_predictor": 1e-4,

    # Stage 2: Decoder
    "decoder_epochs": 25,
    "lr_decoder": 5e-5,

    # Stage 3: Joint
    "finetune_epochs": 15,
    "lr": 3e-5,

    # Training
    "batch_size": 32,
    "max_context_len": 256,
    "max_target_len": 256,
    "use_amp": True,
    "grad_clip": 1.0,
    "weight_decay": 0.01,

    # Loss
    "loss_type": "l2",
    "contrastive_weight": 0.1,

    "freeze_backbone": False,
    "output_dir": "checkpoints/a100_full",
}

trainer = Trainer(model, train_data, val_data, config)
results = trainer.run_full_pipeline()
print(f"\nFinal metrics: {results}")

In [None]:
# 7. Test on math problems
import torch.nn.functional as F
from src.utils.metrics import latent_space_stats, cosine_similarity_score

model.eval()
tokenizer = model.context_encoder.tokenizer
device = next(model.parameters()).device

test_problems = [
    "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells every duck egg at the farmers' market daily for $2. How much in dollars does she make every day at the farmers' market?",
    "What is 15 + 27?",
    "A store has 45 apples. If 18 are sold, how many remain?",
    "Calculate: 12 * 5 + 3",
    "If a train travels at 60 mph for 2.5 hours, how far does it go?",
]

print("=" * 70)
print("  INFERENCE RESULTS")
print("=" * 70)

all_s_y = []
for problem in test_problems:
    enc = tokenizer(problem, return_tensors="pt", padding="max_length", max_length=256, truncation=True)
    enc = {k: v.to(device) for k, v in enc.items()}

    with torch.no_grad():
        s_x = model.encode_context(enc["input_ids"], enc["attention_mask"])
        s_y_hat = model.predict(s_x)
        decoded = model.decoder.generate(s_y_hat, max_new_tokens=128, temperature=0.5)
        all_s_y.append(s_y_hat.cpu())

    print(f"\n  Q: {problem[:80]}{'...' if len(problem)>80 else ''}")
    print(f"  A: {decoded[0][:200]}")

# Latent health
s_y_cat = torch.cat(all_s_y)
stats = latent_space_stats(s_y_cat)
print(f"\nLatent health: var={stats['variance']:.2f}, avg_cos={stats['avg_pairwise_cosine']:.4f}, collapse={stats['collapse_detected']}")

In [None]:
# 8. Latent space visualization
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from src.utils.data_loading import ReasoningDataset
from torch.utils.data import DataLoader

model.eval()
sample = val_data[:200]
dataset = ReasoningDataset(sample, tokenizer, max_context_len=256, max_target_len=256)
loader = DataLoader(dataset, batch_size=32, shuffle=False)

all_sx, all_sy, all_sy_hat, difficulties = [], [], [], []
with torch.no_grad():
    for batch in loader:
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        sx = model.encode_context(batch["context_ids"], batch["context_mask"])
        sy = model.encode_target(batch["target_ids"], batch["target_mask"])
        sy_hat = model.predict(sx)
        all_sx.append(sx.cpu())
        all_sy.append(sy.cpu())
        all_sy_hat.append(sy_hat.cpu())

sx_cat = torch.cat(all_sx).numpy()
sy_cat = torch.cat(all_sy).numpy()
sy_hat_cat = torch.cat(all_sy_hat).numpy()
diffs = [s.get("difficulty", 0) for s in sample[:len(sx_cat)]]

# PCA visualization
pca = PCA(n_components=2)
combined = np.vstack([sy_cat, sy_hat_cat])
proj = pca.fit_transform(combined)
n = len(sy_cat)

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot 1: Target vs Predicted
axes[0].scatter(proj[:n, 0], proj[:n, 1], c="blue", alpha=0.5, label="Target s_y", s=20)
axes[0].scatter(proj[n:, 0], proj[n:, 1], c="red", alpha=0.5, label="Predicted ŝ_y", s=20)
for i in range(min(n, 30)):
    axes[0].plot([proj[i,0], proj[n+i,0]], [proj[i,1], proj[n+i,1]], 'gray', alpha=0.2, linewidth=0.5)
axes[0].legend()
axes[0].set_title("Target vs Predicted Latents (PCA)")

# Plot 2: By difficulty
colors = {0: "green", 1: "orange", 2: "red"}
labels = {0: "Easy", 1: "Medium", 2: "Hard"}
for d in sorted(set(diffs)):
    mask = [i for i, dd in enumerate(diffs) if dd == d]
    axes[1].scatter(proj[mask, 0], proj[mask, 1], c=colors.get(d, "gray"), alpha=0.6, label=labels.get(d, f"Diff {d}"), s=20)
axes[1].legend()
axes[1].set_title("Latent Space by Difficulty")

# Plot 3: Prediction error distribution
errors = np.linalg.norm(sy_cat - sy_hat_cat, axis=1)
axes[2].hist(errors, bins=30, color="steelblue", edgecolor="white")
axes[2].axvline(errors.mean(), color="red", linestyle="--", label=f"Mean: {errors.mean():.3f}")
axes[2].legend()
axes[2].set_title("Prediction Error Distribution")
axes[2].set_xlabel("L2 distance (s_y, ŝ_y)")

plt.tight_layout()
plt.savefig("latent_analysis.png", dpi=150)
plt.show()
print(f"Mean prediction error: {errors.mean():.4f} ± {errors.std():.4f}")
print(f"Explained variance by PCA: {pca.explained_variance_ratio_.sum()*100:.1f}%")

In [None]:
# 9. Save checkpoint to Google Drive
from google.colab import drive
drive.mount('/content/drive')

import shutil
save_dir = "/content/drive/MyDrive/jepa_reasoner_checkpoints"
os.makedirs(save_dir, exist_ok=True)
shutil.copy("checkpoints/a100_full/final.pt", f"{save_dir}/final.pt")
shutil.copy("latent_analysis.png", f"{save_dir}/latent_analysis.png")
print(f"Saved to {save_dir}")