In [None]:
import torch as t
from torch import Tensor
import torch.nn.functional as F
import numpy as np

from pathlib import Path
import os
import sys

import plotly.express as px
import plotly.graph_objects as go

from functools import *
import gdown
from typing import List, Tuple, Union, Optional
from fancy_einsum import einsum
import einops
from jaxtyping import Float, Int
from tqdm import tqdm

from transformer_lens import utils, ActivationCache, HookedTransformer, HookedTransformerConfig
from transformer_lens.hook_points import HookPoint
from transformer_lens.components import LayerNorm
from src.circuit import *
from src.fourier import *
from src.my_utils import *
from src.train import *


# Setup


In [None]:
root = ('Grokking' / 'saved_runs').resolve()
large_root = ('Grokking' / 'large_files').resolve()

In [None]:
p = 113

cfg = HookedTransformerConfig(
    n_layers = 1,
    d_vocab = p+1,
    d_model = 128,
    d_mlp = 4 * 128,
    n_heads = 4,
    d_head = 128 // 4,
    n_ctx = 3,
    act_fn = "relu",
    normalization_type = None,
    device = device
)

model = HookedTransformer(cfg)

In [None]:
if not large_root.exists():
    !git clone https://github.com/neelnanda-io/Grokking.git
    os.mkdir(large_root)

full_run_data_path = (large_root / "full_run_data.pth").resolve()
if not full_run_data_path.exists():
    url = "https://drive.google.com/uc?id=12pmgxpTHLDzSNMbMCuAMXP1lE_XiCQRy"
    output = str(full_run_data_path)
    gdown.download(url, output)

In [None]:
full_run_data = t.load(full_run_data_path)
state_dict = full_run_data["state_dicts"][400]

model = load_in_state_dict(model, state_dict)

In [None]:
# Helper variables
W_O = model.W_O[0]
W_K = model.W_K[0]
W_Q = model.W_Q[0]
W_V = model.W_V[0]
W_in = model.W_in[0]
W_out = model.W_out[0]
W_pos = model.W_pos
W_E = model.W_E[:-1]
final_pos_resid_initial = model.W_E[-1] + W_pos[2]
W_U = model.W_U[:, :-1]

In [None]:
all_data = t.tensor([(i, j, p) for i in range(p) for j in range(p)]).to(device)
labels = t.tensor([fn(i, j) for i, j, _ in all_data]).to(device)
original_logits, cache = model.run_with_cache(all_data)
# Final position only, also remove the logits for `=`
original_logits = original_logits[:, -1, :-1]
original_loss = cross_entropy_high_precision(original_logits, labels)
print(f"Original loss: {original_loss.item()}")

In [None]:
W_logit = W_out @ W_U

W_OV = W_V @ W_O
W_neur = W_E @ W_OV @ W_in

W_QK = W_Q @ W_K.transpose(-1, -2)
W_attn = final_pos_resid_initial @ W_QK @ W_E.T / (cfg.d_head ** 0.5)

Getting activations

In [None]:
attn_mat = cache['pattern', 0][:, :, 2]
neuron_acts_post = cache['post', 0][:, -1]
neuron_acts_pre = cache['pre', 0][:, -1]

# Periodicity

**Attention pattern**

In [None]:
attn_mat = attn_mat[:, :, :2]
# Note, we ignore attn from 2 -> 2

attn_mat_sq = einops.rearrange(attn_mat, "(x y) head seq -> x y head seq", x=p)
# We rearranged attn_mat, so the first two dims represent (x, y) in modular arithmetic equation

inputs_heatmap(
    attn_mat_sq[..., 0],
    title=f'Attention score for heads at position 0',
    animation_frame=2,
    animation_name='head'
)

**Neuron Activation**

In [None]:
neuron_acts_post_sq = einops.rearrange(neuron_acts_post, "(x y) d_mlp -> x y d_mlp", x=p)
neuron_acts_pre_sq = einops.rearrange(neuron_acts_pre, "(x y) d_mlp -> x y d_mlp", x=p)
# We rearranged activations, so the first two dims represent (x, y) in modular arithmetic equation

top_k = 3
inputs_heatmap(
    neuron_acts_post_sq[..., :top_k],
    title=f'Activations for first {top_k} neurons',
    animation_frame=2,
    animation_name='Neuron'
)

**Effective weights:**

**$W_{neur}$**

In [None]:
top_k = 5
animate_multi_lines(
    W_neur[..., :top_k],
    y_index = [f'head {hi}' for hi in range(4)],
    labels = {'x':'Input token', 'value':'Contribution to neuron'},
    snapshot='Neuron',
    title=f'Contribution to first {top_k} neurons via OV-circuit of heads (not weighted by attention)'
)

**$W_{attn}$**

In [None]:
lines(
    W_attn,
    labels = [f'head {hi}' for hi in range(4)],
    xaxis='Input token',
    yaxis='Contribution to attn score',
    title=f'Contribution to attention score (pre-softmax) for each head'
)

### Fourier Transform

In [None]:
fourier_basis, fourier_basis_names = make_fourier_basis(p)

animate_lines(
    fourier_basis,
    snapshot_index=fourier_basis_names,
    snapshot='Fourier Component',
    title='Graphs of Fourier Components (Use Slider)'
)

In [None]:
imshow(fourier_basis @ fourier_basis.T)

### Activations in fourier space

**Attention Matrix**

In [None]:
# Apply Fourier transformation
attn_mat_fourier_basis = fft2d(attn_mat_sq, fourier_basis)

# Plot results
imshow_fourier(
    attn_mat_fourier_basis[..., 0],
    title=f'Attention score for heads at position 0, in Fourier basis',
    animation_frame=2,
    animation_name='head'
)

**Neuron Activations**

In [None]:
neuron_acts_post_fourier_basis = fft2d(neuron_acts_post_sq, fourier_basis)

top_k = 3
imshow_fourier(
    neuron_acts_post_fourier_basis[..., :top_k],
    title=f'Activations for first {top_k} neurons',
    animation_frame=2,
    animation_name='Neuron'
)

In [None]:
W_neur_fourier = fft1d_given_dim(W_neur, dim=1)

top_k = 5
animate_multi_lines(
    W_neur_fourier[..., :top_k],
    y_index = [f'head {hi}' for hi in range(4)],
    labels = {'x':'Fourier component', 'value':'Contribution to neuron'},
    snapshot='Neuron',
    hover=fourier_basis_names,
    title=f'Contribution to first {top_k} neurons via OV-circuit of heads (not weighted by attn), in Fourier basis'
)

In [None]:
lines(
    fft1d(W_attn),
    labels = [f'head {hi}' for hi in range(4)],
    xaxis='Input token',
    yaxis = 'Contribution to attn score',
    title=f'Contribution to attn score (pre-softmax) for each head, in Fourier Basis',
    hover=fourier_basis_names
)

## Circuit Analysis

### Quadratic Terms

In [None]:
line(
    (fourier_basis @ W_E).pow(2).sum(1),
    hover=fourier_basis_names,
    title='Norm of embedding of each Fourier Component',
    xaxis='Fourier Component',
    yaxis='Norm'
)

In [None]:
neuron_acts_centered = neuron_acts_post_sq - neuron_acts_post_sq.mean((0, 1), keepdim=True)

# Take 2D Fourier transform
neuron_acts_centered_fourier = fft2d(neuron_acts_centered, fourier_basis)


imshow_fourier(
    neuron_acts_centered_fourier.pow(2).mean(-1),
    title=f"Norms of 2D Fourier components of centered neuron activations",
)

### Neuron Clusters

In [None]:
neuron_freqs, neuron_frac_explained = find_neuron_freqs(neuron_acts_centered_fourier)
key_freqs, neuron_freq_counts = t.unique(neuron_freqs, return_counts=True)

In [None]:
fraction_of_activations_positive_at_posn2 = (cache['pre', 0][:, -1] > 0).float().mean(0)

scatter(
    x=neuron_freqs,
    y=neuron_frac_explained,
    xaxis="Neuron frequency",
    yaxis="Frac explained",
    colorbar_title="Frac positive",
    title="Fraction of neuron activations explained by key freq",
    color=utils.to_numpy(fraction_of_activations_positive_at_posn2)
)

In [None]:
# To represent that they are in a special sixth cluster, we set the frequency of these neurons to -1
neuron_freqs[neuron_frac_explained < 0.85] = -1.
key_freqs_plus = t.concatenate([key_freqs, -key_freqs.new_ones((1,))])

for i, k in enumerate(key_freqs_plus):
    print(f'Cluster {i}: freq k={k}, {(neuron_freqs==k).sum()} neurons')

In [None]:
fourier_norms_in_each_cluster = []
for freq in key_freqs:
    fourier_norms_in_each_cluster.append(
        einops.reduce(
            neuron_acts_centered_fourier.pow(2)[..., neuron_freqs==freq],
            'batch_y batch_x neuron -> batch_y batch_x',
            'mean'
        )
    )

imshow_fourier(
    t.stack(fourier_norms_in_each_cluster),
    title=f'Norm of 2D Fourier components of neuron activations in each cluster',
    facet_col=0,
    facet_labels=[f"Freq={freq}" for freq in key_freqs]
)

In [None]:
logits_in_freqs = []

for freq in key_freqs:

    # Get all neuron activations corresponding to this frequency
    filtered_neuron_acts = neuron_acts_post[:, neuron_freqs==freq]

    # Project onto const/linear/quadratic terms in 2D Fourier basis
    filtered_neuron_acts_in_freq = project_onto_frequency(filtered_neuron_acts, freq)

    # Calcluate new logits, from these filtered neuron activations
    logits_in_freq = filtered_neuron_acts_in_freq @ W_logit[neuron_freqs==freq]

    logits_in_freqs.append(logits_in_freq)

# We add on neurons in the always firing cluster, unfiltered
logits_always_firing = neuron_acts_post[:, neuron_freqs==-1] @ W_logit[neuron_freqs==-1]
logits_in_freqs.append(logits_always_firing)

# Print new losses
print('Loss with neuron activations ONLY in key freq (inclusing always firing cluster)\n{:.6e}\n'.format(
    test_logits(
        sum(logits_in_freqs),
        bias_correction=True,
        original_logits=original_logits
    )
))
print('Loss with neuron activations ONLY in key freq (exclusing always firing cluster)\n{:.6e}\n'.format(
    test_logits(
        sum(logits_in_freqs[:-1]),
        bias_correction=True,
        original_logits=original_logits
    )
))
print('Original loss\n{:.6e}'.format(original_loss))

In [None]:
print('Loss with neuron activations excluding none:     {:.9f}'.format(original_loss.item()))
for c, freq in enumerate(key_freqs_plus):
    print('Loss with neuron activations excluding freq={}:  {:.9f}'.format(
        freq,
        test_logits(
            sum(logits_in_freqs) - logits_in_freqs[c],
            bias_correction=True,
            original_logits=original_logits
        )
    ))

### Logits

#### Logits in Fourier Basis


In [None]:
imshow_fourier(
    einops.reduce(neuron_acts_centered_fourier.pow(2), 'y x neuron -> y x', 'mean'),
    title='Norm of Fourier Components of Neuron Acts'
)

# Rearrange logits, so the first two dims represent (x, y) in modular arithmetic equation
original_logits_sq = einops.rearrange(original_logits, "(x y) z -> x y z", x=p)
original_logits_fourier = fft2d(original_logits_sq)

imshow_fourier(
    einops.reduce(original_logits_fourier.pow(2), 'y x z -> y x', 'mean'),
    title='Norm of Fourier Components of Logits'
)

In [None]:
trig_logits = []

for k in key_freqs:

    cos_xplusy_direction, sin_xplusy_direction = get_trig_sum_directions(k)

    cos_xplusy_projection = project_onto_direction(
        original_logits,
        cos_xplusy_direction.flatten()
    )

    sin_xplusy_projection = project_onto_direction(
        original_logits,
        sin_xplusy_direction.flatten()
    )

    trig_logits.extend([cos_xplusy_projection, sin_xplusy_projection])

trig_logits = sum(trig_logits)

print(f'Loss with just x+y components: {test_logits(trig_logits, True, original_logits):.4e}')
print(f"Original Loss: {original_loss:.4e}")

#### $W_{logits}$ and SVD

In [None]:
US = W_logit @ fourier_basis.T

imshow_div(
    US,
    x=fourier_basis_names,
    yaxis='Neuron index',
    title='W_logit in the Fourier Basis',
    height=800,
    width=600
)

In [None]:
US_sorted = t.concatenate([
    US[neuron_freqs==freq] for freq in key_freqs_plus
])
hline_positions = np.cumsum([(neuron_freqs == freq).sum().item() for freq in key_freqs]).tolist() + [cfg.d_mlp]

imshow_div(
    US_sorted,
    x=fourier_basis_names,
    yaxis='Neuron',
    title='W_logit in the Fourier Basis (rearranged by neuron cluster)',
    hline_positions = hline_positions,
    hline_labels = [f"Cluster: {freq=}" for freq in key_freqs.tolist()] + ["No freq"],
    height=800,
    width=600
)

In [None]:
cos_components = []
sin_components = []

for k in key_freqs:
    ﾏブ_sin = US[:, 2*k]
    ﾏブ_cos = US[:, 2*k-1]

    logits_in_cos_dir = neuron_acts_post_sq @ ﾏブ_cos
    logits_in_sin_dir = neuron_acts_post_sq @ ﾏブ_sin

    cos_components.append(fft2d(logits_in_cos_dir))
    sin_components.append(fft2d(logits_in_sin_dir))

for title, components in zip(['Cosine', 'Sine'], [cos_components, sin_components]):
    imshow_fourier(
        t.stack(components),
        title=f'{title} components of neuron activations in Fourier basis',
        animation_frame=0,
        animation_name="Frequency",
        animation_labels=key_freqs.tolist()
    )

# Anaysis during training

In [None]:
epochs = full_run_data['epochs']

# Define a dictionary to store our metrics in
metric_cache = {}
plot_metric = partial(lines, x=epochs, xaxis='Epoch', log_y=True)

In [None]:
excl_loss = partial(excl_loss, key_freqs=key_freqs)
get_metrics(model, metric_cache, excl_loss, 'excl_loss')

lines(
    t.concat([
        metric_cache['excl_loss'].T,
        metric_cache['train_loss'][None, :],
        metric_cache['test_loss'][None, :]
    ], axis=0),
    labels=[f'excl {freq}' for freq in key_freqs]+['train', 'test'],
    title='Excluded Loss for each trig component',
    log_y=True,
    x=full_run_data['epochs'],
    xaxis='Epoch',
    yaxis='Loss'
)

### Embedding in Fourier Basis

In [None]:
# Plot every 200 epochs so it's not overwhelming
get_metrics(model, metric_cache, fourier_embed, 'fourier_embed')

animate_lines(
    metric_cache['fourier_embed'][::2],
    snapshot_index = epochs[::2],
    snapshot='Epoch',
    hover=fourier_basis_names,
    animation_group='x',
    title='Norm of Fourier Components in the Embedding Over Training',
)

In [None]:
get_metrics(model, metric_cache, embed_SVD, 'embed_SVD')

animate_lines(
    metric_cache['embed_SVD'],
    snapshot_index = epochs,
    snapshot='Epoch',
    title='Singular Values of the Embedding During Training',
    xaxis='Singular Number',
    yaxis='Singular Value',
)

**Development of Trig Components**

In [None]:
for mode in ['neuron_pre', 'neuron_post', 'logit']:
    get_metrics(
        model,
        metric_cache,
        partial(tensor_trig_ratio, mode=mode),
        f"{mode}_trig_ratio",
        reset=True
    )

lines_list = []
line_labels = []
for mode in ['neuron_pre', 'neuron_post', 'logit']:
    tensor = metric_cache[f"{mode}_trig_ratio"]
    lines_list.append(einops.reduce(tensor, 'epoch index -> epoch', 'mean'))
    line_labels.append(f"{mode}_trig_frac")

plot_metric(
    lines_list,
    labels=line_labels,
    log_y=False,
    yaxis='Ratio',
    title='Fraction of logits and neurons explained by trig terms',
)

### Development of neuron activations

In [None]:
def get_frac_explained(model: HookedTransformer):
    _, cache = model.run_with_cache(all_data, return_type=None)

    returns = []

    for neuron_type in ['pre', 'post']:
        neuron_acts = cache[neuron_type, 0][:, -1].clone().detach()
        neuron_acts_centered = neuron_acts - neuron_acts.mean(0)
        neuron_acts_fourier = fft2d(
            einops.rearrange(neuron_acts_centered, "(x y) neuron -> x y neuron", x=p)
        )

        # Calculate the sum of squares over all inputs, for each neuron
        square_of_all_terms = einops.reduce(
            neuron_acts_fourier.pow(2), "x y neuron -> neuron", "sum"
        )

        frac_explained = t.zeros(d_mlp).to(device)
        frac_explained_quadratic_terms = t.zeros(d_mlp).to(device)

        for freq in key_freqs_plus:
            # Get Fourier activations for neurons in this frequency cluster
            # We arrange by frequency (i.e. each freq has a 3x3 grid with const, linear & quadratic terms)
            acts_fourier = arrange_by_2d_freqs(neuron_acts_fourier[..., neuron_freqs==freq])

            # Calculate the sum of squares over all inputs, after filtering for just this frequency
            # Also calculate the sum of squares for just the quadratic terms in this frequency
            if freq==-1:
                squares_for_this_freq = squares_for_this_freq_quadratic_terms = einops.reduce(
                    acts_fourier[:, 1:, 1:].pow(2), "freq x y neuron -> neuron", "sum"
                )
            else:
                squares_for_this_freq = einops.reduce(
                    acts_fourier[freq-1].pow(2), "x y neuron -> neuron", "sum"
                )
                squares_for_this_freq_quadratic_terms = einops.reduce(
                    acts_fourier[freq-1, 1:, 1:].pow(2), "x y neuron -> neuron", "sum"
                )

            frac_explained[neuron_freqs==freq] = squares_for_this_freq / square_of_all_terms[neuron_freqs==freq]
            frac_explained_quadratic_terms[neuron_freqs==freq] = squares_for_this_freq_quadratic_terms / square_of_all_terms[neuron_freqs==freq]

        returns.extend([frac_explained, frac_explained_quadratic_terms])

    frac_active = (neuron_acts > 0).float().mean(0)

    return t.nan_to_num(t.stack(returns + [neuron_freqs, frac_active], axis=0))


get_metrics(model, metric_cache, get_frac_explained, 'get_frac_explained')

frac_explained_pre = metric_cache['get_frac_explained'][:, 0]
frac_explained_quadratic_pre = metric_cache['get_frac_explained'][:, 1]
frac_explained_post = metric_cache['get_frac_explained'][:, 2]
frac_explained_quadratic_post = metric_cache['get_frac_explained'][:, 3]
neuron_freqs_ = metric_cache['get_frac_explained'][:, 4]
frac_active = metric_cache['get_frac_explained'][:, 5]

animate_scatter(
    t.stack([frac_explained_quadratic_pre, frac_explained_quadratic_post], dim=1)[:200:5],
    color=neuron_freqs_[:200:5],
    color_name='freq',
    snapshot='epoch',
    snapshot_index=epochs[:200:5],
    xaxis='Quad ratio pre',
    yaxis='Quad ratio post',
    color_continuous_scale='viridis',
    title='Fraction of variance explained by quadratic terms (up to epoch 20K)'
)

animate_scatter(
    t.stack([neuron_freqs_, frac_explained_pre, frac_explained_post], dim=1)[:200:5],
    color=frac_active[:200:5],
    color_name='frac_active',
    snapshot='epoch',
    snapshot_index=epochs[:200:5],
    xaxis='Freq',
    yaxis='Frac explained',
    hover=list(range(d_mlp)),
    color_continuous_scale='viridis',
    title='Fraction of variance explained by this frequency (up to epoch 20K)'
)

### Development of commutativity

In [None]:
get_metrics(model, metric_cache, avg_attn_pattern, 'avg_attn_pattern')

imshow_div(
    metric_cache['avg_attn_pattern'][::5],
    animation_frame=0,
    animation_name='head',
    title='Avg attn by position and head, snapped every 100 epochs',
    xaxis='Pos',
    yaxis='Head',
    zmax=0.5,
    zmin=0.0,
    color_continuous_scale='Blues',
    text_auto='.3f',
)

In [None]:
lines(
    (metric_cache['avg_attn_pattern'][:, :, 0]-metric_cache['avg_attn_pattern'][:, :, 1]).T,
    labels=[f"head {i}" for i in range(4)],
    x=epochs,
    xaxis='Epoch',
    yaxis='Average difference',
    title='Attention to pos 0 - pos 1 by head over training'
)

### Noise clean up

In [None]:
get_metrics(model, metric_cache, trig_loss, 'trig_loss')

trig_loss_train = partial(trig_loss, mode='train')
get_metrics(model, metric_cache, trig_loss_train, 'trig_loss_train')

line_labels = ['test_loss', 'train_loss', 'trig_loss', 'trig_loss_train']
plot_metric([metric_cache[lab] for lab in line_labels], labels=line_labels, title='Different losses over training')
plot_metric([metric_cache['test_loss']/metric_cache['trig_loss']], title='Ratio of trig and test loss')

### Development of squared sum of the weights

In [None]:
parameter_names = [name for name, param in model.named_parameters()]

def sum_sq_weights(model):
    return [param.pow(2).sum().item() for name, param in model.named_parameters()]
get_metrics(model, metric_cache, sum_sq_weights, 'sum_sq_weights')

plot_metric(
    metric_cache['sum_sq_weights'].T,
    title='Sum of squared weights for each parameter',
    # Take only the end of each parameter name for brevity
    labels=[i.split('.')[-1] for i in parameter_names],
    log_y=False
)
plot_metric(
    [einops.reduce(metric_cache['sum_sq_weights'], 'epoch param -> epoch', 'sum')],
    title='Total sum of squared weights',
    log_y=False
)