<a href="https://colab.research.google.com/github/AlperYildirim1/geometric-grokking/blob/main/Grokking_Modular_Last.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import json

SAVE_DIR = '/content/drive/MyDrive/grokking_logs_detailed'
os.makedirs(SAVE_DIR, exist_ok=True)


# ==========================================
# CONFIGURATION
# ==========================================
P = 113
FRAC_TRAIN = 0.3
D_MODEL = 128
NUM_HEADS = 4
MLP_DIM = 512
LR = 1e-3
WEIGHT_DECAY = 1.0
LOG_EVERY = 200
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


#torch.set_float32_matmul_precision("high")

# ==========================================
# BULLETPROOF DETERMINISM
# ==========================================
def set_seed(seed):
    """Locks down all sources of randomness for 100% reproducibility."""
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # Required for strict cuDNN determinism

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True, warn_only=True)

# ==========================================
# DATASET GENERATION
# ==========================================
def make_dataset(p, frac_train, seed=42):
    set_seed(seed) # Ensure dataset split is also deterministic per seed
    rng = random.Random(seed)
    all_pairs = [(a, b) for a in range(p) for b in range(p)]
    rng.shuffle(all_pairs)

    n_train = int(len(all_pairs) * frac_train)
    train_x = torch.tensor([[a, b, p] for a, b in all_pairs[:n_train]], dtype=torch.long)
    train_y = torch.tensor([(a + b) % p for a, b in all_pairs[:n_train]], dtype=torch.long)
    test_x = torch.tensor([[a, b, p] for a, b in all_pairs[n_train:]], dtype=torch.long)
    test_y = torch.tensor([(a + b) % p for a, b in all_pairs[n_train:]], dtype=torch.long)
    return train_x, train_y, test_x, test_y

# ==========================================
# 1. STANDARD & SPHERICAL TRANSFORMER
# ==========================================
class StandardTransformer(nn.Module):
    def __init__(self, p, d_model, num_heads, mlp_dim, normalize_hiddens=False):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.normalize_hiddens = normalize_hiddens

        self.tok_embed = nn.Embedding(p + 1, d_model)
        self.pos_embed = nn.Embedding(3, d_model)

        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

        self.mlp_in = nn.Linear(d_model, mlp_dim)
        self.mlp_out = nn.Linear(mlp_dim, d_model)
        self.unembed = nn.Linear(d_model, p, bias=False)

    def forward(self, x):
        B, L = x.shape
        pos = torch.arange(L, device=x.device).unsqueeze(0)
        h = self.tok_embed(x) + self.pos_embed(pos)

        if self.normalize_hiddens: h = F.normalize(h, dim=-1)

        Q = self.W_Q(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        K = self.W_K(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        V = self.W_V(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)

        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_head)
        attn_out = self.W_O((F.softmax(scores, dim=-1) @ V).transpose(1, 2).contiguous().view(B, L, self.d_model))

        h = h + attn_out
        if self.normalize_hiddens: h = F.normalize(h, dim=-1)

        h = h + self.mlp_out(F.relu(self.mlp_in(h)))
        if self.normalize_hiddens: h = F.normalize(h, dim=-1)

        return self.unembed(h[:, 2, :])

def inject_fourier_bias(model, p):
    with torch.no_grad():
        key_freqs = [14, 35, 41, 42, 52]
        for i, k in enumerate(key_freqs):
            for x in range(p):
                val = 2 * math.pi * k * x / p
                model.tok_embed.weight[x, 2 * i] = math.cos(val)
                model.tok_embed.weight[x, 2 * i + 1] = math.sin(val)

# ==========================================
# 2. THE STRICT PHASE TRANSFORMER (|z|=1)
# ==========================================
def strictly_phase(z, eps=1e-6):
    return z / (z.abs() + eps)

class StrictComplexLinear(nn.Module):
    def __init__(self, d_in, d_out, init_scale=0.02):
        super().__init__()
        self.W_real = nn.Linear(d_in, d_out, bias=False)
        self.W_imag = nn.Linear(d_in, d_out, bias=False)
        nn.init.normal_(self.W_real.weight, std=init_scale)
        nn.init.normal_(self.W_imag.weight, std=init_scale)

    def forward(self, z):
        return torch.complex(
            self.W_real(z.real) - self.W_imag(z.imag),
            self.W_real(z.imag) + self.W_imag(z.real)
        )

class PhaseAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.W_Q = StrictComplexLinear(d_model, d_model)
        self.W_K = StrictComplexLinear(d_model, d_model)
        self.W_V = StrictComplexLinear(d_model, d_model)
        self.W_O = StrictComplexLinear(d_model, d_model, init_scale=0.001)

    def forward(self, z):
        B, L, D = z.shape
        q = strictly_phase(self.W_Q(z)).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        k = strictly_phase(self.W_K(z)).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        v = strictly_phase(self.W_V(z)).view(B, L, self.num_heads, self.d_head).transpose(1, 2)

        scores = (q @ k.conj().transpose(-2, -1)).real * (self.d_head ** -0.5)
        attn_out = strictly_phase(F.softmax(scores, dim=-1).to(v.dtype) @ v)
        attn_out = attn_out.transpose(1, 2).contiguous().view(B, L, D)
        return self.W_O(attn_out)

class PhaseFFN(nn.Module):
    def __init__(self, d_model, mlp_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model * 2, mlp_dim),
            nn.ReLU(),
            nn.Linear(mlp_dim, d_model)
        )
        nn.init.zeros_(self.net[2].weight)
        nn.init.zeros_(self.net[2].bias)

    def forward(self, z):
        features = torch.cat([z.real, z.imag], dim=-1)
        angles = self.net(features)
        return torch.exp(1j * angles)

class PhaseTransformer(nn.Module):
    def __init__(self, p, d_model, num_heads, mlp_dim):
        super().__init__()
        self.tok_embed_real = nn.Embedding(p + 1, d_model)
        self.tok_embed_imag = nn.Embedding(p + 1, d_model)
        self.pos_angles = nn.Embedding(3, d_model)
        nn.init.zeros_(self.pos_angles.weight)

        self.attn = PhaseAttention(d_model, num_heads)
        self.ffn = PhaseFFN(d_model, mlp_dim)

        self.bridge = nn.Linear(d_model * 2, d_model)
        self.unembed = nn.Linear(d_model, p, bias=False)

    def forward(self, x):
        B, L = x.shape
        z = strictly_phase(torch.complex(self.tok_embed_real(x), self.tok_embed_imag(x)))
        pos = torch.arange(L, device=x.device).unsqueeze(0)
        z = strictly_phase(z * torch.exp(1j * self.pos_angles(pos)))

        z = strictly_phase(z + self.attn(z))
        z = strictly_phase(z * self.ffn(z))

        h_real = self.bridge(torch.cat([z.real, z.imag], dim=-1))
        return self.unembed(h_real[:, 2, :])

# ==========================================
# TRAINING LOGIC
# ==========================================
def train_model(model, name, train_x, train_y, test_x, test_y, epochs):
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.999))
    criterion = nn.CrossEntropyLoss()

    train_x, train_y = train_x.to(DEVICE), train_y.to(DEVICE)
    test_x, test_y = test_x.to(DEVICE), test_y.to(DEVICE)

    grok_epoch = None
    history = {'epochs': [], 'train_acc': [], 'test_acc': [], 'train_loss': [], 'test_loss': []}

    pbar = tqdm(range(epochs), desc=f"Training {name}")

    for epoch in pbar:
        model.train()
        logits = model(train_x)
        loss = criterion(logits, train_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % LOG_EVERY == 0 or epoch == epochs - 1:
            model.eval()
            with torch.no_grad():
                train_acc = (logits.argmax(-1) == train_y).float().mean().item()
                test_logits = model(test_x)
                test_loss = criterion(test_logits, test_y).item()
                test_acc = (test_logits.argmax(-1) == test_y).float().mean().item()

            history['epochs'].append(epoch)
            history['train_acc'].append(train_acc)
            history['test_acc'].append(test_acc)
            history['train_loss'].append(loss.item())
            history['test_loss'].append(test_loss)

            pbar.set_postfix({
                'tr_acc': f"{train_acc:.3f}",
                'te_acc': f"{test_acc:.3f}"
            })

            if grok_epoch is None and test_acc > 0.95:
                grok_epoch = epoch
                tqdm.write(f"\nâš¡ {name} generalized at epoch {epoch}! (test_acc={test_acc:.4f})")

    return {"grok_epoch": grok_epoch if grok_epoch else f">{epochs}", "history": history}

# ==========================================
# EXECUTION & PLOTTING (1x4, 2x2, and Zoomed Grids)
# ==========================================
if __name__ == "__main__":
    print("=" * 60)
    print("GROKKING LITMUS TEST: The 4 Quadrants of Generalization")
    print("=" * 60)

    # Loop through seeds 1 to 5 automatically
    for SEED in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]:
        print(f"\n\n{'=' * 40}")
        print(f"ðŸš€ STARTING RUN FOR SEED {SEED} ðŸš€")
        print(f"{'=' * 40}\n")

        # 1. Generate Dataset Deterministically
        tr_x, tr_y, te_x, te_y = make_dataset(P, FRAC_TRAIN, seed=SEED)

        # Format: (Name, Model, Epochs)
        models = []

        # 2. Standard Model (Baseline) - 60k epochs
        set_seed(SEED)
        models.append(("Standard", StandardTransformer(P, D_MODEL, NUM_HEADS, MLP_DIM, False), 60000))

        # 3. Fourier Init Model (Treasure Map) - 40k epochs
        set_seed(SEED)
        f_model = StandardTransformer(P, D_MODEL, NUM_HEADS, MLP_DIM, False)
        inject_fourier_bias(f_model, P)
        models.append(("Fourier Init", f_model, 40000))

        # 4. Spherical Norm (L2 Straitjacket) - Learns instantly, stops at 10k
        set_seed(SEED)
        models.append(("Spherical Norm", StandardTransformer(P, D_MODEL, NUM_HEADS, MLP_DIM, True), 10000))

        # 5. The Ultimate Combo: Fourier Init + Spherical Norm - stops at 10k
        set_seed(SEED)
        combo_model = StandardTransformer(P, D_MODEL, NUM_HEADS, MLP_DIM, True)
        inject_fourier_bias(combo_model, P)
        models.append(("Fourier + Spherical", combo_model, 10000))

        results = {}
        for name, model, train_epochs in models:
            model = model.to(DEVICE)
            results[name] = train_model(model, name, tr_x, tr_y, te_x, te_y, train_epochs)

            # Save logs to Drive
            safe_name = name.replace(" ", "_").replace("(|z|=1)", "z1")
            save_path = os.path.join(SAVE_DIR, f"{safe_name}_seed{SEED}.json")
            with open(save_path, "w") as f:
                json.dump(results[name]["history"], f)
            print(f"Saved {name} logs to {save_path}")

        # ==========================================
        # --- PLOTTING 1x4 HORIZONTAL GRID ---
        # ==========================================
        fig_1x4, axs_1x4 = plt.subplots(1, 4, figsize=(24, 5))
        axs_1x4 = axs_1x4.flatten()

        for i, (name, res) in enumerate(results.items()):
            hist = res["history"]

            axs_1x4[i].plot(hist["epochs"], hist["train_acc"], label="Train Acc", color="#1f77b4", linewidth=2.5)
            axs_1x4[i].plot(hist["epochs"], hist["test_acc"], label="Test Acc", color="#d62728", linewidth=2.5)
            axs_1x4[i].fill_between(hist["epochs"], hist["train_acc"], hist["test_acc"], color='red', alpha=0.1)

            title_text = f"{name}\nGeneralized at: {res['grok_epoch']}"
            axs_1x4[i].set_title(title_text, fontsize=16, pad=10)
            axs_1x4[i].set_xlabel("Epochs", fontsize=14)
            axs_1x4[i].set_ylabel("Accuracy", fontsize=14)
            axs_1x4[i].set_xlim(0, 60000)  # <--- LOCKED X-AXIS
            axs_1x4[i].tick_params(axis='both', which='major', labelsize=12)
            axs_1x4[i].grid(True, alpha=0.3)
            axs_1x4[i].legend(loc="lower right", fontsize=12)

        plt.tight_layout()
        plot_path_1x4 = os.path.join(SAVE_DIR, f"grokking_1x4_grid_seed{SEED}.png")
        plt.savefig(plot_path_1x4, dpi=300, bbox_inches='tight')
        print(f"âœ… Saved 1x4 plot to: {plot_path_1x4}")
        plt.close(fig_1x4) # Close to prevent overlapping plots

        # ==========================================
        # --- PLOTTING 2x2 SQUARE GRID ---
        # ==========================================
        fig_2x2, axs_2x2 = plt.subplots(2, 2, figsize=(12, 10))
        axs_2x2 = axs_2x2.flatten()

        for i, (name, res) in enumerate(results.items()):
            hist = res["history"]

            axs_2x2[i].plot(hist["epochs"], hist["train_acc"], label="Train Acc", color="#1f77b4", linewidth=2.5)
            axs_2x2[i].plot(hist["epochs"], hist["test_acc"], label="Test Acc", color="#d62728", linewidth=2.5)
            axs_2x2[i].fill_between(hist["epochs"], hist["train_acc"], hist["test_acc"], color='red', alpha=0.1)

            title_text = f"{name}\nGeneralized at: {res['grok_epoch']}"
            axs_2x2[i].set_title(title_text, fontsize=16, pad=10)
            axs_2x2[i].set_xlabel("Epochs", fontsize=14)
            axs_2x2[i].set_ylabel("Accuracy", fontsize=14)
            axs_2x2[i].set_xlim(0, 60000)  # <--- LOCKED X-AXIS
            axs_2x2[i].tick_params(axis='both', which='major', labelsize=12)
            axs_2x2[i].grid(True, alpha=0.3)
            axs_2x2[i].legend(loc="lower right", fontsize=12)

        plt.tight_layout()
        plot_path_2x2 = os.path.join(SAVE_DIR, f"grokking_2x2_grid_seed{SEED}.png")
        plt.savefig(plot_path_2x2, dpi=300, bbox_inches='tight')
        print(f"âœ… Saved 2x2 plot to: {plot_path_2x2}")

        # Show the 2x2 in the notebook output
        plt.show()

        # ==========================================
        # --- PLOTTING 1x1 ZOOMED GRIDS (SPHERICAL ONLY) ---
        # ==========================================
        for name, res in results.items():
            if name in ["Spherical Norm", "Fourier + Spherical"]:
                fig_zoom, ax_zoom = plt.subplots(figsize=(8, 6))
                hist = res["history"]

                ax_zoom.plot(hist["epochs"], hist["train_acc"], label="Train Acc", color="#1f77b4", linewidth=2.5)
                ax_zoom.plot(hist["epochs"], hist["test_acc"], label="Test Acc", color="#d62728", linewidth=2.5)
                ax_zoom.fill_between(hist["epochs"], hist["train_acc"], hist["test_acc"], color='red', alpha=0.1)

                title_text = f"{name} (Zoomed 5k)\nGeneralized at: {res['grok_epoch']}"
                ax_zoom.set_title(title_text, fontsize=16, pad=10)
                ax_zoom.set_xlabel("Epochs", fontsize=14)
                ax_zoom.set_ylabel("Accuracy", fontsize=14)
                ax_zoom.set_xlim(0, 5000)  # <--- ZOOMED X-AXIS
                ax_zoom.tick_params(axis='both', which='major', labelsize=12)
                ax_zoom.grid(True, alpha=0.3)
                ax_zoom.legend(loc="lower right", fontsize=12)

                plt.tight_layout()
                safe_name = name.replace(" ", "_").replace("(|z|=1)", "z1")
                plot_path_zoom = os.path.join(SAVE_DIR, f"grokking_zoomed_{safe_name}_seed{SEED}.png")
                plt.savefig(plot_path_zoom, dpi=300, bbox_inches='tight')
                print(f"âœ… Saved Zoomed plot to: {plot_path_zoom}")
                plt.close(fig_zoom)

In [None]:
import os, glob, json
import numpy as np
import pandas as pd

SAVE_DIR = '/content/drive/MyDrive/grokking_logs_detailed'
model_names = ["Standard", "Fourier_Init", "Spherical_Norm", "Fourier_+_Spherical"]

results_data = []

print("Analyzing grokking moments across all seeds...\n")

for model in model_names:
    files = glob.glob(f"{SAVE_DIR}/{model}_seed*.json")
    if not files:
        continue

    grok_epochs = []
    did_not_grok_count = 0

    for f in files:
        data = json.load(open(f))
        test_accs = data["test_acc"]
        epochs = data["epochs"]

        # Find the first epoch where test_acc > 0.95
        grok_epoch = None
        for epoch, acc in zip(epochs, test_accs):
            if acc > 0.95:
                grok_epoch = epoch
                break

        if grok_epoch is not None:
            grok_epochs.append(grok_epoch)
        else:
            did_not_grok_count += 1

    if grok_epochs:
        results_data.append({
            "Model": model.replace("_", " "),
            "Seeds Run": len(files),
            "Successful Groks": len(grok_epochs),
            "Earliest Grok": np.min(grok_epochs),
            "Latest Grok": np.max(grok_epochs),
            "Average Grok": int(np.mean(grok_epochs)),
            "Std Dev": int(np.std(grok_epochs))
        })

if results_data:
    # Use Pandas to print a clean, paper-ready table
    df = pd.DataFrame(results_data)
    print(df.to_markdown(index=False))
else:
    print("No log files found yet. Make sure training has finished!")

In [None]:
import sys
import os
import torch
import subprocess
import datetime

# Make sure the directory exists
SAVE_DIR = '/content/drive/MyDrive/grokking_logs'
os.makedirs(SAVE_DIR, exist_ok=True)

env_log_path = os.path.join(SAVE_DIR, "environment_info.txt")

with open(env_log_path, "w") as f:
    f.write(f"--- Grokking Experiment Environment Log ---\n")
    f.write(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")

    # 1. Python Version
    f.write("=== Python Version ===\n")
    f.write(sys.version + "\n\n")

    # 2. PyTorch & GPU Info
    f.write("=== PyTorch & CUDA ===\n")
    f.write(f"PyTorch Version: {torch.__version__}\n")
    f.write(f"CUDA Available: {torch.cuda.is_available()}\n")
    if torch.cuda.is_available():
        f.write(f"CUDA Built-in Version: {torch.version.cuda}\n")
        f.write(f"cuDNN Version: {torch.backends.cudnn.version()}\n")
        f.write(f"GPU Model: {torch.cuda.get_device_name(0)}\n")
    f.write("\n")

    # 3. Full NVIDIA-SMI Output
    f.write("=== NVIDIA-SMI ===\n")
    try:
        nvidia_smi = subprocess.check_output("nvidia-smi", shell=True).decode()
        f.write(nvidia_smi + "\n")
    except Exception as e:
        f.write(f"Could not retrieve nvidia-smi: {e}\n\n")

    # 4. Pip Freeze (All installed libraries)
    f.write("=== Installed PIP Packages ===\n")
    try:
        pip_freeze = subprocess.check_output("pip freeze", shell=True).decode()
        f.write(pip_freeze + "\n")
    except Exception as e:
        f.write(f"Could not retrieve pip freeze: {e}\n")

print(f"âœ… Environment successfully logged to: {env_log_path}")