<a href="https://colab.research.google.com/github/AlperYildirim1/geometric-grokking/blob/main/Grokking_Architectural_Analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
Grokking Experiment: Phase Attention vs Standard Attention
===========================================================
Task: Modular addition mod P=113
Setup: Evaluates 3 distinct topological architectures on the grokking threshold.
  1. Standard Transformer (Absolute PE)
  2. RoPE Transformer (Rotary PE)
  3. Phase Transformer (|z|=1 enforced, protected optimization, pure wave interference)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
import numpy as np
from tqdm.auto import tqdm

# ==========================================
# CONFIGURATION
# ==========================================
P = 113                  # Prime modulus
FRAC_TRAIN = 0.3         # 30% train, 70% test
D_MODEL = 128
NUM_HEADS = 4
MLP_DIM = 512
LR = 5e-3
WEIGHT_DECAY = 1.0       # High decay to force grokking
EPOCHS = 20000
LOG_EVERY = 50
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42

torch.set_float32_matmul_precision("high")

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# ==========================================
# DATASET GENERATION
# ==========================================
def make_dataset(p, frac_train, seed=42):
    rng = random.Random(seed)
    all_pairs = [(a, b) for a in range(p) for b in range(p)]
    rng.shuffle(all_pairs)

    n_train = int(len(all_pairs) * frac_train)
    train_pairs = all_pairs[:n_train]
    test_pairs = all_pairs[n_train:]

    def pairs_to_tensors(pairs):
        xs = torch.tensor([[a, b, p] for a, b in pairs], dtype=torch.long)
        ys = torch.tensor([(a + b) % p for a, b in pairs], dtype=torch.long)
        return xs, ys

    train_x, train_y = pairs_to_tensors(train_pairs)
    test_x, test_y = pairs_to_tensors(test_pairs)
    return train_x, train_y, test_x, test_y

# ==========================================
# BASELINE 1: STANDARD TRANSFORMER (Nanda)
# ==========================================
class StandardTransformer(nn.Module):
    def __init__(self, p, d_model, num_heads, mlp_dim):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads

        self.tok_embed = nn.Embedding(p + 1, d_model)
        self.pos_embed = nn.Embedding(3, d_model)

        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

        self.mlp_in = nn.Linear(d_model, mlp_dim)
        self.mlp_out = nn.Linear(mlp_dim, d_model)
        self.unembed = nn.Linear(d_model, p, bias=False)

    def forward(self, x):
        B, L = x.shape
        pos = torch.arange(L, device=x.device).unsqueeze(0)
        h = self.tok_embed(x) + self.pos_embed(pos)

        Q = self.W_Q(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        K = self.W_K(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        V = self.W_V(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)

        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_head)
        attn = F.softmax(scores, dim=-1)
        attn_out = self.W_O((attn @ V).transpose(1, 2).contiguous().view(B, L, self.d_model))

        h = h + attn_out
        h = h + self.mlp_out(F.relu(self.mlp_in(h)))
        return self.unembed(h[:, 2, :])

# ==========================================
# BASELINE 2: ROPE TRANSFORMER
# ==========================================
class RoPETransformer(StandardTransformer):
    def __init__(self, p, d_model, num_heads, mlp_dim):
        super().__init__(p, d_model, num_heads, mlp_dim)
        # Remove absolute pos embed, replace with RoPE frequencies
        self.pos_embed = None
        inv_freq = 1.0 / (10000 ** (torch.arange(0, self.d_head, 2).float() / self.d_head))
        self.register_buffer("inv_freq", inv_freq)

    def apply_rope(self, x, seq_len):
        t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        # Align to [1, 1, L, d_head] to broadcast across Batch and Heads
        cos = emb.cos().unsqueeze(0).unsqueeze(1)
        sin = emb.sin().unsqueeze(0).unsqueeze(1)

        x1, x2 = x.chunk(2, dim=-1)
        rotated = torch.cat((-x2, x1), dim=-1)
        return (x * cos) + (rotated * sin)

    def forward(self, x):
        B, L = x.shape
        h = self.tok_embed(x)

        Q = self.W_Q(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        K = self.W_K(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        V = self.W_V(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)

        Q = self.apply_rope(Q, L)
        K = self.apply_rope(K, L)

        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_head)
        attn = F.softmax(scores, dim=-1)
        attn_out = self.W_O((attn @ V).transpose(1, 2).contiguous().view(B, L, self.d_model))

        h = h + attn_out
        h = h + self.mlp_out(F.relu(self.mlp_in(h)))
        return self.unembed(h[:, 2, :])

# ==========================================
# THE MENACE: PHASE TRANSFORMER (|z|=1)
# ==========================================
def strictly_phase(z, eps=1e-6):
    return z / (z.abs() + eps)

class StrictComplexLinear(nn.Module):
    def __init__(self, d_in, d_out, init_scale=0.02):
        super().__init__()
        self.W_real = nn.Linear(d_in, d_out, bias=False)
        self.W_imag = nn.Linear(d_in, d_out, bias=False)
        nn.init.normal_(self.W_real.weight, std=init_scale)
        nn.init.normal_(self.W_imag.weight, std=init_scale)

    def forward(self, z):
        return torch.complex(
            self.W_real(z.real) - self.W_imag(z.imag),
            self.W_real(z.imag) + self.W_imag(z.real)
        )

class PhaseAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.W_Q = StrictComplexLinear(d_model, d_model)
        self.W_K = StrictComplexLinear(d_model, d_model)
        self.W_V = StrictComplexLinear(d_model, d_model)
        # Initialize W_O very small so Superposition starts as Identity
        self.W_O = StrictComplexLinear(d_model, d_model, init_scale=0.001)

    def forward(self, z):
        B, L, D = z.shape
        q = strictly_phase(self.W_Q(z)).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        k = strictly_phase(self.W_K(z)).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        v = strictly_phase(self.W_V(z)).view(B, L, self.num_heads, self.d_head).transpose(1, 2)

        scores = (q @ k.conj().transpose(-2, -1)).real * (self.d_head ** -0.5)
        attn_out = strictly_phase(F.softmax(scores, dim=-1).to(v.dtype) @ v)
        attn_out = attn_out.transpose(1, 2).contiguous().view(B, L, D)

        # DO NOT SNAP HERE. Return scaled vector for superposition.
        return self.W_O(attn_out)

class PhaseFFN(nn.Module):
    def __init__(self, d_model, mlp_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model * 2, mlp_dim),
            nn.ReLU(),
            nn.Linear(mlp_dim, d_model)
        )
        # STRICT IDENTITY INIT: Forces angle=0.0 at Step 0
        nn.init.zeros_(self.net[2].weight)
        nn.init.zeros_(self.net[2].bias)

    def forward(self, z):
        features = torch.cat([z.real, z.imag], dim=-1)
        angles = self.net(features)
        return torch.exp(1j * angles)

class PhaseTransformer(nn.Module):
    def __init__(self, p, d_model, num_heads, mlp_dim):
        super().__init__()
        self.tok_embed_real = nn.Embedding(p + 1, d_model)
        self.tok_embed_imag = nn.Embedding(p + 1, d_model)

        self.pos_angles = nn.Embedding(3, d_model)
        # STRICT IDENTITY INIT: Don't force random phase shifts at Step 0
        nn.init.zeros_(self.pos_angles.weight)

        self.attn = PhaseAttention(d_model, num_heads)
        self.ffn = PhaseFFN(d_model, mlp_dim)

        self.bridge = nn.Linear(d_model * 2, d_model)
        self.unembed = nn.Linear(d_model, p, bias=False)

    def forward(self, x):
        B, L = x.shape
        z = strictly_phase(torch.complex(self.tok_embed_real(x), self.tok_embed_imag(x)))

        # Positional rotation
        pos = torch.arange(L, device=x.device).unsqueeze(0)
        z = strictly_phase(z * torch.exp(1j * self.pos_angles(pos)))

        # Superposition Residual: Add, THEN snap
        z = strictly_phase(z + self.attn(z))

        # Multiplicative Residual: Rotate
        z = strictly_phase(z * self.ffn(z))

        h_real = self.bridge(torch.cat([z.real, z.imag], dim=-1))
        return self.unembed(h_real[:, 2, :])

# ==========================================
# OPTIMIZER ROUTER & TRAINING LOGIC
# ==========================================
def create_optimizer(model, lr, weight_decay, is_phase_model=False):
    if not is_phase_model:
        return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, betas=(0.9, 0.98))

    decay_params, no_decay_params = [], []
    protected_modules = ['tok_embed_real', 'tok_embed_imag', 'W_Q', 'W_K', 'W_V']

    for name, param in model.named_parameters():
        if not param.requires_grad: continue
        if any(p in name for p in protected_modules):
            no_decay_params.append(param)
        else:
            decay_params.append(param)

    return torch.optim.AdamW([
        {"params": decay_params, "weight_decay": weight_decay},
        {"params": no_decay_params, "weight_decay": 0.0},
    ], lr=lr, betas=(0.9, 0.98))

def train_model(model, name, is_phase, train_x, train_y, test_x, test_y):
    optimizer = create_optimizer(model, LR, WEIGHT_DECAY, is_phase)
    criterion = nn.CrossEntropyLoss()

    train_x, train_y = train_x.to(DEVICE), train_y.to(DEVICE)
    test_x, test_y = test_x.to(DEVICE), test_y.to(DEVICE)

    grok_epoch = None
    pbar = tqdm(range(EPOCHS), desc=f"Training {name}")

    for epoch in pbar:
        model.train()
        logits = model(train_x)
        loss = criterion(logits, train_y)

        optimizer.zero_grad()
        loss.backward()

        # Live Gradient Tracking
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf')).item()
        optimizer.step()

        if epoch % LOG_EVERY == 0 or epoch == EPOCHS - 1:
            model.eval()
            with torch.no_grad():
                train_acc = (logits.argmax(-1) == train_y).float().mean().item()
                test_logits = model(test_x)
                test_loss = criterion(test_logits, test_y).item()
                test_acc = (test_logits.argmax(-1) == test_y).float().mean().item()

            pbar.set_postfix({
                'tr_acc': f"{train_acc:.3f}",
                'te_acc': f"{test_acc:.3f}",
                'te_loss': f"{test_loss:.3f}",
                '|g|': f"{grad_norm:.3f}"
            })

            if grok_epoch is None and test_acc > 0.95:
                grok_epoch = epoch
                tqdm.write(f"\n⚡ {name} GROKKED at epoch {epoch}! (test_acc={test_acc:.4f})")

    return grok_epoch if grok_epoch else ">20000"

# ==========================================
# EXECUTION
# ==========================================
if __name__ == "__main__":
    set_seed(SEED)
    print("=" * 60)
    print("GROKKING LITMUS TEST: Magnitude vs. Phase Physics")
    print(f"Modulus P={P} | d_model={D_MODEL} | wd={WEIGHT_DECAY} | Device={DEVICE}")
    print("=" * 60)

    tr_x, tr_y, te_x, te_y = make_dataset(P, FRAC_TRAIN, seed=SEED)

    models = [
        ("Standard Transformer", StandardTransformer(P, D_MODEL, NUM_HEADS, MLP_DIM), False),
        ("RoPE Transformer", RoPETransformer(P, D_MODEL, NUM_HEADS, MLP_DIM), False),
        ("Phase Transformer", PhaseTransformer(P, D_MODEL, NUM_HEADS, MLP_DIM), True)
    ]

    results = {}
    for name, model, is_phase in models:
        model = model.to(DEVICE)
        set_seed(SEED) # Reset seed before each run for fair init
        results[name] = train_model(model, name, is_phase, tr_x, tr_y, te_x, te_y)

    print("\n" + "=" * 60)
    print(f"{'MODEL TOPOLOGY':<25} | {'EPOCHS TO GROK (>95% Acc)'}")
    print("-" * 60)
    for name, grok_ep in results.items():
        print(f"{name:<25} | {grok_ep}")
    print("=" * 60)

In [None]:
import matplotlib.pyplot as plt

# Create a 2x2 grid for deep mechanistic comparison
fig, axs = plt.subplots(2, 2, figsize=(16, 12))

# Extract histories
std_hist = results["Standard Transformer"]["history"]
rope_hist = results["RoPE Transformer"]["history"]
phase_hist = results["Phase Transformer"]["history"]

# ==========================================
# PLOT 1: The Standard Memorization Trap
# ==========================================
axs[0, 0].plot(std_hist["epochs"], std_hist["train_acc"], label="Train Acc", color="#1f77b4", linewidth=2)
axs[0, 0].plot(std_hist["epochs"], std_hist["test_acc"], label="Test Acc", color="#d62728", linewidth=2)
axs[0, 0].fill_between(std_hist["epochs"], std_hist["train_acc"], std_hist["test_acc"], color='red', alpha=0.1)
axs[0, 0].set_title("Standard Transformer (Abs PE)\nNotice the massive memorization gap", fontsize=14, pad=10)
axs[0, 0].set_ylabel("Accuracy", fontsize=12)
axs[0, 0].grid(True, alpha=0.3)
axs[0, 0].legend(loc="lower right")

# ==========================================
# PLOT 2: The Phase Physics Bypass
# ==========================================
axs[0, 1].plot(phase_hist["epochs"], phase_hist["train_acc"], label="Train Acc", color="#1f77b4", linewidth=2)
axs[0, 1].plot(phase_hist["epochs"], phase_hist["test_acc"], label="Test Acc", color="#2ca02c", linewidth=2)
axs[0, 1].fill_between(phase_hist["epochs"], phase_hist["train_acc"], phase_hist["test_acc"], color='green', alpha=0.1)
axs[0, 1].set_title("Phase Transformer (|z|=1)\nMemorization is physically blocked", fontsize=14, pad=10)
axs[0, 1].set_ylabel("Accuracy", fontsize=12)
axs[0, 1].grid(True, alpha=0.3)
axs[0, 1].legend(loc="lower right")

# ==========================================
# PLOT 3: Head-to-Head Test Accuracy
# ==========================================
axs[1, 0].plot(std_hist["epochs"], std_hist["test_acc"], label="Standard (Abs PE)", color="#d62728", linestyle="--", alpha=0.8)
axs[1, 0].plot(rope_hist["epochs"], rope_hist["test_acc"], label="RoPE Baseline", color="#ff7f0e", linewidth=2)
axs[1, 0].plot(phase_hist["epochs"], phase_hist["test_acc"], label="Phase (|z|=1)", color="#2ca02c", linewidth=2.5)
axs[1, 0].set_title("Test Accuracy Comparison", fontsize=14, pad=10)
axs[1, 0].set_xlabel("Training Epochs", fontsize=12)
axs[1, 0].set_ylabel("Test Accuracy", fontsize=12)
axs[1, 0].grid(True, alpha=0.3)
axs[1, 0].legend(loc="lower right")

# ==========================================
# PLOT 4: Head-to-Head Test Loss (Log Scale)
# ==========================================
axs[1, 1].plot(std_hist["epochs"], std_hist["test_loss"], label="Standard", color="#d62728", linestyle="--", alpha=0.8)
axs[1, 1].plot(rope_hist["epochs"], rope_hist["test_loss"], label="RoPE", color="#ff7f0e", linewidth=2)
axs[1, 1].plot(phase_hist["epochs"], phase_hist["test_loss"], label="Phase", color="#2ca02c", linewidth=2.5)
axs[1, 1].set_title("Test Loss Trajectory (Log Scale)", fontsize=14, pad=10)
axs[1, 1].set_xlabel("Training Epochs", fontsize=12)
axs[1, 1].set_ylabel("Cross Entropy Loss", fontsize=12)
axs[1, 1].set_yscale("log")
axs[1, 1].grid(True, alpha=0.3)
axs[1, 1].legend(loc="upper right")

plt.tight_layout()
plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

# ==========================================
# CONFIGURATION
# ==========================================
P = 113
FRAC_TRAIN = 0.3
D_MODEL = 128
NUM_HEADS = 4
MLP_DIM = 512
LR = 1e-3                # STRICTLY NANDA'S LR
WEIGHT_DECAY = 1.0       # STRICTLY NANDA'S DECAY
EPOCHS = 35000           # Increased slightly to give 1e-3 Standard time to grok
LOG_EVERY = 100
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42

torch.set_float32_matmul_precision("high")

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# ==========================================
# DATASET GENERATION (FULL BATCH)
# ==========================================
def make_dataset(p, frac_train, seed=42):
    rng = random.Random(seed)
    all_pairs = [(a, b) for a in range(p) for b in range(p)]
    rng.shuffle(all_pairs)

    n_train = int(len(all_pairs) * frac_train)
    train_pairs = all_pairs[:n_train]
    test_pairs = all_pairs[n_train:]

    def pairs_to_tensors(pairs):
        xs = torch.tensor([[a, b, p] for a, b in pairs], dtype=torch.long)
        ys = torch.tensor([(a + b) % p for a, b in pairs], dtype=torch.long)
        return xs, ys

    train_x, train_y = pairs_to_tensors(train_pairs)
    test_x, test_y = pairs_to_tensors(test_pairs)
    return train_x, train_y, test_x, test_y

# ==========================================
# BASELINE: STANDARD TRANSFORMER (Nanda)
# ==========================================
class StandardTransformer(nn.Module):
    def __init__(self, p, d_model, num_heads, mlp_dim):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads

        self.tok_embed = nn.Embedding(p + 1, d_model)
        self.pos_embed = nn.Embedding(3, d_model) # ABSOLUTE PE

        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

        self.mlp_in = nn.Linear(d_model, mlp_dim)
        self.mlp_out = nn.Linear(mlp_dim, d_model)
        self.unembed = nn.Linear(d_model, p, bias=False)

    def forward(self, x):
        B, L = x.shape
        pos = torch.arange(L, device=x.device).unsqueeze(0)
        h = self.tok_embed(x) + self.pos_embed(pos)

        Q = self.W_Q(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        K = self.W_K(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        V = self.W_V(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)

        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_head)
        attn_out = self.W_O((F.softmax(scores, dim=-1) @ V).transpose(1, 2).contiguous().view(B, L, self.d_model))

        h = h + attn_out
        h = h + self.mlp_out(F.relu(self.mlp_in(h)))
        return self.unembed(h[:, 2, :])

# ==========================================
# THE MENACE: PHASE TRANSFORMER (|z|=1)
# ==========================================
def strictly_phase(z, eps=1e-6):
    return z / (z.abs() + eps)

class StrictComplexLinear(nn.Module):
    def __init__(self, d_in, d_out, init_scale=0.02):
        super().__init__()
        self.W_real = nn.Linear(d_in, d_out, bias=False)
        self.W_imag = nn.Linear(d_in, d_out, bias=False)
        nn.init.normal_(self.W_real.weight, std=init_scale)
        nn.init.normal_(self.W_imag.weight, std=init_scale)

    def forward(self, z):
        return torch.complex(
            self.W_real(z.real) - self.W_imag(z.imag),
            self.W_real(z.imag) + self.W_imag(z.real)
        )

class PhaseAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.W_Q = StrictComplexLinear(d_model, d_model)
        self.W_K = StrictComplexLinear(d_model, d_model)
        self.W_V = StrictComplexLinear(d_model, d_model)
        self.W_O = StrictComplexLinear(d_model, d_model, init_scale=0.001)

    def forward(self, z):
        B, L, D = z.shape
        q = strictly_phase(self.W_Q(z)).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        k = strictly_phase(self.W_K(z)).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        v = strictly_phase(self.W_V(z)).view(B, L, self.num_heads, self.d_head).transpose(1, 2)

        scores = (q @ k.conj().transpose(-2, -1)).real * (self.d_head ** -0.5)
        attn_out = strictly_phase(F.softmax(scores, dim=-1).to(v.dtype) @ v)
        attn_out = attn_out.transpose(1, 2).contiguous().view(B, L, D)
        return self.W_O(attn_out)

class PhaseFFN(nn.Module):
    def __init__(self, d_model, mlp_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model * 2, mlp_dim),
            nn.ReLU(),
            nn.Linear(mlp_dim, d_model)
        )
        nn.init.zeros_(self.net[2].weight)
        nn.init.zeros_(self.net[2].bias)

    def forward(self, z):
        features = torch.cat([z.real, z.imag], dim=-1)
        angles = self.net(features)
        return torch.exp(1j * angles)

class PhaseTransformer(nn.Module):
    def __init__(self, p, d_model, num_heads, mlp_dim):
        super().__init__()
        self.tok_embed_real = nn.Embedding(p + 1, d_model)
        self.tok_embed_imag = nn.Embedding(p + 1, d_model)

        self.pos_angles = nn.Embedding(3, d_model) # ABSOLUTE PE (Phase Eq)
        nn.init.zeros_(self.pos_angles.weight)

        self.attn = PhaseAttention(d_model, num_heads)
        self.ffn = PhaseFFN(d_model, mlp_dim)

        self.bridge = nn.Linear(d_model * 2, d_model)
        self.unembed = nn.Linear(d_model, p, bias=False)

    def forward(self, x):
        B, L = x.shape
        z = strictly_phase(torch.complex(self.tok_embed_real(x), self.tok_embed_imag(x)))

        pos = torch.arange(L, device=x.device).unsqueeze(0)
        z = strictly_phase(z * torch.exp(1j * self.pos_angles(pos)))

        z = strictly_phase(z + self.attn(z))
        z = strictly_phase(z * self.ffn(z))

        h_real = self.bridge(torch.cat([z.real, z.imag], dim=-1))
        return self.unembed(h_real[:, 2, :])

# ==========================================
# OPTIMIZER ROUTER & TRAINING LOGIC
# ==========================================
def create_optimizer(model, lr, weight_decay, is_phase_model=False):
    if not is_phase_model:
        return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, betas=(0.9, 0.98))

    decay_params, no_decay_params = [], []
    protected_modules = ['tok_embed_real', 'tok_embed_imag', 'W_Q', 'W_K', 'W_V']

    for name, param in model.named_parameters():
        if not param.requires_grad: continue
        if any(p in name for p in protected_modules):
            no_decay_params.append(param)
        else:
            decay_params.append(param)

    return torch.optim.AdamW([
        {"params": decay_params, "weight_decay": weight_decay},
        {"params": no_decay_params, "weight_decay": 0.0},
    ], lr=lr, betas=(0.9, 0.98))

def train_model(model, name, is_phase, train_x, train_y, test_x, test_y):
    optimizer = create_optimizer(model, LR, WEIGHT_DECAY, is_phase)
    criterion = nn.CrossEntropyLoss()

    train_x, train_y = train_x.to(DEVICE), train_y.to(DEVICE)
    test_x, test_y = test_x.to(DEVICE), test_y.to(DEVICE)

    grok_epoch = None
    history = {'epochs': [], 'train_acc': [], 'test_acc': [], 'train_loss': [], 'test_loss': []}

    pbar = tqdm(range(EPOCHS), desc=f"Training {name}")

    for epoch in pbar:
        model.train()
        logits = model(train_x)
        loss = criterion(logits, train_y)

        optimizer.zero_grad()
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf')).item()
        optimizer.step()

        if epoch % LOG_EVERY == 0 or epoch == EPOCHS - 1:
            model.eval()
            with torch.no_grad():
                train_acc = (logits.argmax(-1) == train_y).float().mean().item()
                test_logits = model(test_x)
                test_loss = criterion(test_logits, test_y).item()
                test_acc = (test_logits.argmax(-1) == test_y).float().mean().item()

            history['epochs'].append(epoch)
            history['train_acc'].append(train_acc)
            history['test_acc'].append(test_acc)
            history['train_loss'].append(loss.item())
            history['test_loss'].append(test_loss)

            pbar.set_postfix({
                'tr_acc': f"{train_acc:.3f}",
                'te_acc': f"{test_acc:.3f}",
                'te_loss': f"{test_loss:.3f}",
                '|g|': f"{grad_norm:.3f}"
            })

            if grok_epoch is None and test_acc > 0.95:
                grok_epoch = epoch
                tqdm.write(f"\n⚡ {name} GROKKED at epoch {epoch}! (test_acc={test_acc:.4f})")

    return {"grok_epoch": grok_epoch if grok_epoch else ">35000", "history": history}

# ==========================================
# EXECUTION & PLOTTING
# ==========================================
if __name__ == "__main__":
    set_seed(SEED)
    print("=" * 60)
    print("GROKKING LITMUS TEST: Standard vs. Phase Physics")
    print(f"Modulus P={P} | d_model={D_MODEL} | wd={WEIGHT_DECAY} | lr={LR}")
    print("=" * 60)

    tr_x, tr_y, te_x, te_y = make_dataset(P, FRAC_TRAIN, seed=SEED)

    models = [
        ("Standard Transformer", StandardTransformer(P, D_MODEL, NUM_HEADS, MLP_DIM), False),
        ("Phase Transformer", PhaseTransformer(P, D_MODEL, NUM_HEADS, MLP_DIM), True)
    ]

    results = {}
    for name, model, is_phase in models:
        model = model.to(DEVICE)
        set_seed(SEED) # Reset seed before each run for fair init
        results[name] = train_model(model, name, is_phase, tr_x, tr_y, te_x, te_y)

    print("\n" + "=" * 60)
    print(f"{'MODEL TOPOLOGY':<25} | {'EPOCHS TO GROK (>95% Acc)'}")
    print("-" * 60)
    for name, res in results.items():
        print(f"{name:<25} | {res['grok_epoch']}")
    print("=" * 60)

    # --- THE PLOT ---
    fig, axs = plt.subplots(2, 2, figsize=(16, 12))
    std_hist = results["Standard Transformer"]["history"]
    phase_hist = results["Phase Transformer"]["history"]

    # PLOT 1: Standard Memorization Trap
    axs[0, 0].plot(std_hist["epochs"], std_hist["train_acc"], label="Train Acc", color="#1f77b4", linewidth=2)
    axs[0, 0].plot(std_hist["epochs"], std_hist["test_acc"], label="Test Acc", color="#d62728", linewidth=2)
    axs[0, 0].fill_between(std_hist["epochs"], std_hist["train_acc"], std_hist["test_acc"], color='red', alpha=0.1)
    axs[0, 0].set_title("Standard Transformer (Abs PE)\nThe Memorization Trap", fontsize=14, pad=10)
    axs[0, 0].set_ylabel("Accuracy", fontsize=12)
    axs[0, 0].grid(True, alpha=0.3)
    axs[0, 0].legend(loc="lower right")

    # PLOT 2: Phase Physics Bypass
    axs[0, 1].plot(phase_hist["epochs"], phase_hist["train_acc"], label="Train Acc", color="#1f77b4", linewidth=2)
    axs[0, 1].plot(phase_hist["epochs"], phase_hist["test_acc"], label="Test Acc", color="#2ca02c", linewidth=2)
    axs[0, 1].fill_between(phase_hist["epochs"], phase_hist["train_acc"], phase_hist["test_acc"], color='green', alpha=0.1)
    axs[0, 1].set_title("Phase Transformer (|z|=1)\nMemorization Physically Blocked", fontsize=14, pad=10)
    axs[0, 1].set_ylabel("Accuracy", fontsize=12)
    axs[0, 1].grid(True, alpha=0.3)
    axs[0, 1].legend(loc="lower right")

    # PLOT 3: Test Accuracy
    axs[1, 0].plot(std_hist["epochs"], std_hist["test_acc"], label="Standard (Abs PE)", color="#d62728", linestyle="--", alpha=0.8)
    axs[1, 0].plot(phase_hist["epochs"], phase_hist["test_acc"], label="Phase (|z|=1)", color="#2ca02c", linewidth=2.5)
    axs[1, 0].set_title("Test Accuracy Comparison", fontsize=14, pad=10)
    axs[1, 0].set_xlabel("Training Epochs", fontsize=12)
    axs[1, 0].set_ylabel("Test Accuracy", fontsize=12)
    axs[1, 0].grid(True, alpha=0.3)
    axs[1, 0].legend(loc="lower right")

    # PLOT 4: Test Loss
    axs[1, 1].plot(std_hist["epochs"], std_hist["test_loss"], label="Standard", color="#d62728", linestyle="--", alpha=0.8)
    axs[1, 1].plot(phase_hist["epochs"], phase_hist["test_loss"], label="Phase", color="#2ca02c", linewidth=2.5)
    axs[1, 1].set_title("Test Loss Trajectory (Log Scale)", fontsize=14, pad=10)
    axs[1, 1].set_xlabel("Training Epochs", fontsize=12)
    axs[1, 1].set_ylabel("Cross Entropy Loss", fontsize=12)
    axs[1, 1].set_yscale("log")
    axs[1, 1].grid(True, alpha=0.3)
    axs[1, 1].legend(loc="upper right")

    plt.tight_layout()
    plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

# ==========================================
# CONFIGURATION
# ==========================================
P = 113
FRAC_TRAIN = 0.3
D_MODEL = 128
NUM_HEADS = 4
MLP_DIM = 512
LR = 1e-3
WEIGHT_DECAY = 1.0
EPOCHS = 40000           # Pushed slightly to ensure standard grokking finishes
LOG_EVERY = 200
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42                # Seed 42 usually groks around 25k-35k with correct betas

torch.set_float32_matmul_precision("high")

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# ==========================================
# DATASET GENERATION
# ==========================================
def make_dataset(p, frac_train, seed=42):
    rng = random.Random(seed)
    all_pairs = [(a, b) for a in range(p) for b in range(p)]
    rng.shuffle(all_pairs)

    n_train = int(len(all_pairs) * frac_train)
    train_x = torch.tensor([[a, b, p] for a, b in all_pairs[:n_train]], dtype=torch.long)
    train_y = torch.tensor([(a + b) % p for a, b in all_pairs[:n_train]], dtype=torch.long)
    test_x = torch.tensor([[a, b, p] for a, b in all_pairs[n_train:]], dtype=torch.long)
    test_y = torch.tensor([(a + b) % p for a, b in all_pairs[n_train:]], dtype=torch.long)
    return train_x, train_y, test_x, test_y

# ==========================================
# STANDARD TRANSFORMER
# ==========================================
class StandardTransformer(nn.Module):
    def __init__(self, p, d_model, num_heads, mlp_dim, normalize_hiddens=False):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.normalize_hiddens = normalize_hiddens # Apple-to-apples phase trick

        self.tok_embed = nn.Embedding(p + 1, d_model)
        self.pos_embed = nn.Embedding(3, d_model)

        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

        self.mlp_in = nn.Linear(d_model, mlp_dim)
        self.mlp_out = nn.Linear(mlp_dim, d_model)
        self.unembed = nn.Linear(d_model, p, bias=False)

    def forward(self, x):
        B, L = x.shape
        pos = torch.arange(L, device=x.device).unsqueeze(0)
        h = self.tok_embed(x) + self.pos_embed(pos)

        if self.normalize_hiddens:
            h = F.normalize(h, dim=-1) # Bounds vectors to the unit hypersphere

        Q = self.W_Q(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        K = self.W_K(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        V = self.W_V(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)

        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_head)
        attn_out = self.W_O((F.softmax(scores, dim=-1) @ V).transpose(1, 2).contiguous().view(B, L, self.d_model))

        h = h + attn_out
        if self.normalize_hiddens:
            h = F.normalize(h, dim=-1)

        h = h + self.mlp_out(F.relu(self.mlp_in(h)))
        if self.normalize_hiddens:
            h = F.normalize(h, dim=-1)

        return self.unembed(h[:, 2, :])

# ==========================================
# FOURIER INITIALIZATION (The Trojan Horse)
# ==========================================
def inject_fourier_bias(model, p):
    """
    Injects the 5 key frequencies found by Nanda et al. directly into the
    embedding matrix at initialization. The architecture remains exactly the same.
    """
    with torch.no_grad():
        key_freqs = [14, 35, 41, 42, 52]
        # We overwrite the first 10 dimensions with perfect sines and cosines
        for i, k in enumerate(key_freqs):
            for x in range(p):
                val = 2 * math.pi * k * x / p
                model.tok_embed.weight[x, 2 * i] = math.cos(val)
                model.tok_embed.weight[x, 2 * i + 1] = math.sin(val)

# ==========================================
# TRAINING LOGIC
# ==========================================
def train_model(model, name, train_x, train_y, test_x, test_y):
    # Notice we reverted betas to default (0.9, 0.999). 0.98 kills momentum late in training!
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    criterion = nn.CrossEntropyLoss()

    train_x, train_y = train_x.to(DEVICE), train_y.to(DEVICE)
    test_x, test_y = test_x.to(DEVICE), test_y.to(DEVICE)

    grok_epoch = None
    history = {'epochs': [], 'train_acc': [], 'test_acc': [], 'train_loss': [], 'test_loss': []}

    pbar = tqdm(range(EPOCHS), desc=f"Training {name}")

    for epoch in pbar:
        model.train()
        logits = model(train_x)
        loss = criterion(logits, train_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % LOG_EVERY == 0 or epoch == EPOCHS - 1:
            model.eval()
            with torch.no_grad():
                train_acc = (logits.argmax(-1) == train_y).float().mean().item()
                test_logits = model(test_x)
                test_loss = criterion(test_logits, test_y).item()
                test_acc = (test_logits.argmax(-1) == test_y).float().mean().item()

            history['epochs'].append(epoch)
            history['train_acc'].append(train_acc)
            history['test_acc'].append(test_acc)
            history['train_loss'].append(loss.item())
            history['test_loss'].append(test_loss)

            pbar.set_postfix({
                'tr_acc': f"{train_acc:.3f}",
                'te_acc': f"{test_acc:.3f}",
                'te_loss': f"{test_loss:.3f}"
            })

            if grok_epoch is None and test_acc > 0.95:
                grok_epoch = epoch
                tqdm.write(f"\n⚡ {name} generalized at epoch {epoch}! (test_acc={test_acc:.4f})")

    return {"grok_epoch": grok_epoch if grok_epoch else ">40000", "history": history}

# ==========================================
# EXECUTION & PLOTTING
# ==========================================
if __name__ == "__main__":
    print("=" * 60)
    print("GROKKING: Inductive Bias vs FLOP-Matched Baselines")
    print("=" * 60)

    tr_x, tr_y, te_x, te_y = make_dataset(P, FRAC_TRAIN, seed=SEED)

    # 1. Standard Model
    set_seed(SEED)
    std_model = StandardTransformer(P, D_MODEL, NUM_HEADS, MLP_DIM).to(DEVICE)
    res_std = train_model(std_model, "Standard", tr_x, tr_y, te_x, te_y)

    # 2. Fourier Init Model (Identical FLOPs, just given the circle at step 1)
    set_seed(SEED)
    fourier_model = StandardTransformer(P, D_MODEL, NUM_HEADS, MLP_DIM).to(DEVICE)
    inject_fourier_bias(fourier_model, P)
    res_fourier = train_model(fourier_model, "Fourier Init", tr_x, tr_y, te_x, te_y)

    # 3. Spherical Model (Your Phase block idea, but generalized to L2 norm, same FLOPs)
    set_seed(SEED)
    sphere_model = StandardTransformer(P, D_MODEL, NUM_HEADS, MLP_DIM, normalize_hiddens=True).to(DEVICE)
    res_sphere = train_model(sphere_model, "Spherical Norm", tr_x, tr_y, te_x, te_y)

    # --- PLOTTING ---
    fig, axs = plt.subplots(1, 3, figsize=(18, 5))

    for i, (name, res) in enumerate(zip(["Standard", "Fourier Init", "Spherical Norm"],
                                        [res_std, res_fourier, res_sphere])):
        hist = res["history"]
        axs[i].plot(hist["epochs"], hist["train_acc"], label="Train Acc", color="#1f77b4", linewidth=2)
        axs[i].plot(hist["epochs"], hist["test_acc"], label="Test Acc", color="#d62728", linewidth=2)
        axs[i].fill_between(hist["epochs"], hist["train_acc"], hist["test_acc"], color='red', alpha=0.1)
        axs[i].set_title(f"{name} Transformer\nGeneralized at: {res['grok_epoch']}", fontsize=12)
        axs[i].set_xlabel("Epochs")
        axs[i].set_ylabel("Accuracy")
        axs[i].grid(True, alpha=0.3)
        axs[i].legend()

    plt.tight_layout()
    plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

# ==========================================
# CONFIGURATION
# ==========================================
P = 113
FRAC_TRAIN = 0.3
D_MODEL = 128
NUM_HEADS = 4
MLP_DIM = 512
LR = 1e-3
WEIGHT_DECAY = 1.0
EPOCHS = 40000
LOG_EVERY = 200
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42

torch.set_float32_matmul_precision("high")

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# ==========================================
# DATASET GENERATION
# ==========================================
def make_dataset(p, frac_train, seed=42):
    rng = random.Random(seed)
    all_pairs = [(a, b) for a in range(p) for b in range(p)]
    rng.shuffle(all_pairs)

    n_train = int(len(all_pairs) * frac_train)
    train_x = torch.tensor([[a, b, p] for a, b in all_pairs[:n_train]], dtype=torch.long)
    train_y = torch.tensor([(a + b) % p for a, b in all_pairs[:n_train]], dtype=torch.long)
    test_x = torch.tensor([[a, b, p] for a, b in all_pairs[n_train:]], dtype=torch.long)
    test_y = torch.tensor([(a + b) % p for a, b in all_pairs[n_train:]], dtype=torch.long)
    return train_x, train_y, test_x, test_y

# ==========================================
# 1. STANDARD & SPHERICAL TRANSFORMER
# ==========================================
class StandardTransformer(nn.Module):
    def __init__(self, p, d_model, num_heads, mlp_dim, normalize_hiddens=False):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.normalize_hiddens = normalize_hiddens

        self.tok_embed = nn.Embedding(p + 1, d_model)
        self.pos_embed = nn.Embedding(3, d_model)

        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

        self.mlp_in = nn.Linear(d_model, mlp_dim)
        self.mlp_out = nn.Linear(mlp_dim, d_model)
        self.unembed = nn.Linear(d_model, p, bias=False)

    def forward(self, x):
        B, L = x.shape
        pos = torch.arange(L, device=x.device).unsqueeze(0)
        h = self.tok_embed(x) + self.pos_embed(pos)

        if self.normalize_hiddens: h = F.normalize(h, dim=-1)

        Q = self.W_Q(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        K = self.W_K(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        V = self.W_V(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)

        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_head)
        attn_out = self.W_O((F.softmax(scores, dim=-1) @ V).transpose(1, 2).contiguous().view(B, L, self.d_model))

        h = h + attn_out
        if self.normalize_hiddens: h = F.normalize(h, dim=-1)

        h = h + self.mlp_out(F.relu(self.mlp_in(h)))
        if self.normalize_hiddens: h = F.normalize(h, dim=-1)

        return self.unembed(h[:, 2, :])

def inject_fourier_bias(model, p):
    with torch.no_grad():
        key_freqs = [14, 35, 41, 42, 52]
        for i, k in enumerate(key_freqs):
            for x in range(p):
                val = 2 * math.pi * k * x / p
                model.tok_embed.weight[x, 2 * i] = math.cos(val)
                model.tok_embed.weight[x, 2 * i + 1] = math.sin(val)

# ==========================================
# 2. THE STRICT PHASE TRANSFORMER (|z|=1)
# ==========================================
def strictly_phase(z, eps=1e-6):
    return z / (z.abs() + eps)

class StrictComplexLinear(nn.Module):
    def __init__(self, d_in, d_out, init_scale=0.02):
        super().__init__()
        self.W_real = nn.Linear(d_in, d_out, bias=False)
        self.W_imag = nn.Linear(d_in, d_out, bias=False)
        nn.init.normal_(self.W_real.weight, std=init_scale)
        nn.init.normal_(self.W_imag.weight, std=init_scale)

    def forward(self, z):
        return torch.complex(
            self.W_real(z.real) - self.W_imag(z.imag),
            self.W_real(z.imag) + self.W_imag(z.real)
        )

class PhaseAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.W_Q = StrictComplexLinear(d_model, d_model)
        self.W_K = StrictComplexLinear(d_model, d_model)
        self.W_V = StrictComplexLinear(d_model, d_model)
        self.W_O = StrictComplexLinear(d_model, d_model, init_scale=0.001)

    def forward(self, z):
        B, L, D = z.shape
        q = strictly_phase(self.W_Q(z)).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        k = strictly_phase(self.W_K(z)).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        v = strictly_phase(self.W_V(z)).view(B, L, self.num_heads, self.d_head).transpose(1, 2)

        # Real part of complex dot product is exact cosine similarity on the unit circle
        scores = (q @ k.conj().transpose(-2, -1)).real * (self.d_head ** -0.5)
        attn_out = strictly_phase(F.softmax(scores, dim=-1).to(v.dtype) @ v)
        attn_out = attn_out.transpose(1, 2).contiguous().view(B, L, D)
        return self.W_O(attn_out)

class PhaseFFN(nn.Module):
    def __init__(self, d_model, mlp_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model * 2, mlp_dim),
            nn.ReLU(),
            nn.Linear(mlp_dim, d_model)
        )
        nn.init.zeros_(self.net[2].weight)
        nn.init.zeros_(self.net[2].bias)

    def forward(self, z):
        features = torch.cat([z.real, z.imag], dim=-1)
        angles = self.net(features)
        return torch.exp(1j * angles)

class PhaseTransformer(nn.Module):
    def __init__(self, p, d_model, num_heads, mlp_dim):
        super().__init__()
        self.tok_embed_real = nn.Embedding(p + 1, d_model)
        self.tok_embed_imag = nn.Embedding(p + 1, d_model)
        self.pos_angles = nn.Embedding(3, d_model)
        nn.init.zeros_(self.pos_angles.weight)

        self.attn = PhaseAttention(d_model, num_heads)
        self.ffn = PhaseFFN(d_model, mlp_dim)

        self.bridge = nn.Linear(d_model * 2, d_model)
        self.unembed = nn.Linear(d_model, p, bias=False)

    def forward(self, x):
        B, L = x.shape
        z = strictly_phase(torch.complex(self.tok_embed_real(x), self.tok_embed_imag(x)))
        pos = torch.arange(L, device=x.device).unsqueeze(0)
        z = strictly_phase(z * torch.exp(1j * self.pos_angles(pos)))

        z = strictly_phase(z + self.attn(z))
        z = strictly_phase(z * self.ffn(z))

        h_real = self.bridge(torch.cat([z.real, z.imag], dim=-1))
        return self.unembed(h_real[:, 2, :])

# ==========================================
# TRAINING LOGIC
# ==========================================
def train_model(model, name, train_x, train_y, test_x, test_y):
    # Standard AdamW used uniformly for fair landscape comparison
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.999))
    criterion = nn.CrossEntropyLoss()

    train_x, train_y = train_x.to(DEVICE), train_y.to(DEVICE)
    test_x, test_y = test_x.to(DEVICE), test_y.to(DEVICE)

    grok_epoch = None
    history = {'epochs': [], 'train_acc': [], 'test_acc': [], 'train_loss': [], 'test_loss': []}

    pbar = tqdm(range(EPOCHS), desc=f"Training {name}")

    for epoch in pbar:
        model.train()
        logits = model(train_x)
        loss = criterion(logits, train_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % LOG_EVERY == 0 or epoch == EPOCHS - 1:
            model.eval()
            with torch.no_grad():
                train_acc = (logits.argmax(-1) == train_y).float().mean().item()
                test_logits = model(test_x)
                test_loss = criterion(test_logits, test_y).item()
                test_acc = (test_logits.argmax(-1) == test_y).float().mean().item()

            history['epochs'].append(epoch)
            history['train_acc'].append(train_acc)
            history['test_acc'].append(test_acc)
            history['train_loss'].append(loss.item())
            history['test_loss'].append(test_loss)

            pbar.set_postfix({
                'tr_acc': f"{train_acc:.3f}",
                'te_acc': f"{test_acc:.3f}"
            })

            if grok_epoch is None and test_acc > 0.95:
                grok_epoch = epoch
                tqdm.write(f"\n⚡ {name} generalized at epoch {epoch}! (test_acc={test_acc:.4f})")

    return {"grok_epoch": grok_epoch if grok_epoch else f">{EPOCHS}", "history": history}

# ==========================================
# EXECUTION & PLOTTING (2x2 Grid)
# ==========================================
if __name__ == "__main__":
    print("=" * 60)
    print("GROKKING LITMUS TEST: The 4 Quadrants of Generalization")
    print("=" * 60)

    tr_x, tr_y, te_x, te_y = make_dataset(P, FRAC_TRAIN, seed=SEED)

    models = []

    # 1. Standard Model (Baseline)
    set_seed(SEED)
    models.append(("Standard", StandardTransformer(P, D_MODEL, NUM_HEADS, MLP_DIM, False)))

    # 2. Fourier Init Model (Treasure Map)
    set_seed(SEED)
    f_model = StandardTransformer(P, D_MODEL, NUM_HEADS, MLP_DIM, False)
    inject_fourier_bias(f_model, P)
    models.append(("Fourier Init", f_model))

    # 3. Spherical Norm (L2 Straitjacket)
    set_seed(SEED)
    models.append(("Spherical Norm", StandardTransformer(P, D_MODEL, NUM_HEADS, MLP_DIM, True)))

    # 4. Phase Transformer (Complex Geometry Straitjacket)
    set_seed(SEED)
    models.append(("Strict Phase (|z|=1)", PhaseTransformer(P, D_MODEL, NUM_HEADS, MLP_DIM)))

    results = {}
    for name, model in models:
        model = model.to(DEVICE)
        results[name] = train_model(model, name, tr_x, tr_y, te_x, te_y)

    # --- PLOTTING 2x2 GRID ---
    fig, axs = plt.subplots(2, 2, figsize=(16, 12))
    axs = axs.flatten()

    for i, (name, res) in enumerate(results.items()):
        hist = res["history"]

        # Determine if it "grokked" (has a massive train/test gap) or "just learned" (lines hug each other)
        # We shade the gap red to highlight the memorization phase.
        axs[i].plot(hist["epochs"], hist["train_acc"], label="Train Acc", color="#1f77b4", linewidth=2.5)
        axs[i].plot(hist["epochs"], hist["test_acc"], label="Test Acc", color="#d62728", linewidth=2.5)
        axs[i].fill_between(hist["epochs"], hist["train_acc"], hist["test_acc"], color='red', alpha=0.1)

        title_text = f"{name} Transformer\nGeneralized at: {res['grok_epoch']}"
        axs[i].set_title(title_text, fontsize=14, pad=10)
        axs[i].set_xlabel("Epochs", fontsize=12)
        axs[i].set_ylabel("Accuracy", fontsize=12)
        axs[i].grid(True, alpha=0.3)
        axs[i].legend(loc="lower right")

    plt.tight_layout()
    plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

# ==========================================
# CONFIGURATION
# ==========================================
P = 113
FRAC_TRAIN = 0.3
D_MODEL = 128
NUM_HEADS = 4
MLP_DIM = 512
LR = 1e-3
WEIGHT_DECAY = 0.0
EPOCHS = 40000
LOG_EVERY = 200
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42

torch.set_float32_matmul_precision("high")

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# ==========================================
# DATASET GENERATION
# ==========================================
def make_dataset(p, frac_train, seed=42):
    rng = random.Random(seed)
    all_pairs = [(a, b) for a in range(p) for b in range(p)]
    rng.shuffle(all_pairs)

    n_train = int(len(all_pairs) * frac_train)
    train_x = torch.tensor([[a, b, p] for a, b in all_pairs[:n_train]], dtype=torch.long)
    train_y = torch.tensor([(a + b) % p for a, b in all_pairs[:n_train]], dtype=torch.long)
    test_x = torch.tensor([[a, b, p] for a, b in all_pairs[n_train:]], dtype=torch.long)
    test_y = torch.tensor([(a + b) % p for a, b in all_pairs[n_train:]], dtype=torch.long)
    return train_x, train_y, test_x, test_y

# ==========================================
# 1. STANDARD & SPHERICAL TRANSFORMER
# ==========================================
class StandardTransformer(nn.Module):
    def __init__(self, p, d_model, num_heads, mlp_dim, normalize_hiddens=False):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.normalize_hiddens = normalize_hiddens

        self.tok_embed = nn.Embedding(p + 1, d_model)
        self.pos_embed = nn.Embedding(3, d_model)

        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

        self.mlp_in = nn.Linear(d_model, mlp_dim)
        self.mlp_out = nn.Linear(mlp_dim, d_model)
        self.unembed = nn.Linear(d_model, p, bias=False)

    def forward(self, x):
        B, L = x.shape
        pos = torch.arange(L, device=x.device).unsqueeze(0)
        h = self.tok_embed(x) + self.pos_embed(pos)

        if self.normalize_hiddens: h = F.normalize(h, dim=-1)

        Q = self.W_Q(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        K = self.W_K(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        V = self.W_V(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)

        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_head)
        attn_out = self.W_O((F.softmax(scores, dim=-1) @ V).transpose(1, 2).contiguous().view(B, L, self.d_model))

        h = h + attn_out
        if self.normalize_hiddens: h = F.normalize(h, dim=-1)

        h = h + self.mlp_out(F.relu(self.mlp_in(h)))
        if self.normalize_hiddens: h = F.normalize(h, dim=-1)

        return self.unembed(h[:, 2, :])

def inject_fourier_bias(model, p):
    with torch.no_grad():
        key_freqs = [14, 35, 41, 42, 52]
        for i, k in enumerate(key_freqs):
            for x in range(p):
                val = 2 * math.pi * k * x / p
                model.tok_embed.weight[x, 2 * i] = math.cos(val)
                model.tok_embed.weight[x, 2 * i + 1] = math.sin(val)

# ==========================================
# 2. THE STRICT PHASE TRANSFORMER (|z|=1)
# ==========================================
def strictly_phase(z, eps=1e-6):
    return z / (z.abs() + eps)

class StrictComplexLinear(nn.Module):
    def __init__(self, d_in, d_out, init_scale=0.02):
        super().__init__()
        self.W_real = nn.Linear(d_in, d_out, bias=False)
        self.W_imag = nn.Linear(d_in, d_out, bias=False)
        nn.init.normal_(self.W_real.weight, std=init_scale)
        nn.init.normal_(self.W_imag.weight, std=init_scale)

    def forward(self, z):
        return torch.complex(
            self.W_real(z.real) - self.W_imag(z.imag),
            self.W_real(z.imag) + self.W_imag(z.real)
        )

class PhaseAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.W_Q = StrictComplexLinear(d_model, d_model)
        self.W_K = StrictComplexLinear(d_model, d_model)
        self.W_V = StrictComplexLinear(d_model, d_model)
        self.W_O = StrictComplexLinear(d_model, d_model, init_scale=0.001)

    def forward(self, z):
        B, L, D = z.shape
        q = strictly_phase(self.W_Q(z)).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        k = strictly_phase(self.W_K(z)).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        v = strictly_phase(self.W_V(z)).view(B, L, self.num_heads, self.d_head).transpose(1, 2)

        # Real part of complex dot product is exact cosine similarity on the unit circle
        scores = (q @ k.conj().transpose(-2, -1)).real * (self.d_head ** -0.5)
        attn_out = strictly_phase(F.softmax(scores, dim=-1).to(v.dtype) @ v)
        attn_out = attn_out.transpose(1, 2).contiguous().view(B, L, D)
        return self.W_O(attn_out)

class PhaseFFN(nn.Module):
    def __init__(self, d_model, mlp_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model * 2, mlp_dim),
            nn.ReLU(),
            nn.Linear(mlp_dim, d_model)
        )
        nn.init.zeros_(self.net[2].weight)
        nn.init.zeros_(self.net[2].bias)

    def forward(self, z):
        features = torch.cat([z.real, z.imag], dim=-1)
        angles = self.net(features)
        return torch.exp(1j * angles)

class PhaseTransformer(nn.Module):
    def __init__(self, p, d_model, num_heads, mlp_dim):
        super().__init__()
        self.tok_embed_real = nn.Embedding(p + 1, d_model)
        self.tok_embed_imag = nn.Embedding(p + 1, d_model)
        self.pos_angles = nn.Embedding(3, d_model)
        nn.init.zeros_(self.pos_angles.weight)

        self.attn = PhaseAttention(d_model, num_heads)
        self.ffn = PhaseFFN(d_model, mlp_dim)

        self.bridge = nn.Linear(d_model * 2, d_model)
        self.unembed = nn.Linear(d_model, p, bias=False)

    def forward(self, x):
        B, L = x.shape
        z = strictly_phase(torch.complex(self.tok_embed_real(x), self.tok_embed_imag(x)))
        pos = torch.arange(L, device=x.device).unsqueeze(0)
        z = strictly_phase(z * torch.exp(1j * self.pos_angles(pos)))

        z = strictly_phase(z + self.attn(z))
        z = strictly_phase(z * self.ffn(z))

        h_real = self.bridge(torch.cat([z.real, z.imag], dim=-1))
        return self.unembed(h_real[:, 2, :])

# ==========================================
# TRAINING LOGIC
# ==========================================
def train_model(model, name, train_x, train_y, test_x, test_y):
    # Standard AdamW used uniformly for fair landscape comparison
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.999))
    criterion = nn.CrossEntropyLoss()

    train_x, train_y = train_x.to(DEVICE), train_y.to(DEVICE)
    test_x, test_y = test_x.to(DEVICE), test_y.to(DEVICE)

    grok_epoch = None
    history = {'epochs': [], 'train_acc': [], 'test_acc': [], 'train_loss': [], 'test_loss': []}

    pbar = tqdm(range(EPOCHS), desc=f"Training {name}")

    for epoch in pbar:
        model.train()
        logits = model(train_x)
        loss = criterion(logits, train_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % LOG_EVERY == 0 or epoch == EPOCHS - 1:
            model.eval()
            with torch.no_grad():
                train_acc = (logits.argmax(-1) == train_y).float().mean().item()
                test_logits = model(test_x)
                test_loss = criterion(test_logits, test_y).item()
                test_acc = (test_logits.argmax(-1) == test_y).float().mean().item()

            history['epochs'].append(epoch)
            history['train_acc'].append(train_acc)
            history['test_acc'].append(test_acc)
            history['train_loss'].append(loss.item())
            history['test_loss'].append(test_loss)

            pbar.set_postfix({
                'tr_acc': f"{train_acc:.3f}",
                'te_acc': f"{test_acc:.3f}"
            })

            if grok_epoch is None and test_acc > 0.95:
                grok_epoch = epoch
                tqdm.write(f"\n⚡ {name} generalized at epoch {epoch}! (test_acc={test_acc:.4f})")

    return {"grok_epoch": grok_epoch if grok_epoch else f">{EPOCHS}", "history": history}

# ==========================================
# EXECUTION & PLOTTING (2x2 Grid)
# ==========================================
if __name__ == "__main__":
    print("=" * 60)
    print("GROKKING LITMUS TEST: The 4 Quadrants of Generalization")
    print("=" * 60)

    tr_x, tr_y, te_x, te_y = make_dataset(P, FRAC_TRAIN, seed=SEED)

    models = []

    # 1. Standard Model (Baseline)
    set_seed(SEED)
    models.append(("Standard", StandardTransformer(P, D_MODEL, NUM_HEADS, MLP_DIM, False)))

    # 2. Fourier Init Model (Treasure Map)
    set_seed(SEED)
    f_model = StandardTransformer(P, D_MODEL, NUM_HEADS, MLP_DIM, False)
    inject_fourier_bias(f_model, P)
    models.append(("Fourier Init", f_model))

    # 3. Spherical Norm (L2 Straitjacket)
    set_seed(SEED)
    models.append(("Spherical Norm", StandardTransformer(P, D_MODEL, NUM_HEADS, MLP_DIM, True)))

    # 4. Phase Transformer (Complex Geometry Straitjacket)
    set_seed(SEED)
    models.append(("Strict Phase (|z|=1)", PhaseTransformer(P, D_MODEL, NUM_HEADS, MLP_DIM)))

    results = {}
    for name, model in models:
        model = model.to(DEVICE)
        results[name] = train_model(model, name, tr_x, tr_y, te_x, te_y)

    # --- PLOTTING 2x2 GRID ---
    fig, axs = plt.subplots(2, 2, figsize=(16, 12))
    axs = axs.flatten()

    for i, (name, res) in enumerate(results.items()):
        hist = res["history"]

        # Determine if it "grokked" (has a massive train/test gap) or "just learned" (lines hug each other)
        # We shade the gap red to highlight the memorization phase.
        axs[i].plot(hist["epochs"], hist["train_acc"], label="Train Acc", color="#1f77b4", linewidth=2.5)
        axs[i].plot(hist["epochs"], hist["test_acc"], label="Test Acc", color="#d62728", linewidth=2.5)
        axs[i].fill_between(hist["epochs"], hist["train_acc"], hist["test_acc"], color='red', alpha=0.1)

        title_text = f"{name} Transformer\nGeneralized at: {res['grok_epoch']}"
        axs[i].set_title(title_text, fontsize=14, pad=10)
        axs[i].set_xlabel("Epochs", fontsize=12)
        axs[i].set_ylabel("Accuracy", fontsize=12)
        axs[i].grid(True, alpha=0.3)
        axs[i].legend(loc="lower right")

    plt.tight_layout()
    plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

# ==========================================
# CONFIGURATION
# ==========================================
P = 113
FRAC_TRAIN = 0.3
D_MODEL = 128
NUM_HEADS = 4
MLP_DIM = 512
LR = 1e-3
EPOCHS = 40000
LOG_EVERY = 200
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42

torch.set_float32_matmul_precision("high")

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# ==========================================
# DATASET GENERATION
# ==========================================
def make_dataset(p, frac_train, seed=42):
    rng = random.Random(seed)
    all_pairs = [(a, b) for a in range(p) for b in range(p)]
    rng.shuffle(all_pairs)

    n_train = int(len(all_pairs) * frac_train)
    train_x = torch.tensor([[a, b, p] for a, b in all_pairs[:n_train]], dtype=torch.long)
    train_y = torch.tensor([(a + b) % p for a, b in all_pairs[:n_train]], dtype=torch.long)
    test_x = torch.tensor([[a, b, p] for a, b in all_pairs[n_train:]], dtype=torch.long)
    test_y = torch.tensor([(a + b) % p for a, b in all_pairs[n_train:]], dtype=torch.long)
    return train_x, train_y, test_x, test_y

# ==========================================
# THE STRICT PHASE TRANSFORMER (|z|=1)
# ==========================================
def strictly_phase(z, eps=1e-6):
    return z / (z.abs() + eps)

class StrictComplexLinear(nn.Module):
    def __init__(self, d_in, d_out, init_scale=0.02):
        super().__init__()
        self.W_real = nn.Linear(d_in, d_out, bias=False)
        self.W_imag = nn.Linear(d_in, d_out, bias=False)
        nn.init.normal_(self.W_real.weight, std=init_scale)
        nn.init.normal_(self.W_imag.weight, std=init_scale)

    def forward(self, z):
        return torch.complex(
            self.W_real(z.real) - self.W_imag(z.imag),
            self.W_real(z.imag) + self.W_imag(z.real)
        )

class PhaseAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.W_Q = StrictComplexLinear(d_model, d_model)
        self.W_K = StrictComplexLinear(d_model, d_model)
        self.W_V = StrictComplexLinear(d_model, d_model)
        self.W_O = StrictComplexLinear(d_model, d_model, init_scale=0.001)

    def forward(self, z):
        B, L, D = z.shape
        q = strictly_phase(self.W_Q(z)).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        k = strictly_phase(self.W_K(z)).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        v = strictly_phase(self.W_V(z)).view(B, L, self.num_heads, self.d_head).transpose(1, 2)

        scores = (q @ k.conj().transpose(-2, -1)).real * (self.d_head ** -0.5)
        attn_out = strictly_phase(F.softmax(scores, dim=-1).to(v.dtype) @ v)
        attn_out = attn_out.transpose(1, 2).contiguous().view(B, L, D)
        return self.W_O(attn_out)

class PhaseFFN(nn.Module):
    def __init__(self, d_model, mlp_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model * 2, mlp_dim),
            nn.ReLU(),
            nn.Linear(mlp_dim, d_model)
        )
        nn.init.zeros_(self.net[2].weight)
        nn.init.zeros_(self.net[2].bias)

    def forward(self, z):
        features = torch.cat([z.real, z.imag], dim=-1)
        angles = self.net(features)
        return torch.exp(1j * angles)

class PhaseTransformer(nn.Module):
    def __init__(self, p, d_model, num_heads, mlp_dim):
        super().__init__()
        self.tok_embed_real = nn.Embedding(p + 1, d_model)
        self.tok_embed_imag = nn.Embedding(p + 1, d_model)
        self.pos_angles = nn.Embedding(3, d_model)
        nn.init.zeros_(self.pos_angles.weight)

        self.attn = PhaseAttention(d_model, num_heads)
        self.ffn = PhaseFFN(d_model, mlp_dim)

        # ---> THE ONLY PLACES MAGNITUDE CAN GROW <---
        self.bridge = nn.Linear(d_model * 2, d_model)
        self.unembed = nn.Linear(d_model, p, bias=False)

    def forward(self, x):
        B, L = x.shape
        z = strictly_phase(torch.complex(self.tok_embed_real(x), self.tok_embed_imag(x)))
        pos = torch.arange(L, device=x.device).unsqueeze(0)
        z = strictly_phase(z * torch.exp(1j * self.pos_angles(pos)))

        z = strictly_phase(z + self.attn(z))
        z = strictly_phase(z * self.ffn(z))

        h_real = self.bridge(torch.cat([z.real, z.imag], dim=-1))
        return self.unembed(h_real[:, 2, :])

# ==========================================
# SURGICAL OPTIMIZER ROUTING
# ==========================================
def train_surgical_phase_model(model, train_x, train_y, test_x, test_y):
    # Route Weight Decay ONLY to the exit pipeline
    decay_params = []
    no_decay_params = []

    for name, param in model.named_parameters():
        if not param.requires_grad: continue
        # Target the final linear layers where magnitudes can explode
        if 'bridge' in name or 'unembed' in name:
            decay_params.append(param)
        else:
            no_decay_params.append(param)

    optimizer = torch.optim.AdamW([
        {"params": decay_params, "weight_decay": 1.0},
        {"params": no_decay_params, "weight_decay": 0.0},
    ], lr=LR, betas=(0.9, 0.999))

    criterion = nn.CrossEntropyLoss()

    train_x, train_y = train_x.to(DEVICE), train_y.to(DEVICE)
    test_x, test_y = test_x.to(DEVICE), test_y.to(DEVICE)

    grok_epoch = None
    history = {'epochs': [], 'train_acc': [], 'test_acc': [], 'train_loss': [], 'test_loss': []}

    pbar = tqdm(range(EPOCHS), desc="Training Phase (Targeted WD)")

    for epoch in pbar:
        model.train()
        logits = model(train_x)
        loss = criterion(logits, train_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % LOG_EVERY == 0 or epoch == EPOCHS - 1:
            model.eval()
            with torch.no_grad():
                train_acc = (logits.argmax(-1) == train_y).float().mean().item()
                test_logits = model(test_x)
                test_loss = criterion(test_logits, test_y).item()
                test_acc = (test_logits.argmax(-1) == test_y).float().mean().item()

            history['epochs'].append(epoch)
            history['train_acc'].append(train_acc)
            history['test_acc'].append(test_acc)
            history['train_loss'].append(loss.item())
            history['test_loss'].append(test_loss)

            pbar.set_postfix({
                'tr_acc': f"{train_acc:.3f}",
                'te_acc': f"{test_acc:.3f}"
            })

            if grok_epoch is None and test_acc > 0.95:
                grok_epoch = epoch
                tqdm.write(f"\n⚡ Surgical Phase Model generalized at epoch {epoch}! (test_acc={test_acc:.4f})")

    return {"grok_epoch": grok_epoch if grok_epoch else f">{EPOCHS}", "history": history}

# ==========================================
# EXECUTION & PLOTTING
# ==========================================
if __name__ == "__main__":
    print("=" * 60)
    print("TARGETED WEIGHT DECAY: Phase Transformer")
    print("WD=1.0 on [Bridge, Unembed], WD=0.0 on [Embeds, Attn, FFN]")
    print("=" * 60)

    tr_x, tr_y, te_x, te_y = make_dataset(P, FRAC_TRAIN, seed=SEED)

    set_seed(SEED)
    model = PhaseTransformer(P, D_MODEL, NUM_HEADS, MLP_DIM).to(DEVICE)
    res = train_surgical_phase_model(model, tr_x, tr_y, te_x, te_y)

    # --- PLOTTING ---
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    hist = res["history"]

    # Plot 1: Accuracy
    ax1.plot(hist["epochs"], hist["train_acc"], label="Train Acc", color="#1f77b4", linewidth=2.5)
    ax1.plot(hist["epochs"], hist["test_acc"], label="Test Acc", color="#d62728", linewidth=2.5)
    ax1.fill_between(hist["epochs"], hist["train_acc"], hist["test_acc"], color='red', alpha=0.1)
    ax1.set_title(f"Targeted WD Phase Transformer\nGeneralized at: {res['grok_epoch']}", fontsize=14)
    ax1.set_xlabel("Epochs", fontsize=12)
    ax1.set_ylabel("Accuracy", fontsize=12)
    ax1.grid(True, alpha=0.3)
    ax1.legend(loc="lower right")

    # Plot 2: Loss
    ax2.plot(hist["epochs"], hist["train_loss"], label="Train Loss", color="#1f77b4", linewidth=2.5)
    ax2.plot(hist["epochs"], hist["test_loss"], label="Test Loss", color="#d62728", linewidth=2.5)
    ax2.set_title("Cross-Entropy Loss (Log Scale)", fontsize=14)
    ax2.set_xlabel("Epochs", fontsize=12)
    ax2.set_ylabel("Loss", fontsize=12)
    ax2.set_yscale("log")
    ax2.grid(True, alpha=0.3)
    ax2.legend(loc="upper right")

    plt.tight_layout()
    plt.show()