In [None]:
import os

import torch as t
import numpy as np
import einops
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download

from circuits.dictionary_learning.buffer import NNsightActivationBuffer
from circuits.dictionary_learning.dictionary import AutoEncoder, AutoEncoderNew, GatedAutoEncoder
import circuits.othello_utils as othello_utils
from circuits.utils import (
    othello_hf_dataset_to_generator,
    get_model,
    get_submodule,
)

repo_dir = '/share/u/can/chess-gpt-circuits'
device = 'cuda:0'

In [None]:
# download data from huggingface if needed
if not os.path.exists(f'{repo_dir}/autoencoders/othello_5-21'):
    hf_hub_download(repo_id='adamkarvonen/othello_saes', filename='othello_5-21.zip', local_dir=f'{repo_dir}/autoencoders')
    # unzip the data
    os.system(f'unzip {repo_dir}/autoencoders/othello_5-21.zip -d autoencoders')

In [None]:
# load SAE
ae_type = 'standard_new'
trainer_id = 0

ae_path = f'{repo_dir}/autoencoders/group-2024-05-17_othello/group-2024-05-17_othello-{ae_type}/trainer{trainer_id}'
if ae_type == 'standard':
    ae = AutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'gated':
    ae = GatedAutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'standard_new':
    ae = AutoEncoderNew.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
else:
    raise ValueError('Invalid ae_type')

In [None]:
# load model and data

layer = 5
context_length = 59
activation_dim = 512  # output dimension of the layer
model_name = "Baidicoot/Othello-GPT-Transformer-Lens"
dataset_name = "taufeeque/othellogpt"

game_batch_size = 100

model = get_model(model_name, device)

submodule = get_submodule(model_name, layer, model)
model.blocks[layer].hook_resid_post

mlp_post_submodules = [model.blocks[l].mlp.hook_post for l in range(model.cfg.n_layers)]


data = othello_hf_dataset_to_generator(
    dataset_name, context_length=context_length, split="train", streaming=True
)
data = othello_hf_dataset_to_generator(
    dataset_name, context_length=context_length, split="train", streaming=True
)
game_batch = [next(data) for _ in range(game_batch_size)]
game_batch = t.tensor(game_batch, device=device)
print(f'game_batch: {len(game_batch)}')

## Single SAE feature ~ all MLP neurons
Fixing a single SAE feature, which MLP neurons (in earlier and later layers) show high pearson correlation with the SAE feature?

From feature viz notebook: Feature #21 of f'{repo_dir}/autoencoders/group-2024-05-17_othello/group-2024-05-17_othello-{standard_new}/trainer{0}' looks like it is representing a piece on H1 or G1

<img src="./feat21.png" alt="Image description" width="800"/>

In [None]:
feat_idx = 21

In [None]:
# Caching activations with nnsight
mlp_acts = {}

with t.no_grad(), model.trace(game_batch, scan=False, validate=False):
    x = submodule.output
    feature_acts = ae.encode(x).save()
    for l in range(model.cfg.n_layers):
        mlp_acts[l] = mlp_post_submodules[l].output.save()

feature_acts = einops.rearrange(feature_acts, "B S F -> F (B S)")
feature_acts = feature_acts[feat_idx]

for l in mlp_acts:
    mlp_acts[l] = einops.rearrange(mlp_acts[l], "B S F -> F (B S)")

In [None]:
# Pearson correlation calculation function
def pearson_corr(x, y):
    mean_x = x.mean(dim=-1, keepdim=True)
    mean_y = y.mean(dim=-1, keepdim=True)
    xm = x - mean_x
    ym = y - mean_y
    r_num = t.sum(xm * ym, dim=-1)
    r_den = t.sqrt(t.sum(xm * xm, dim=-1) * t.sum(ym * ym, dim=-1))
    r = r_num / r_den
    return r

# Calculate Pearson correlation
pearson_correlations = {}

for l in mlp_acts:
    mlp_acts_layer = mlp_acts[l]
    correlations = t.zeros(mlp_acts_layer.shape[0])
    for i in range(mlp_acts_layer.shape[0]):
        mlp_feature = mlp_acts_layer[i]
        corr = pearson_corr(feature_acts, mlp_feature)
        correlations[i] = corr
    pearson_correlations[l] = correlations

In [None]:
# Prepare data
layers = list(pearson_correlations.keys())
data = [pearson_correlations[l].abs() for l in layers]

# Create stacked histogram
plt.hist(data, bins=100, histtype='bar', stacked=True, label=layers)

# Add legend and log scale for y-axis
plt.legend(title='Layer')
plt.yscale('log')

# Display plot
plt.xlabel('Absolute Pearson Correlation')
plt.ylabel('Frequency')
plt.title('Stacked Histogram of Pearson Correlations by Layer')
plt.show()

In [None]:
# save indices and layer for pearson_correlations above a certain threshold
corr_threshold = 0.5
indices = {}
for l in pearson_correlations:
    indices[l] = t.where(pearson_correlations[l].abs() > corr_threshold)[0]

indices