In [None]:
from nnsight import LanguageModel
import torch
import matplotlib.pyplot as plt
import chess
import json

from dictionary_learning import ActivationBuffer
from circuits.nanogpt_to_hf_transformers import NanogptTokenizer, convert_nanogpt_model
from dictionary_learning.utils import hf_dataset_to_generator
from dictionary_learning import AutoEncoder

import chess_utils

Step 1: Load the model, dictionary, data, and activation buffers.

In [None]:


autoencoder_path = "../autoencoders/ef=8_lr=1e-04_l1=1e-03_layer=5/"
autoencoder_model_path = f"{autoencoder_path}ae.pt"
autoencoder_config_path = f"{autoencoder_path}config.json"

DEVICE = torch.device("cuda")
ae = AutoEncoder.from_pretrained(autoencoder_model_path, device=DEVICE)

with open(autoencoder_config_path, "r") as f:
    config = json.load(f)

print(config)

context_length = config['buffer']['ctx_len']


In [None]:
tokenizer = NanogptTokenizer()
model = convert_nanogpt_model("../models/lichess_8layers_ckpt_no_optimizer.pt", torch.device(DEVICE))
model = LanguageModel(model, device_map=DEVICE, tokenizer=tokenizer).to(DEVICE)

submodule = model.transformer.h[5].mlp  # layer 1 MLP
activation_dim = 512  # output dimension of the MLP
dictionary_size = 8 * activation_dim

batch_size = 8

# chess_sae_test is 100MB of data, so no big deal to download it
data = hf_dataset_to_generator("adamkarvonen/chess_sae_test", streaming=False)
buffer = ActivationBuffer(
    data,
    model,
    submodule,
    n_ctxs=512,
    ctx_len=256,
    refresh_batch_size=4,
    io="out",
    d_submodule=512,
    device=DEVICE,
    out_batch_size=batch_size,
)

Collect feature activations on total_inputs inputs.

In [None]:
@torch.no_grad()
def get_feature(
    activations,
    ae: AutoEncoder,
    device,
):
    try:
        x = next(activations).to(device)
    except StopIteration:
        raise StopIteration(
            "Not enough activations in buffer. Pass a buffer with a smaller batch size or more data."
        )

    x_hat, f = ae(x, output_features=True)

    return f

total_inputs = 8192
assert total_inputs % batch_size == 0
num_iters = total_inputs // batch_size

features = torch.zeros((total_inputs, dictionary_size), device=DEVICE)
for i in range(num_iters):
    feature = get_feature(buffer, ae, DEVICE) # (batch_size, dictionary_size)
    features[i*batch_size:(i+1)*batch_size, :] = feature

A few plots about various statistics.

In [None]:
firing_rate_per_feature = (features != 0).float().sum(dim=0).cpu() / total_inputs

# Creating the histogram
plt.figure(figsize=(10, 6))
plt.hist(firing_rate_per_feature, bins=50, alpha=0.75, color='blue')
plt.title('Histogram of firing rates for features')
plt.xlabel('Probability')
plt.ylabel('Frequency')
plt.grid(True)
plt.show()

In [None]:

firing_rate_per_input = (features != 0).float().sum(dim=-1).cpu() / total_inputs

# Creating the histogram
plt.figure(figsize=(10, 6))
plt.hist(firing_rate_per_input, bins=50, alpha=0.75, color='blue')
plt.title('Percentage of features firing per input')
plt.xlabel('Probability')
plt.ylabel('Frequency')
plt.grid(True)
plt.show()

I got this from: https://colab.research.google.com/drive/19Qo9wj5rGLjb6KsB9NkKNJkMiHcQhLqo?usp=sharing#scrollTo=WZMhAzLTvw-u

In [None]:
feat_prob = features.mean(0)
print(feat_prob.shape)
log_freq = (feat_prob + 1e-10).log10()
print(log_freq.shape)

log_freq_np = log_freq.cpu().numpy()

# Creating the histogram
plt.figure(figsize=(10, 6))
plt.hist(log_freq_np, bins=50, alpha=0.75, color='blue')
plt.title('Histogram of log10 of Feature Probabilities')
plt.xlabel('log10(Probability)')
plt.ylabel('Frequency')
plt.grid(True)
plt.show()

Get the L0 statistic. Then, get a list of indices for features that fire between 0 and 50% of the time.

In [None]:
print(features.shape)
l0 = (features != 0).float().sum(dim=-1).mean()
print(f"l0: {l0}")

firing_rate_per_feature = (features != 0).float().sum(dim=0) / total_inputs

assert firing_rate_per_feature.shape[0] == dictionary_size

mask = (firing_rate_per_feature > 0) & (firing_rate_per_feature < 0.5)
idx = torch.nonzero(mask, as_tuple=False).squeeze()
print(idx.shape)
print(f"\n\nWe have {idx.shape[0]} features that fire between 0 and 50% of the time.")
print(idx[:10])

Next, we collect per dim stats, which include the top tokens it fires on, and the top k inputs and activations per input token.
To speed this up, we can reduce the amount of processing by reducing the number of dims with idx[:10]. idx will probably contain thousands
of dimensions, and we do n_inputs work for each dimension.
At low numbers of dims, runtime is dominated by forwarding the GPT model through n_inputs.
At high numbers of dims, runtime is dominated by finding the top inputs per dimension, work of len(dims) * n_inputs.
Note that we could probably make the processing much faster, I haven't optimized it at all.

Rough ballpark times on my RTX 3050: 

10 dims, 2400 inputs, batch size 48 = 25 seconds

1500 dims, 192 inputs, batch_size 48 = 4 minutes 20 seconds

In [None]:
import importlib
from dictionary_learning import interp
importlib.reload(interp)

from dictionary_learning.interp import examine_dimension


per_dim_stats = examine_dimension(model, submodule, buffer, dictionary=ae, dims=idx[:], n_inputs=192, k=30, batch_size=48, device=DEVICE)

This cell looks at syntax related features. Specifically, it looks for features that always fire on a PGN "counting number". In this PGN, I've wrapped the "counting numbers" in brackets.

;<1.>e4 e5 <2.>Nf3 ...

In [None]:
importlib.reload(chess_utils)

minimum_number_of_activations = 10
top_k = 10

nonzero_count = 0
num_idx_count = 0
dim_count = 0
max_dims = 10000
for dim in per_dim_stats:
    dim_count += 1
    if dim_count > max_dims:
        break

    decoded_tokens = per_dim_stats[dim].decoded_tokens
    activations = per_dim_stats[dim].activations
    # If the dim doesn't have at least 10 firing activations, skip it
    if activations[minimum_number_of_activations][-1].item() == 0:
        continue
    nonzero_count += 1

    inputs = ["".join(string) for string in decoded_tokens]
    num_indices = []
    count = 0
    for i, pgn in enumerate(inputs[:top_k]):
        print(f"dim: {dim} pgn: {pgn}, activation: {activations[i][-1].item()}")
        nums = chess_utils.find_num_indices(pgn)
        num_indices.append(nums)

        # If the last token (which contains the max activation for that context) is a number
        # Then we count this firing as a "number index firing"
        if (len(pgn) - 1) in nums:
            count += 1

    if count == top_k:
        print(f"All top {top_k} activations in dim: {dim} are on num indices")
        num_idx_count += 1
print(num_idx_count, nonzero_count)


In [None]:
nonzero_count = 0
nonzero_dim_count = 0
dim_count = 0
top_k = 20

max_dims = 10000

average_input_length = 0

board_tracker = torch.zeros(8, 8, device="cpu")
result_dict = {key: 0 for key in range(0, 13)}
length_tracker = []
board_to_state_fn = chess_utils.board_to_piece_state

for dim in per_dim_stats:
    dim_count += 1
    if dim_count > max_dims:
        break
    decoded_tokens = per_dim_stats[dim].decoded_tokens
    activations = per_dim_stats[dim].activations
    if activations[10][-1].item() == 0:
        continue
    nonzero_count += 1
    inputs = ["".join(string) for string in decoded_tokens]
    inputs = inputs[:top_k]

    chess_boards = [chess_utils.pgn_string_to_board(pgn, allow_exception=True) for pgn in inputs]
    
    one_hot_list = chess_utils.chess_boards_to_state_stack(chess_boards, DEVICE, board_to_state_fn)
    one_hot_list = chess_utils.mask_initial_board_states(one_hot_list, DEVICE, board_to_state_fn)
    averaged_one_hot = chess_utils.get_averaged_states(one_hot_list)
    common_indices = chess_utils.find_common_states(averaged_one_hot, 0.9)

    if any(len(idx) > 0 for idx in common_indices):
        nonzero_dim_count += 1  # Increment if there are nonzero indices
        average_input_length = sum(len(pgn) for pgn in inputs) / len(inputs)
        length_tracker.append(average_input_length)

    for idx in zip(*common_indices):
        value = averaged_one_hot[idx].item()
        # print(f"Dim: {dim}, Average input length: {int(average_input_length):04}, Value: {value:.2f} at Index: {idx}")
        board_tracker[idx[0], idx[1]] += 1
        result_dict[idx[2].item()] += 1
print(nonzero_dim_count, nonzero_count)

for key, count in result_dict.items():
    print(f"Index: {key}, Count: {count}")
print(board_tracker.flip(0))

In [None]:
nonzero_count = 0
nonzero_dim_count = 0
dim_count = 0
top_k = 20

max_dims = 100

average_input_length = 0

board_tracker = torch.zeros(8, 8, device="cpu")
result_dict = {key: 0 for key in range(0, 13)}
length_tracker = []
board_to_state_fn = chess_utils.board_to_threat_state

for dim in per_dim_stats:
    dim_count += 1
    if dim_count > max_dims:
        break
    decoded_tokens = per_dim_stats[dim].decoded_tokens
    activations = per_dim_stats[dim].activations
    if activations[10][-1].item() == 0:
        continue
    nonzero_count += 1
    inputs = ["".join(string) for string in decoded_tokens]
    inputs = inputs[:top_k]

    chess_boards = [chess_utils.pgn_string_to_board(pgn, allow_exception=True) for pgn in inputs]
    
    one_hot_list = chess_utils.chess_boards_to_state_stack(chess_boards, DEVICE, board_to_state_fn)
    one_hot_list = chess_utils.mask_initial_board_states(one_hot_list, DEVICE, board_to_state_fn)
    averaged_one_hot = chess_utils.get_averaged_states(one_hot_list)
    common_indices = chess_utils.find_common_states(averaged_one_hot, 0.9)

    if any(len(idx) > 0 for idx in common_indices):
        nonzero_dim_count += 1  # Increment if there are nonzero indices
        average_input_length = sum(len(pgn) for pgn in inputs) / len(inputs)
        length_tracker.append(average_input_length)

    for idx in zip(*common_indices):
        value = averaged_one_hot[idx].item()
        board_tracker[idx[0], idx[1]] += 1
        result_dict[idx[2].item()] += 1
print(nonzero_dim_count, nonzero_count)

for key, count in result_dict.items():
    print(f"Index: {key}, Count: {count}")
print(board_tracker.flip(0))