<a href="https://colab.research.google.com/github/AlperYildirim1/geometric-grokking/blob/main/Grokking_S5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
import numpy as np
import itertools
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import json

SAVE_DIR = '/content/drive/MyDrive/grokking_logs_s5'
os.makedirs(SAVE_DIR, exist_ok=True)

# ==========================================
# CONFIGURATION
# ==========================================
N_DEGREE = 5                            # S5 Symmetric Group
GROUP_SIZE = math.factorial(N_DEGREE)   # 120 elements
FRAC_TRAIN = 0.3
D_MODEL = 128
NUM_HEADS = 4
MLP_DIM = 512
LR = 4e-3
WEIGHT_DECAY = 1.0
EPOCHS = 60000                          # S5 usually takes slightly longer than Z_p
LOG_EVERY = 200
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 5

# ==========================================
# BULLETPROOF DETERMINISM
# ==========================================
def set_seed(seed):
    """Locks down all sources of randomness for 100% reproducibility."""
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True, warn_only=True)

# ==========================================
# S5 DATASET GENERATION
# ==========================================
def generate_s5_elements():
    """Generates the 120 elements of S5 and their composition table."""
    elements = list(itertools.permutations(range(N_DEGREE)))
    elem_to_id = {elem: i for i, elem in enumerate(elements)}
    id_to_elem = {i: elem for i, elem in enumerate(elements)}
    return elements, elem_to_id, id_to_elem

def compose_s5(id1, id2, id_to_elem, elem_to_id):
    """Composes two permutations: applying p2 then p1."""
    p1 = id_to_elem[id1]
    p2 = id_to_elem[id2]
    p_out = tuple(p1[p2[i]] for i in range(N_DEGREE))
    return elem_to_id[p_out]

def make_s5_dataset(frac_train, seed=42):
    set_seed(seed)
    rng = random.Random(seed)

    _, elem_to_id, id_to_elem = generate_s5_elements()

    all_pairs = [(a, b) for a in range(GROUP_SIZE) for b in range(GROUP_SIZE)]
    rng.shuffle(all_pairs)

    n_train = int(len(all_pairs) * frac_train)

    # x inputs: [elem_A, elem_B, OP_TOKEN]
    op_token = GROUP_SIZE # The index 120 acts as the composition operator

    # Generate Train
    train_x = torch.tensor([[a, b, op_token] for a, b in all_pairs[:n_train]], dtype=torch.long)
    train_y = torch.tensor([compose_s5(a, b, id_to_elem, elem_to_id) for a, b in all_pairs[:n_train]], dtype=torch.long)

    # Generate Test
    test_x = torch.tensor([[a, b, op_token] for a, b in all_pairs[n_train:]], dtype=torch.long)
    test_y = torch.tensor([compose_s5(a, b, id_to_elem, elem_to_id) for a, b in all_pairs[n_train:]], dtype=torch.long)

    return train_x, train_y, test_x, test_y

# ==========================================
# STANDARD & SPHERICAL TRANSFORMER
# ==========================================
class StandardTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, mlp_dim, normalize_hiddens=False):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.normalize_hiddens = normalize_hiddens

        self.tok_embed = nn.Embedding(vocab_size + 1, d_model) # +1 for OP_TOKEN
        self.pos_embed = nn.Embedding(3, d_model)

        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

        self.mlp_in = nn.Linear(d_model, mlp_dim)
        self.mlp_out = nn.Linear(mlp_dim, d_model)
        self.unembed = nn.Linear(d_model, vocab_size, bias=False) # Unembed back to 120 S5 elements

    def forward(self, x):
        B, L = x.shape
        pos = torch.arange(L, device=x.device).unsqueeze(0)
        h = self.tok_embed(x) + self.pos_embed(pos)

        if self.normalize_hiddens: h = F.normalize(h, dim=-1)

        Q = self.W_Q(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        K = self.W_K(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)
        V = self.W_V(h).view(B, L, self.num_heads, self.d_head).transpose(1, 2)

        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_head)
        attn_out = self.W_O((F.softmax(scores, dim=-1) @ V).transpose(1, 2).contiguous().view(B, L, self.d_model))

        h = h + attn_out
        if self.normalize_hiddens: h = F.normalize(h, dim=-1)

        h = h + self.mlp_out(F.relu(self.mlp_in(h)))
        if self.normalize_hiddens: h = F.normalize(h, dim=-1)

        return self.unembed(h[:, 2, :])

# ==========================================
# TRAINING LOGIC
# ==========================================
def train_model(model, name, train_x, train_y, test_x, test_y, epochs):
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.999))
    criterion = nn.CrossEntropyLoss()

    train_x, train_y = train_x.to(DEVICE), train_y.to(DEVICE)
    test_x, test_y = test_x.to(DEVICE), test_y.to(DEVICE)

    grok_epoch = None
    history = {'epochs': [], 'train_acc': [], 'test_acc': [], 'train_loss': [], 'test_loss': []}

    pbar = tqdm(range(epochs), desc=f"Training {name}")

    for epoch in pbar:
        model.train()
        logits = model(train_x)
        loss = criterion(logits, train_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % LOG_EVERY == 0 or epoch == epochs - 1:
            model.eval()
            with torch.no_grad():
                train_acc = (logits.argmax(-1) == train_y).float().mean().item()
                test_logits = model(test_x)
                test_loss = criterion(test_logits, test_y).item()
                test_acc = (test_logits.argmax(-1) == test_y).float().mean().item()

            history['epochs'].append(epoch)
            history['train_acc'].append(train_acc)
            history['test_acc'].append(test_acc)
            history['train_loss'].append(loss.item())
            history['test_loss'].append(test_loss)

            pbar.set_postfix({
                'tr_acc': f"{train_acc:.3f}",
                'te_acc': f"{test_acc:.3f}"
            })

            if grok_epoch is None and test_acc > 0.95:
                grok_epoch = epoch
                tqdm.write(f"\n⚡ {name} generalized at epoch {epoch}! (test_acc={test_acc:.4f})")

    return {"grok_epoch": grok_epoch if grok_epoch else f">{epochs}", "history": history}

# ==========================================
# EXECUTION & PLOTTING (1x2 Grid)
# ==========================================
if __name__ == "__main__":
    print("=" * 60)
    print("S5 GROKKING TEST: The Non-Circular Baseline")
    print("=" * 60)

    tr_x, tr_y, te_x, te_y = make_s5_dataset(FRAC_TRAIN, seed=SEED)

    models = []

    # 1. Standard Model (Should Grok)
    set_seed(SEED)
    models.append(("Standard", StandardTransformer(GROUP_SIZE, D_MODEL, NUM_HEADS, MLP_DIM, False), EPOCHS))

    # 2. Spherical Norm (Should Struggle/Fail)
    set_seed(SEED)
    models.append(("Spherical Norm", StandardTransformer(GROUP_SIZE, D_MODEL, NUM_HEADS, MLP_DIM, True), EPOCHS))

    results = {}
    for name, model, train_epochs in models:
        model = model.to(DEVICE)
        results[name] = train_model(model, name, tr_x, tr_y, te_x, te_y, train_epochs)

        safe_name = name.replace(" ", "_")
        save_path = os.path.join(SAVE_DIR, f"{safe_name}_S5_seed{SEED}.json")
        with open(save_path, "w") as f:
            json.dump(results[name]["history"], f)

    # --- PLOTTING 1x2 GRID ---
    fig, axs = plt.subplots(1, 2, figsize=(14, 5))
    axs = axs.flatten()

    for i, (name, res) in enumerate(results.items()):
        hist = res["history"]

        axs[i].plot(hist["epochs"], hist["train_acc"], label="Train Acc", color="#1f77b4", linewidth=2.5)
        axs[i].plot(hist["epochs"], hist["test_acc"], label="Test Acc", color="#d62728", linewidth=2.5)
        axs[i].fill_between(hist["epochs"], hist["train_acc"], hist["test_acc"], color='red', alpha=0.1)

        title_text = f"S5: {name} Transformer\nGeneralized at: {res['grok_epoch']}"
        axs[i].set_title(title_text, fontsize=16, pad=10)
        axs[i].set_xlabel("Epochs", fontsize=14)
        axs[i].set_ylabel("Accuracy", fontsize=14)
        axs[i].tick_params(axis='both', which='major', labelsize=12)
        axs[i].grid(True, alpha=0.3)
        axs[i].legend(loc="lower right", fontsize=12)

    plt.tight_layout()
    plot_path = os.path.join(SAVE_DIR, f"grokking_S5_1x2_grid_seed{SEED}.png")
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    print(f"✅ Saved S5 plot to: {plot_path}")

    plt.show()

In [None]:
import os, glob, json
import numpy as np
import pandas as pd

SAVE_DIR = '/content/drive/MyDrive/grokking_logs'
model_names = ["Standard", "Fourier_Init", "Spherical_Norm"]

results_data = []

print("Analyzing grokking moments across all seeds...\n")

for model in model_names:
    files = glob.glob(f"{SAVE_DIR}/{model}_seed*.json")
    if not files:
        continue

    grok_epochs = []
    did_not_grok_count = 0

    for f in files:
        data = json.load(open(f))
        test_accs = data["test_acc"]
        epochs = data["epochs"]

        # Find the first epoch where test_acc > 0.95
        grok_epoch = None
        for epoch, acc in zip(epochs, test_accs):
            if acc > 0.95:
                grok_epoch = epoch
                break

        if grok_epoch is not None:
            grok_epochs.append(grok_epoch)
        else:
            did_not_grok_count += 1

    if grok_epochs:
        results_data.append({
            "Model": model.replace("_", " "),
            "Seeds Run": len(files),
            "Successful Groks": len(grok_epochs),
            "Earliest Grok": np.min(grok_epochs),
            "Latest Grok": np.max(grok_epochs),
            "Average Grok": int(np.mean(grok_epochs)),
            "Std Dev": int(np.std(grok_epochs))
        })

if results_data:
    # Use Pandas to print a clean, paper-ready table
    df = pd.DataFrame(results_data)
    print(df.to_markdown(index=False))
else:
    print("No log files found yet. Make sure training has finished!")

In [None]:
import sys
import os
import torch
import subprocess
import datetime

# Make sure the directory exists
SAVE_DIR = '/content/drive/MyDrive/grokking_logs'
os.makedirs(SAVE_DIR, exist_ok=True)

env_log_path = os.path.join(SAVE_DIR, "environment_info.txt")

with open(env_log_path, "w") as f:
    f.write(f"--- Grokking Experiment Environment Log ---\n")
    f.write(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")

    # 1. Python Version
    f.write("=== Python Version ===\n")
    f.write(sys.version + "\n\n")

    # 2. PyTorch & GPU Info
    f.write("=== PyTorch & CUDA ===\n")
    f.write(f"PyTorch Version: {torch.__version__}\n")
    f.write(f"CUDA Available: {torch.cuda.is_available()}\n")
    if torch.cuda.is_available():
        f.write(f"CUDA Built-in Version: {torch.version.cuda}\n")
        f.write(f"cuDNN Version: {torch.backends.cudnn.version()}\n")
        f.write(f"GPU Model: {torch.cuda.get_device_name(0)}\n")
    f.write("\n")

    # 3. Full NVIDIA-SMI Output
    f.write("=== NVIDIA-SMI ===\n")
    try:
        nvidia_smi = subprocess.check_output("nvidia-smi", shell=True).decode()
        f.write(nvidia_smi + "\n")
    except Exception as e:
        f.write(f"Could not retrieve nvidia-smi: {e}\n\n")

    # 4. Pip Freeze (All installed libraries)
    f.write("=== Installed PIP Packages ===\n")
    try:
        pip_freeze = subprocess.check_output("pip freeze", shell=True).decode()
        f.write(pip_freeze + "\n")
    except Exception as e:
        f.write(f"Could not retrieve pip freeze: {e}\n")

print(f"✅ Environment successfully logged to: {env_log_path}")