In [None]:
import os

import torch as t
import numpy as np
import einops
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
import pickle
import torch.nn.functional as F
from tqdm import tqdm

from circuits.dictionary_learning.buffer import NNsightActivationBuffer
from circuits.dictionary_learning.dictionary import AutoEncoder, AutoEncoderNew, GatedAutoEncoder
import circuits.othello_utils as othello_utils
from circuits.utils import (
    othello_hf_dataset_to_generator,
    get_model,
    get_submodule,
)

import circuits.analysis as analysis
import circuits.utils as utils
import circuits.othello_utils as othello_utils
import circuits.chess_utils as chess_utils

repo_dir = "/home/adam/chess-gpt-circuits/"
# repo_dir = '/share/u/can/chess-gpt-circuits'
device = 'cuda:0'

othello = True

In [None]:
if othello:
    # download data from huggingface if needed
    if not os.path.exists(f'{repo_dir}/autoencoders/othello_5-21'):
        hf_hub_download(repo_id='adamkarvonen/othello_saes', filename='othello_5-21.zip', local_dir=f'{repo_dir}/autoencoders')
        # unzip the data
        os.system(f'unzip {repo_dir}/autoencoders/othello_5-21.zip -d autoencoders')

In [None]:
# load SAE
ae_type = 'p_anneal'
trainer_id = 4

if othello:
    ae_path = f'{repo_dir}/autoencoders/othello_5-21/othello-{ae_type}/trainer{trainer_id}'
else:
    ae_path = f'{repo_dir}/autoencoders/chess-trained_model-layer_5-2024-05-23/chess-trained_model-layer_5-{ae_type}/trainer{trainer_id}'


if ae_type == 'standard' or ae_type == 'p_anneal':
    ae = AutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'gated' or ae_type == 'gated_anneal':
    ae = GatedAutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'standard_new':
    ae = AutoEncoderNew.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
else:
    raise ValueError('Invalid ae_type')

print(ae.encoder.weight.shape)
ae_hidden_dim = ae.encoder.weight.shape[0]
d_model = ae.encoder.weight.shape[1]

t.set_grad_enabled(False)

In [None]:
# load model
d_model = 512  # output dimension of the layer
layer = 5

if not othello:
    with open("models/meta.pkl", "rb") as f:
        meta = pickle.load(f)

    context_length = 256
    model_name = "adamkarvonen/8LayerChessGPT2"
    dataset_name = "adamkarvonen/chess_sae_text"
    data = utils.chess_hf_dataset_to_generator(
        dataset_name, meta, context_length=context_length, split="train", streaming=True
    )
    model_type = "chess"
else:
    context_length = 59
    model_name = "Baidicoot/Othello-GPT-Transformer-Lens"
    dataset_name = "taufeeque/othellogpt"
    data = utils.othello_hf_dataset_to_generator(
        dataset_name, context_length=context_length, split="train", streaming=True
    )
    model_type = "othello"

model = utils.get_model(model_name, device)
submodule = utils.get_submodule(model_name, layer, model)

mlp_post_submodules = [model.blocks[layer].mlp.hook_post for layer in range(model.cfg.n_layers)]

batch_size = 8
total_games_size = batch_size * 10

buffer = NNsightActivationBuffer(
    data,
    model,
    submodule,
    n_ctxs=8e3,
    ctx_len=context_length,
    refresh_batch_size=batch_size,
    io="out",
    d_submodule=d_model,
    device=device,
)

## Single SAE feature ~ all MLP neurons
Fixing a single SAE feature, which MLP neurons (in earlier and later layers) show high pearson correlation with the SAE feature?

From feature viz notebook: Feature #21 of f'{repo_dir}/autoencoders/group-2024-05-17_othello/group-2024-05-17_othello-{standard_new}/trainer{0}' looks like it is representing a piece on H1 or G1

<img src="./feat21.png" alt="Image description" width="800"/>

In [None]:
def get_cosine_similarities_for_mlp_neuron(model, feat_idx: int, layer: int, y_vectors_DF: t.Tensor) -> t.Tensor:
    d_model_vec = model.blocks[layer].mlp.W_out[feat_idx, :]

    y_vectors_FD = einops.rearrange(y_vectors_DF, 'd f -> f d')

    cosine_similarities = F.cosine_similarity(d_model_vec, y_vectors_FD)

    return cosine_similarities

def get_max_cos_sim_for_all_mlp_neurons(model, layer: int, y_vectors_DF: t.Tensor) -> t.Tensor:
    max_cos_sims = []

    for neuron_idx in range(model.blocks[layer].mlp.W_out.shape[0]):
        max_cos_sims.append(get_cosine_similarities_for_mlp_neuron(model, neuron_idx, layer, y_vectors_DF).max())
    return t.stack(max_cos_sims)

def get_cosine_similarities_for_ae_decoder_neuron(ae, feat_idx: int, layer: int, y_vectors_FD: t.Tensor) -> t.Tensor:
    d_model_vec = ae.decoder.weight[:, feat_idx]
    return F.cosine_similarity(d_model_vec, y_vectors_FD)

def get_max_cos_sim_for_all_ae_decoder_neurons(ae, layer: int, y_vectors_FD: t.Tensor) -> t.Tensor:
    max_cos_sims = []

    for neuron_idx in range(ae.decoder.weight.shape[1]):
        max_cos_sims.append(get_cosine_similarities_for_ae_decoder_neuron(ae, neuron_idx, layer, y_vectors_FD).max())
    return t.stack(max_cos_sims)

layer = 5

# cos sims for an individual neuron
mlp_neuron_idx = 1024
cos_sims_1024 = get_cosine_similarities_for_mlp_neuron(model, mlp_neuron_idx, layer, ae.decoder.weight)
print(f'Cos sim of mlp neuron {mlp_neuron_idx} with SAE feature 2192: {cos_sims_1024[2192]}')
print(f'Cos sim of mlp neuron {mlp_neuron_idx} with SAE feature 3098: {cos_sims_1024[3098]}')
print(f'Maximum cosine similarity of mlp neuron {mlp_neuron_idx}: {cos_sims_1024.max()}')

# max cos sims for all mlp neurons
max_cos_sims = get_max_cos_sim_for_all_ae_decoder_neurons(ae, layer, ae.encoder.weight)
average_max_cos_sim = max_cos_sims.mean()
print(average_max_cos_sim)

print(len(max_cos_sims))
print(t.sum(t.tensor(max_cos_sims) > 0.5).item())

plt.title(f'SAE feature max cosine similarities with MLP neurons')
plt.hist(max_cos_sims.cpu().numpy(), bins=100)

In [None]:
# This cell caches mlp activations and SAE feature activations for all games in the dataset

mlp_acts_bLF = {}
for layer in range(model.cfg.n_layers):
    mlp_acts_bLF[layer] = t.zeros((total_games_size, context_length, d_model * 4), dtype=t.float32, device=device)

tokens_bL = t.zeros((total_games_size, context_length), dtype=t.int16, device=device)
feature_activations_bLF = t.zeros((total_games_size, context_length, ae_hidden_dim), dtype=t.float32, device=device)

for i in range(0, total_games_size, batch_size):
    game_batch_BL = [next(data) for _ in range(batch_size)]
    game_batch_BL = t.tensor(game_batch_BL, device=device)
    with t.no_grad(), model.trace(game_batch_BL, scan=False, validate=False):
        x_BLD = submodule.output
        feature_acts_BLF = ae.encode(x_BLD).save()
        for layer in range(model.cfg.n_layers):
            mlp_acts_bLF[layer][i:i+batch_size] = mlp_post_submodules[layer].output.save()
    tokens_bL[i:i+batch_size] = game_batch_BL
    feature_activations_bLF[i:i+batch_size] = feature_acts_BLF

feature_activations_Fb = einops.rearrange(feature_activations_bLF, "B S F -> F (B S)")

for layer in mlp_acts_bLF:
    mlp_acts_bLF[layer] = einops.rearrange(mlp_acts_bLF[layer], "B S F -> F (B S)")
mlp_acts_Fb = {layer: mlp_acts_bLF[layer] for layer in mlp_acts_bLF}

In [None]:
# Pearson correlation calculation function
def pearson_corr(x, y):
    mean_x = x.mean(dim=-1, keepdim=True)
    mean_y = y.mean(dim=-1, keepdim=True)
    xm = x - mean_x
    ym = y - mean_y
    r_num = t.sum(xm * ym, dim=-1)
    r_den = t.sqrt(t.sum(xm * xm, dim=-1) * t.sum(ym * ym, dim=-1))
    
    with t.no_grad():
        zero_variance = (r_den == 0)
    r = t.where(zero_variance, t.zeros_like(r_num), r_num / r_den)

    return r


def get_correlation_for_activation(
    x_activations_Fb: t.Tensor, x_activation_index: int, y_activations_Fb: t.Tensor
) -> t.Tensor:
    x_activation_b = x_activations_Fb[x_activation_index]
    correlations = t.zeros(y_activations_Fb.shape[0])

    for i in range(y_activations_Fb.shape[0]):
        y_activation_b = y_activations_Fb[i]
        corr = pearson_corr(x_activation_b, y_activation_b)
        correlations[i] = corr

    return correlations

mlp_neuron_idx = 1024

feat_idx = 21

# sae_correlations = get_correlation_for_activation(feature_activations_Fb, feat_idx, mlp_acts_Fb[5])
mlp_correlations = get_correlation_for_activation(mlp_acts_Fb[5], mlp_neuron_idx, feature_activations_Fb)
sae_correlations = get_correlation_for_activation(feature_activations_Fb, feat_idx, feature_activations_Fb)

# Calculate Pearson correlation
# pearson_correlations = {}

# for layer in mlp_acts:
#     mlp_acts_layer = mlp_acts[layer]
#     correlations = t.zeros(mlp_acts_layer.shape[0])
#     for i in range(mlp_acts_layer.shape[0]):
#         mlp_feature = mlp_acts_layer[i]
#         corr = pearson_corr(feature_acts_BLF, mlp_feature)
#         correlations[i] = corr
#     pearson_correlations[layer] = correlations

In [None]:
print(max_cors_sae)
# print average correlation

# remove all 0s
max_cors_sae_filtered = [x for x in max_cors_sae if x != 0]

print(t.mean(t.tensor(max_cors_sae_filtered)).item())

plt.title('Max correlation of any MLP neuron with each SAE feature')
plt.hist(max_cors_sae_filtered, bins=100)

In [None]:
# Number of matches with correlation > 0.5
print(len(max_cors_sae_filtered))
print(t.sum(t.tensor(max_cors_sae_filtered) > 0.5).item())

In [None]:
print(mlp_correlations.shape)


def analyze_correlations(correlations: t.Tensor):
    
    k = 20
    values, indices = t.topk(correlations, k)

    # Printing the top n values and their corresponding indices
    for index, value in zip(indices, values):
        print(f"Index: {index}, Value: {value.item()}")
    print()

    plt.hist(correlations.cpu().numpy(), bins=100)
    plt.show()


analyze_correlations(sae_correlations)
analyze_correlations(mlp_correlations)

In [None]:
print(mlp_acts_Fb[5].shape)

max_cors_sae = []

# for i in tqdm(range(mlp_acts_Fb[5].shape[0])):
#     mlp_correlations = get_correlation_for_activation(mlp_acts_Fb[5], i, feature_activations_Fb)
#     max_correlation = t.max(mlp_correlations)
#     max_cors.append(max_correlation.item())

for i in tqdm(range(feature_activations_Fb.shape[0])):
    sae_correlations = get_correlation_for_activation(feature_activations_Fb, i, mlp_acts_Fb[5])
    max_correlation = t.max(sae_correlations)
    max_cors_sae.append(max_correlation.item())


In [None]:
# # Prepare data
# layers = list(pearson_correlations.keys())
# data = [pearson_correlations[l].abs() for l in layers]

# # Create stacked histogram
# plt.hist(data, bins=100, histtype='bar', stacked=True, label=layers)

# # Add legend and log scale for y-axis
# plt.legend(title='Layer')
# plt.yscale('log')

# # Display plot
# plt.xlabel('Absolute Pearson Correlation')
# plt.ylabel('Frequency')
# plt.title('Stacked Histogram of Pearson Correlations by Layer')
# plt.show()

In [None]:
# # save indices and layer for pearson_correlations above a certain threshold
# corr_threshold = 0.5
# indices = {}
# for layer in pearson_correlations:
#     indices[layer] = t.where(pearson_correlations[layer].abs() > corr_threshold)[0]

# indices