In [None]:
!pwd

In [None]:
from nnsight import LanguageModel
import torch

from dictionary_learning import ActivationBuffer
from dictionary_learning.training import trainSAE
from circuits.nanogpt_to_hf_transformers import NanogptTokenizer, convert_nanogpt_model
from dictionary_learning.utils import hf_dataset_to_generator
from dictionary_learning.trainers.standard import StandardTrainer

In [None]:
DEVICE = torch.device("cuda")

tokenizer = NanogptTokenizer()
model = convert_nanogpt_model("lichess_8layers_ckpt_no_optimizer.pt", torch.device(DEVICE))
model = LanguageModel(model, device_map=DEVICE, tokenizer=tokenizer)

submodule = model.transformer.h[5].mlp  # layer 1 MLP
activation_dim = 512  # output dimension of the MLP
dictionary_size = 8 * activation_dim

batch_size = 8

data = hf_dataset_to_generator("adamkarvonen/chess_sae_test", streaming=False)
buffer = ActivationBuffer(
    data,
    model,
    submodule,
    n_ctxs=512,
    ctx_len=256,
    refresh_batch_size=4,
    io="out",
    d_submodule=512,
    device=DEVICE,
    out_batch_size=batch_size,
)

In [None]:
from dictionary_learning import AutoEncoder

ae = AutoEncoder.from_pretrained("t1_ae.pt", device=DEVICE)

In [None]:
@torch.no_grad()
def get_feature(
    activations,
    ae: AutoEncoder,
    device,
):
    try:
        x = next(activations).to(device)
    except StopIteration:
        raise StopIteration(
            "Not enough activations in buffer. Pass a buffer with a smaller batch size or more data."
        )
    # print(x.shape)

    print(x.shape)

    x_hat, f = ae(x, output_features=True)

    # print(x_hat.shape, f.shape)

    return f

    return f.mean(0)
    # batch_size, seq_len = tokens.shape

    # logits, cache = model.run_with_cache(tokens, names_filter = ["blocks.0.mlp.hook_post"])
    # post = cache["blocks.0.mlp.hook_post"]
    # assert post.shape == (batch_size, seq_len, model.cfg.d_mlp)

    # post_reshaped = einops.repeat(post, "batch seq d_mlp -> (batch seq) instances d_mlp", instances=2)
    # assert post_reshaped.shape == (batch_size * seq_len, 2, model.cfg.d_mlp)

    # acts = autoencoder.forward(post_reshaped)[3]
    # assert acts.shape == (batch_size * seq_len, 2, autoencoder.cfg.n_hidden_ae)

    # return acts.mean(0)
num_iters = 1024
seq_len = 4096

features = torch.zeros((batch_size*num_iters, seq_len), device=DEVICE)
probs = []
for i in range(num_iters):
    feature = get_feature(buffer, ae, DEVICE)
    prob = feature.mean(0)
    features[i*batch_size:(i+1)*batch_size, :] = feature
    probs.append(prob)
    # print(i)

# l0 = (f != 0).float().sum(dim=-1).mean()
feat_prob = sum(probs) / len(probs)
print(feat_prob.shape)
log_freq = (feat_prob + 1e-10).log10()
print(log_freq.shape)

In [None]:
print(features.shape)
l0 = (features != 0).float().sum(dim=0)#.mean()
print(l0.mean())
l0 /= num_iters * batch_size
print(l0.shape)
print(l0)

print(l0.mean())

In [None]:
mask = (l0 > 0) & (l0 < 0.5)
idx = torch.nonzero(mask, as_tuple=False).squeeze()
print(idx.shape)
print(idx[:10])

In [None]:
l0_log = l0
import matplotlib.pyplot as plt
lo_log_np = l0_log.cpu().numpy()

# Creating the histogram
plt.figure(figsize=(10, 6))
plt.hist(lo_log_np, bins=50, alpha=0.75, color='blue')
plt.title('Histogram of firing rates for features')
plt.xlabel('(Probability)')
plt.ylabel('Frequency')
plt.grid(True)
plt.show()

In [None]:
import random
from circuitsvis.activations import text_neuron_activations
from einops import rearrange
import torch as t
from collections import namedtuple
import umap
import pandas as pd
import plotly.express as px



In [None]:
import matplotlib.pyplot as plt
log_freq_np = log_freq.cpu().numpy()
# log_freq_np = feat_prob.cpu()

# Creating the histogram
plt.figure(figsize=(10, 6))
plt.hist(log_freq_np, bins=50, alpha=0.75, color='blue')
plt.title('Histogram of log10 of Feature Probabilities')
plt.xlabel('log10(Probability)')
plt.ylabel('Frequency')
plt.grid(True)
plt.show()

In [None]:
print(idx[:20])
interest = 6
print(l0[interest])
print(idx.shape)
print(l0[:10])

In [None]:
import random
from circuitsvis.activations import text_neuron_activations
from einops import rearrange
import torch as t
from collections import namedtuple
import umap
import pandas as pd
import plotly.express as px


def feature_effect(
    model,
    submodule,
    dictionary,
    feature,
    inputs,
    add_residual=True,  # whether to compensate for dictionary reconstruction error by adding residual
    k=10,
    largest=True,
):
    """
    Effect of ablating the feature on top k predictions for next token.
    """
    # clean run
    with model.trace(inputs):
        if dictionary is None:
            pass
        elif not add_residual:  # run hidden state through autoencoder
            if type(submodule.output.shape) == tuple:
                submodule.output[0][:] = dictionary(submodule.output[0])
            else:
                submodule.output = dictionary(submodule.output)
        clean_output = model.output.save()
    try:
        clean_logits = clean_output.value.logits[:, -1, :]
    except:
        clean_logits = clean_output.value[:, -1, :]
    clean_logprobs = t.nn.functional.log_softmax(clean_logits, dim=-1)

    # ablated run
    with model.trace(inputs):
        if dictionary is None:
            if type(submodule.output.shape) == tuple:
                submodule.output[0][:, -1, feature] = 0
            else:
                submodule.output[:, -1, feature] = 0
        else:
            x = submodule.output
            if type(x.shape) == tuple:
                x = x[0]
            x_hat, f = dictionary(x, output_features=True)
            residual = x - x_hat

            f[:, -1, feature] = 0
            if add_residual:
                x_hat = dictionary.decode(f) + residual
            else:
                x_hat = dictionary.decode(f)

            if type(submodule.output.shape) == tuple:
                submodule.output[0][:] = x_hat
            else:
                submodule.output = x_hat
        ablated_output = model.output.save()
    try:
        ablated_logits = ablated_output.value.logits[:, -1, :]
    except:
        ablated_logits = ablated_output.value[:, -1, :]
    ablated_logprobs = t.nn.functional.log_softmax(ablated_logits, dim=-1)

    diff = clean_logprobs - ablated_logprobs
    top_probs, top_tokens = t.topk(diff.mean(dim=0), k=k, largest=largest)
    return top_tokens, top_probs


@t.no_grad()
def examine_dimension_old(
    model,
    submodule,
    buffer,
    dictionary=None,
    max_length=128,
    n_inputs=512,
    dim_idx=None,
    k=30,
    batch_size=4,
    device=t.device("cuda"),
):

    def _list_decode(x):
        if isinstance(x, int):
            return model.tokenizer.decode(x)
        else:
            return [_list_decode(y) for y in x]

    if dim_idx is None:
        dim_idx = random.randint(0, activations.shape[-1] - 1)

    n_iters = n_inputs // batch_size

    activations = t.zeros((n_iters * batch_size, max_length), device=device, dtype=t.float32)
    tokens = t.zeros((n_iters * batch_size, max_length), device=device, dtype=t.int64)

    for i in range(n_iters):
        inputs = buffer.text_batch(batch_size=batch_size)
        with model.trace(inputs, invoker_args=dict(max_length=max_length, truncation=True)):
            cur_tokens = model.input[1][
                "input_ids"
            ].save()  # if you're getting errors, check here; might only work for pythia models
            cur_activations = submodule.output
            if type(cur_activations.shape) == tuple:
                cur_activations = cur_activations[0]
            if dictionary is not None:
                cur_activations = dictionary.encode(cur_activations)
            cur_activations = cur_activations[:, :, dim_idx].save()
        activations[i * batch_size : (i + 1) * batch_size, :] = cur_activations.value
        tokens[i * batch_size : (i + 1) * batch_size, :] = cur_tokens.value

    top_affected = feature_effect(model, submodule, dictionary, dim_idx, tokens, k=k)
    top_affected = [(model.tokenizer.decode(tok), prob.item()) for tok, prob in zip(*top_affected)]

    # get top k tokens by mean activation
    token_mean_acts = {}
    # tokens = tokens.value
    for ctx in tokens:
        for tok in ctx:
            if tok.item() in token_mean_acts:
                continue
            idxs = (tokens == tok).nonzero(as_tuple=True)
            token_mean_acts[tok.item()] = activations[idxs].mean().item()
    top_tokens = sorted(token_mean_acts.items(), key=lambda x: x[1], reverse=True)[:k]
    top_tokens = [(model.tokenizer.decode(tok), act) for tok, act in top_tokens]

    flattened_acts = rearrange(activations, "b n -> (b n)")
    topk_indices = t.argsort(flattened_acts, dim=0, descending=True)[:k]
    batch_indices = topk_indices // activations.shape[1]
    token_indices = topk_indices % activations.shape[1]
    tokens = [
        tokens[batch_idx, : token_idx + 1].tolist()
        for batch_idx, token_idx in zip(batch_indices, token_indices)
    ]
    activations = [
        activations[batch_idx, : token_id + 1, None, None]
        for batch_idx, token_id in zip(batch_indices, token_indices)
    ]
    decoded_tokens = _list_decode(tokens)
    top_contexts = text_neuron_activations(decoded_tokens, activations)

    return namedtuple(
        "featureProfile",
        ["top_contexts", "top_tokens", "top_affected", "decoded_tokens", "activations"],
    )(top_contexts, top_tokens, top_affected, decoded_tokens, activations)


In [None]:
import random
from circuitsvis.activations import text_neuron_activations
from einops import rearrange
import torch as t
from collections import namedtuple
import umap
import pandas as pd
import plotly.express as px


def feature_effect(
    model,
    submodule,
    dictionary,
    feature,
    inputs,
    add_residual=True,  # whether to compensate for dictionary reconstruction error by adding residual
    k=10,
    largest=True,
):
    """
    Effect of ablating the feature on top k predictions for next token.
    """
    # clean run
    with model.trace(inputs):
        if dictionary is None:
            pass
        elif not add_residual:  # run hidden state through autoencoder
            if type(submodule.output.shape) == tuple:
                submodule.output[0][:] = dictionary(submodule.output[0])
            else:
                submodule.output = dictionary(submodule.output)
        clean_output = model.output.save()
    try:
        clean_logits = clean_output.value.logits[:, -1, :]
    except:
        clean_logits = clean_output.value[:, -1, :]
    clean_logprobs = t.nn.functional.log_softmax(clean_logits, dim=-1)

    # ablated run
    with model.trace(inputs):
        if dictionary is None:
            if type(submodule.output.shape) == tuple:
                submodule.output[0][:, -1, feature] = 0
            else:
                submodule.output[:, -1, feature] = 0
        else:
            x = submodule.output
            if type(x.shape) == tuple:
                x = x[0]
            x_hat, f = dictionary(x, output_features=True)
            residual = x - x_hat

            f[:, -1, feature] = 0
            if add_residual:
                x_hat = dictionary.decode(f) + residual
            else:
                x_hat = dictionary.decode(f)

            if type(submodule.output.shape) == tuple:
                submodule.output[0][:] = x_hat
            else:
                submodule.output = x_hat
        ablated_output = model.output.save()
    try:
        ablated_logits = ablated_output.value.logits[:, -1, :]
    except:
        ablated_logits = ablated_output.value[:, -1, :]
    ablated_logprobs = t.nn.functional.log_softmax(ablated_logits, dim=-1)

    diff = clean_logprobs - ablated_logprobs
    top_probs, top_tokens = t.topk(diff.mean(dim=0), k=k, largest=largest)
    return top_tokens, top_probs


@t.no_grad()
def examine_dimension(
    model,
    submodule,
    buffer,
    dictionary=None,
    max_length=128,
    n_inputs=512,
    dims=torch.tensor([0]),
    k=30,
    batch_size=4,
    device=t.device("cuda"),
):

    def _list_decode(x):
        if isinstance(x, int):
            return model.tokenizer.decode(x)
        else:
            return [_list_decode(y) for y in x]


    n_iters = n_inputs // batch_size

    dim_count = dims.shape[0]

    activations = t.zeros((dim_count, n_iters * batch_size, max_length), device=device, dtype=t.float32)
    tokens = t.zeros((n_iters * batch_size, max_length), device=device, dtype=t.int64)

    for i in range(n_iters):
        inputs = buffer.text_batch(batch_size=batch_size)
        with model.trace(inputs, invoker_args=dict(max_length=max_length, truncation=True)):
            cur_tokens = model.input[1][
                "input_ids"
            ].save()  # if you're getting errors, check here; might only work for pythia models
            cur_activations = submodule.output
            if type(cur_activations.shape) == tuple:
                cur_activations = cur_activations[0]
            if dictionary is not None:
                cur_activations = dictionary.encode(cur_activations)
            cur_activations = cur_activations[:, :, dims].save() # Shape: (batch_size, max_length, dim_count)
        cur_activations = rearrange(cur_activations.value, "b n d -> d b n") # Shape: (dim_count, batch_size, max_length)
        activations[:, i * batch_size : (i + 1) * batch_size, :] = cur_activations
        tokens[i * batch_size : (i + 1) * batch_size, :] = cur_tokens.value


    per_dim_stats = {}
    for i, dim in enumerate(dims):
        individual_acts = activations[i]
        
        # top_affected = feature_effect(model, submodule, dictionary, dim_idx, tokens, k=k)
        # top_affected = [(model.tokenizer.decode(tok), prob.item()) for tok, prob in zip(*top_affected)]
        top_affected = None

        # get top k tokens by mean activation
        token_mean_acts = {}
        # tokens = tokens.value
        for ctx in tokens:
            for tok in ctx:
                if tok.item() in token_mean_acts:
                    continue
                idxs = (tokens == tok).nonzero(as_tuple=True)
                token_mean_acts[tok.item()] = individual_acts[idxs].mean().item()
        top_tokens = sorted(token_mean_acts.items(), key=lambda x: x[1], reverse=True)[:k]
        top_tokens = [(model.tokenizer.decode(tok), act) for tok, act in top_tokens]

        flattened_acts = rearrange(individual_acts, "b n -> (b n)")
        topk_indices = t.argsort(flattened_acts, dim=0, descending=True)[:k]
        batch_indices = topk_indices // individual_acts.shape[1]
        token_indices = topk_indices % individual_acts.shape[1]
        individual_tokens = [
            tokens[batch_idx, : token_idx + 1].tolist()
            for batch_idx, token_idx in zip(batch_indices, token_indices)
        ]
        individual_acts = [
            individual_acts[batch_idx, : token_id + 1, None, None]
            for batch_idx, token_id in zip(batch_indices, token_indices)
        ]
        decoded_tokens = _list_decode(individual_tokens)

        if dim_count == 1:
            top_contexts = text_neuron_activations(decoded_tokens, activations)
        else:
            top_contexts = None
        per_dim_stats[dim.item()] = namedtuple(
            "featureProfile",
            ["top_contexts", "top_tokens", "top_affected", "decoded_tokens", "activations"],
        )(top_contexts, top_tokens, top_affected, decoded_tokens, individual_acts)

    return per_dim_stats

In [None]:
# from dictionary_learning.interp import examine_dimension
top_contexts, top_tokens, top_affected, decoded_tokens, activations = examine_dimension_old(model, submodule, buffer, dictionary=ae, dim_idx=interest, n_inputs=20, k=30, batch_size=4, device=DEVICE)
print(top_tokens)

In [None]:
print(idx)
print(torch.tensor([interest], device=DEVICE))

In [None]:
per_dim_stats = examine_dimension(model, submodule, buffer, dictionary=ae, dims=idx[:500], n_inputs=200, k=30, batch_size=4, device=DEVICE)
print(per_dim_stats[interest].top_tokens)

In [None]:
print(per_dim_stats.keys())

In [None]:
import chess_utils
# for string in decoded_tokens:
#     print(string)
# for act, string in zip(activations, decoded_tokens):
    # print(act.shape, len(string))
    # print(act[-1:], string[-1:])
    # print("".join(string))
# top_contexts

nonzero_count = 0
num_idx_count = 0
stop = 0
for dim in per_dim_stats:
    stop += 1
    if stop > 100:
        break
    decoded_tokens = per_dim_stats[dim].decoded_tokens
    activations = per_dim_stats[dim].activations
    if activations[10][-1].item() == 0:
        continue
    nonzero_count += 1
    inputs = ["".join(string) for string in decoded_tokens]
    # print(inputs)
    num_indices = []
    count = 0
    top_k = 10
    for i, pgn in enumerate(inputs[:top_k]):
        # print(f"dim: {dim} pgn: {pgn}, activation: {activations[i][-1].item()}")
        nums = chess_utils.find_num_indices(pgn)
        num_indices.append(nums)
        if (len(pgn) - 1) in nums:
            count += 1
    if count == top_k:
        # print(f"dim: {dim} all have num indices")
        num_idx_count += 1
print(num_idx_count, nonzero_count)


In [None]:
import chess_utils
import chess
# for string in decoded_tokens:
#     print(string)
# for act, string in zip(activations, decoded_tokens):
    # print(act.shape, len(string))
    # print(act[-1:], string[-1:])
    # print("".join(string))
# top_contexts

nonzero_count = 0
num_idx_count = 0
stop = 0
for dim in per_dim_stats:
    stop += 1
    if stop > 500:
        break
    decoded_tokens = per_dim_stats[dim].decoded_tokens
    activations = per_dim_stats[dim].activations
    if activations[10][-1].item() == 0:
        continue
    nonzero_count += 1
    inputs = ["".join(string) for string in decoded_tokens]

    chess_boards = [chess_utils.pgn_string_to_board(pgn, allow_exception=True) for pgn in inputs]

    num_indices = []
    count = 0
    top_k = 10

    if dim == 1082:
        print(inputs)
    
    for i, board in enumerate(chess_boards[:top_k]):
        # board = chess.Board()
        if board.is_check():
            print("Check")
            count += 1
    # print(count)
    if count > top_k - 2:
        print(f"dim: {dim} all have num indices")
        print(count)
        num_idx_count += 1
print(num_idx_count, nonzero_count)
