# Do Transformers Grok Succinct Algorithms?

## Interactive Analysis Notebook

This notebook provides an interactive walkthrough of the experiments from the ACL 2026 paper:

> **"Do Transformers Grok Succinct Algorithms? Mechanistic Evidence for Counting Circuits"**

### Contents
1. [Setup and Data Exploration](#1-setup)
2. [Transformer Training with Grokking](#2-training)
3. [Mechanistic Analysis](#3-analysis)
4. [Visualization](#4-visualization)
5. [RNN Baseline Comparison](#5-baselines)

## 1. Setup and Data Exploration <a id='1-setup'></a>

In [None]:
# Install dependencies if needed
# !pip install torch numpy matplotlib seaborn tqdm pyyaml

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

In [None]:
# Import project modules
from src.data.counter_dataset import (
    LargeCounterDataset, 
    create_dataloader,
    int_to_binary,
    get_carry_chain_length
)
from src.models.transformer import TransformerConfig, TransformerLM
from src.models.rnn import RNNConfig, RNNLM
from src.training.trainer import TrainingConfig, Trainer

### 1.1 Understanding the LARGECOUNTER Task

The task is to predict the next n-bit binary number:
- Input: `N_i # ` (current number followed by delimiter)
- Output: `N_{i+1}` (next number in sequence)

Example: `0111 # 1000` (7 → 8 in binary)

In [None]:
# Demonstrate binary counting
n_bits = 4
print("Binary Counting Examples:")
print("="*40)
for n in range(16):
    n_bin = int_to_binary(n, n_bits)
    next_bin = int_to_binary((n + 1) % 16, n_bits)
    carry_len = get_carry_chain_length(n, n_bits)
    print(f"{n:2d} = {n_bin} → {next_bin} = {(n+1) % 16:2d}  (carry chain: {carry_len})")

### 1.2 Carry Chain Length Distribution

The carry chain length determines how many bits flip during incrementing:
- `k=0`: Only LSB flips (numbers ending in 0)
- `k=n-1`: Global carry - all bits flip (e.g., 1111 → 0000)

In [None]:
# Visualize carry chain distribution (why stratified sampling matters)
n_bits = 20
carry_counts = [0] * (n_bits + 1)

for n in range(2**min(n_bits, 16)):  # Sample first 2^16 for visualization
    k = get_carry_chain_length(n, n_bits)
    carry_counts[k] += 1

fig, ax = plt.subplots(figsize=(12, 5))
ax.bar(range(len(carry_counts)), carry_counts, color='steelblue', alpha=0.7)
ax.set_xlabel('Carry Chain Length (k)', fontsize=12)
ax.set_ylabel('Count (log scale)', fontsize=12)
ax.set_title('Natural Distribution of Carry Chain Lengths\n(Exponential Decay: P(k) = 2^{-k})', fontsize=14)
ax.set_yscale('log')

# Add theoretical line
theoretical = [carry_counts[0] * (0.5**k) for k in range(len(carry_counts))]
ax.plot(range(len(carry_counts)), theoretical, 'r--', linewidth=2, label='Theoretical: 2^{-k}')
ax.legend()
plt.tight_layout()
plt.show()

print("\nKey insight: Global carries (k=n-1) are exponentially rare!")
print("→ Stratified sampling ensures the model sees all difficulty levels equally.")

### 1.3 Dataset Creation

In [None]:
# Create datasets
n_bits = 20
train_dataset = LargeCounterDataset(n_bits=n_bits, split='train', train_ratio=0.3, stratified=True)
test_dataset = LargeCounterDataset(n_bits=n_bits, split='test', train_ratio=0.3, stratified=True)

print(f"Training samples: {len(train_dataset):,}")
print(f"Test samples: {len(test_dataset):,}")
print(f"State space size: 2^{n_bits} = {2**n_bits:,}")
print(f"Training coverage: {len(train_dataset) / 2**n_bits * 100:.1f}%")

In [None]:
# Examine a sample
sample = train_dataset[0]
print("Sample format:")
print(f"  Input IDs shape: {sample['input_ids'].shape}")
print(f"  Labels shape: {sample['labels'].shape}")
print(f"  Input tokens: {sample['input_ids'][:30].tolist()}...")

## 2. Transformer Training with Grokking <a id='2-training'></a>

Key training settings for inducing grokking:
- **High weight decay (λ=1.0)**: Critical for "complexity collapse"
- **Extended training**: Continue well past training convergence
- **Small model**: 2 layers, 4 heads, d=64

In [None]:
# Model configuration (paper settings)
model_config = TransformerConfig(
    vocab_size=train_dataset.tokenizer.vocab_size,
    d_model=64,
    n_heads=4,
    n_layers=2,
    d_ff=256,
    max_seq_len=128,
    dropout=0.0,
    use_rope=True  # Critical for Same-Bit Lookup
)

model = TransformerLM(model_config).to(device)

# Count parameters
n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")
print(f"Configuration: {model_config}")

In [None]:
# Training configuration
train_config = TrainingConfig(
    learning_rate=1e-3,
    weight_decay=1.0,  # High weight decay for grokking!
    batch_size=512,
    max_steps=10000,  # Reduce for demo (paper uses 50000)
    eval_interval=500,
    warmup_steps=500,
    save_checkpoints=False
)

print("Training Configuration:")
print(f"  Learning rate: {train_config.learning_rate}")
print(f"  Weight decay: {train_config.weight_decay} (critical for grokking!)")
print(f"  Max steps: {train_config.max_steps}")

In [None]:
# Create data loaders
train_loader = create_dataloader(train_dataset, batch_size=train_config.batch_size, shuffle=True)
test_loader = create_dataloader(test_dataset, batch_size=train_config.batch_size, shuffle=False)

print(f"Train batches per epoch: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")

In [None]:
# Train the model
# Note: This is a shortened demo. Full training takes ~2 hours.

trainer = Trainer(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    config=train_config,
    output_dir='../outputs/notebook_demo',
    device=device
)

print("Starting training...")
print("Watch for the grokking phase transition!")
print()

history = trainer.train()

## 3. Mechanistic Analysis <a id='3-analysis'></a>

After grokking, we analyze the internal circuits to verify alignment with B-RASP theory.

In [None]:
from src.analysis.attention_analysis import (
    analyze_attention_patterns,
    compute_diagonal_score,
    find_lookup_heads
)

In [None]:
# Extract attention patterns
model.eval()

# Generate test sequences
test_batch = next(iter(test_loader))
input_ids = test_batch['input_ids'][:8].to(device)

# Get attention patterns
with torch.no_grad():
    patterns = model.get_attention_patterns(input_ids)

print(f"Extracted attention patterns from {len(patterns)} layers")
print(f"Pattern shape: {patterns[0].shape}")

In [None]:
# Find Same-Bit Lookup heads
# Theory predicts: attention at offset -(n+1) for retrieving corresponding bit

target_offset = -(n_bits + 1)
print(f"Target offset for Same-Bit Lookup: {target_offset}")
print(f"(For n={n_bits} bits, previous bit is at position -(n+1)={target_offset})")
print()

# Compute diagonal scores for each head
print("Diagonal Attention Scores:")
print("="*50)

for layer_idx, layer_attn in enumerate(patterns):
    print(f"\nLayer {layer_idx}:")
    for head_idx in range(layer_attn.shape[1]):
        head_attn = layer_attn[0, head_idx].cpu()
        score = compute_diagonal_score(head_attn, offset=target_offset)
        bar = '█' * int(score * 20)
        print(f"  Head {head_idx}: {score:.3f} {bar}")

## 4. Visualization <a id='4-visualization'></a>

In [None]:
from src.analysis.visualization import (
    plot_grokking_dynamics,
    plot_weight_norm,
    plot_attention_heatmap
)

In [None]:
# Plot grokking dynamics (Figure 3 from paper)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Accuracy dynamics
ax1 = axes[0]
ax1.plot(history['steps'], history['train_acc'], 'b-', label='Train Accuracy', linewidth=2)
ax1.plot(history['steps'], history['test_acc'], 'g-', label='Test Accuracy', linewidth=2)
ax1.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)
ax1.set_xlabel('Training Steps', fontsize=12)
ax1.set_ylabel('Sequence Accuracy', fontsize=12)
ax1.set_title('Grokking Dynamics: Delayed Generalization', fontsize=14)
ax1.legend(loc='lower right')
ax1.set_xscale('log')
ax1.grid(True, alpha=0.3)

# Weight norm (complexity collapse)
ax2 = axes[1]
ax2.plot(history['steps'], history['weight_norm'], 'purple', linewidth=2)
ax2.set_xlabel('Training Steps', fontsize=12)
ax2.set_ylabel('L2 Weight Norm ||θ||₂', fontsize=12)
ax2.set_title('Mechanism: Complexity Collapse', fontsize=14)
ax2.set_xscale('log')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nKey observations:")
print("• Train accuracy saturates early (memorization phase)")
print("• Test accuracy jumps suddenly (grokking transition)")
print("• Weight norm drops sharply at transition (complexity collapse)")

In [None]:
# Visualize attention pattern (Figure 4 from paper)
# Find the head with highest diagonal score (Same-Bit Lookup head)

best_score = 0
best_layer, best_head = 0, 0

for layer_idx, layer_attn in enumerate(patterns):
    for head_idx in range(layer_attn.shape[1]):
        score = compute_diagonal_score(layer_attn[0, head_idx].cpu(), offset=target_offset)
        if score > best_score:
            best_score = score
            best_layer, best_head = layer_idx, head_idx

print(f"Best Same-Bit Lookup head: Layer {best_layer}, Head {best_head} (score: {best_score:.3f})")

# Plot the attention pattern
fig, ax = plt.subplots(figsize=(10, 8))

attn = patterns[best_layer][0, best_head].cpu().numpy()
im = ax.imshow(attn, cmap='Blues', aspect='auto')

ax.set_xlabel('Source Position (Input N_i)', fontsize=12)
ax.set_ylabel('Target Position (Output N_{i+1})', fontsize=12)
ax.set_title(f'Attention Pattern: Layer {best_layer}, Head {best_head}\n(Same-Bit Lookup at offset {target_offset})', fontsize=14)

# Mark the theoretical diagonal
for i in range(abs(target_offset), attn.shape[0]):
    ax.scatter(i + target_offset, i, marker='o', s=30, c='red', alpha=0.5)

plt.colorbar(im, ax=ax, label='Attention Weight')
plt.tight_layout()
plt.show()

print("\nRed dots: Theoretical Same-Bit Lookup positions (offset = -(n+1))")
print("High attention along this diagonal confirms the B-RASP circuit!")

## 5. RNN Baseline Comparison <a id='5-baselines'></a>

Compare with RNN baselines to demonstrate the succinctness gap.

In [None]:
from src.models.rnn import compare_model_sizes

# Compare model sizes
sizes = compare_model_sizes()

print("Model Parameter Comparison:")
print("="*50)
for name, params in sizes.items():
    print(f"{name:20s}: {params:>10,} parameters")

print("\nKey insight:")
print("RNN with 30x more parameters still fails!")
print("This is the SUCCINCTNESS GAP.")

In [None]:
# Quick RNN test (will fail)
rnn_config = RNNConfig(
    vocab_size=train_dataset.tokenizer.vocab_size,
    hidden_dim=2048,  # Even large RNN fails
    n_layers=2,
    model_type='lstm'
)

rnn_model = RNNLM(rnn_config).to(device)
rnn_params = sum(p.numel() for p in rnn_model.parameters())

print(f"LSTM Parameters: {rnn_params:,}")
print(f"Transformer Parameters: {n_params:,}")
print(f"RNN is {rnn_params/n_params:.1f}x larger!")

In [None]:
# Visualize the succinctness gap (Figure 8 from paper)
fig, ax = plt.subplots(figsize=(10, 6))

models = ['Transformer\n(d=64)', 'LSTM\n(d=64)', 'LSTM\n(d=256)', 'LSTM\n(d=1024)', 'LSTM\n(d=2048)', 'GRU\n(d=2048)']
accuracies = [100.0, 0.0, 0.0, 4.2, 5.8, 5.8]  # Paper results
colors = ['green'] + ['red'] * 5

bars = ax.bar(models, accuracies, color=colors, alpha=0.7, edgecolor='black')

# Annotate
for bar, acc in zip(bars, accuracies):
    ax.annotate(f'{acc:.1f}%', 
                xy=(bar.get_x() + bar.get_width()/2, bar.get_height()),
                ha='center', va='bottom', fontsize=12, fontweight='bold')

# Add note box
ax.annotate('Comparison on n=20 bits\nRNNs fail to generalize\neven with 30x parameters', 
            xy=(0.7, 0.7), xycoords='axes fraction',
            fontsize=10, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

ax.set_ylabel('Sequence-Level Test Accuracy (%)', fontsize=12)
ax.set_title('The Succinctness Gap: Transformers vs RNNs', fontsize=14)
ax.set_ylim(0, 110)
ax.axhline(y=100, color='gray', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.show()

## Summary

### Key Findings

1. **Succinctness Gap**: Transformers (50K params) achieve 100% accuracy while RNNs (1.5M params) fail (<6%)

2. **Grokking Phase Transition**: The succinct circuit emerges via:
   - Memorization phase (high weight norm)
   - Complexity collapse (weight norm drops)
   - Generalization phase (100% test accuracy)

3. **Mechanistic Alignment with B-RASP**:
   - Same-Bit Lookup heads with precise offset -(n+1)
   - MLPs implementing XOR/AND logic

### Citation

If you use this code, please cite:

```bibtex
@inproceedings{anonymous2026grokking,
    title={Do Transformers Grok Succinct Algorithms? Mechanistic Evidence for Counting Circuits},
    author={Anonymous},
    booktitle={Proceedings of ACL 2026},
    year={2026}
}
```