In [None]:
!pwd

In [None]:
from nnsight import LanguageModel
import torch

from dictionary_learning import ActivationBuffer
from dictionary_learning.training import trainSAE
from circuits.nanogpt_to_hf_transformers import NanogptTokenizer, convert_nanogpt_model
from dictionary_learning.utils import hf_dataset_to_generator
from dictionary_learning.trainers.standard import StandardTrainer

In [None]:
DEVICE = torch.device("cuda")

tokenizer = NanogptTokenizer()
model = convert_nanogpt_model("lichess_8layers_ckpt_no_optimizer.pt", torch.device(DEVICE))
model = LanguageModel(model, device_map=DEVICE, tokenizer=tokenizer)

submodule = model.transformer.h[5].mlp  # layer 1 MLP
activation_dim = 512  # output dimension of the MLP
dictionary_size = 8 * activation_dim

batch_size = 8

data = hf_dataset_to_generator("adamkarvonen/chess_sae_test", streaming=False)
buffer = ActivationBuffer(
    data,
    model,
    submodule,
    n_ctxs=512,
    ctx_len=256,
    refresh_batch_size=4,
    io="out",
    d_submodule=512,
    device=DEVICE,
    out_batch_size=batch_size,
)

In [None]:
from dictionary_learning import AutoEncoder

ae = AutoEncoder.from_pretrained("t1_ae.pt", device=DEVICE)

In [None]:
@torch.no_grad()
def get_feature(
    activations,
    ae: AutoEncoder,
    device,
):
    try:
        x = next(activations).to(device)
    except StopIteration:
        raise StopIteration(
            "Not enough activations in buffer. Pass a buffer with a smaller batch size or more data."
        )
    # print(x.shape)

    print(x.shape)

    x_hat, f = ae(x, output_features=True)

    # print(x_hat.shape, f.shape)

    return f

    return f.mean(0)
    # batch_size, seq_len = tokens.shape

    # logits, cache = model.run_with_cache(tokens, names_filter = ["blocks.0.mlp.hook_post"])
    # post = cache["blocks.0.mlp.hook_post"]
    # assert post.shape == (batch_size, seq_len, model.cfg.d_mlp)

    # post_reshaped = einops.repeat(post, "batch seq d_mlp -> (batch seq) instances d_mlp", instances=2)
    # assert post_reshaped.shape == (batch_size * seq_len, 2, model.cfg.d_mlp)

    # acts = autoencoder.forward(post_reshaped)[3]
    # assert acts.shape == (batch_size * seq_len, 2, autoencoder.cfg.n_hidden_ae)

    # return acts.mean(0)
num_iters = 1024
seq_len = 4096

features = torch.zeros((batch_size*num_iters, seq_len), device=DEVICE)
probs = []
for i in range(num_iters):
    feature = get_feature(buffer, ae, DEVICE)
    prob = feature.mean(0)
    features[i*batch_size:(i+1)*batch_size, :] = feature
    probs.append(prob)
    # print(i)

# l0 = (f != 0).float().sum(dim=-1).mean()
feat_prob = sum(probs) / len(probs)
print(feat_prob.shape)
log_freq = (feat_prob + 1e-10).log10()
print(log_freq.shape)

In [None]:
print(features.shape)
l0 = (features != 0).float().sum(dim=-1)#.mean()
print(l0.mean())
l0 /= num_iters * batch_size
print(l0.shape)
print(l0)

print(l0.mean())

In [None]:
mask = (l0 > 0) & (l0 < 0.5)
idx = torch.nonzero(mask, as_tuple=False).squeeze()
print(idx.shape)
print(idx[:10])

In [None]:
l0_log = l0
import matplotlib.pyplot as plt
lo_log_np = l0_log.cpu().numpy()

# Creating the histogram
plt.figure(figsize=(10, 6))
plt.hist(lo_log_np, bins=50, alpha=0.75, color='blue')
plt.title('Histogram of log10 of Feature Probabilities')
plt.xlabel('log10(Probability)')
plt.ylabel('Frequency')
plt.grid(True)
plt.show()

In [None]:
import matplotlib.pyplot as plt
log_freq_np = log_freq.cpu().numpy()
# log_freq_np = feat_prob.cpu()

# Creating the histogram
plt.figure(figsize=(10, 6))
plt.hist(log_freq_np, bins=50, alpha=0.75, color='blue')
plt.title('Histogram of log10 of Feature Probabilities')
plt.xlabel('log10(Probability)')
plt.ylabel('Frequency')
plt.grid(True)
plt.show()

In [None]:
torch.cuda.empty_cache()

In [None]:
print(idx[:20])
interest = 19
print(l0[interest])

In [None]:
import random
from circuitsvis.activations import text_neuron_activations
from einops import rearrange
import torch as t
from collections import namedtuple
import umap
import pandas as pd
import plotly.express as px



In [None]:
from dictionary_learning.interp import examine_dimension

top_contexts, top_tokens = examine_dimension(model, submodule, buffer, dictionary=ae, dim_idx=interest, n_inputs=500, k=30, batch_size=4, device=DEVICE)