In [None]:
from nnsight import LanguageModel
import torch
import matplotlib.pyplot as plt
import chess
import json
from tqdm import tqdm

from dictionary_learning import ActivationBuffer
from nanogpt_to_hf_transformers import NanogptTokenizer, convert_nanogpt_model
from dictionary_learning.utils import hf_dataset_to_generator
from dictionary_learning import AutoEncoder

import chess_utils

In [None]:
print(chess_utils.piece_config.custom_board_state_function.__name__)

Step 1: Load the model, dictionary, data, and activation buffers.

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "../models/lichess_8layers_ckpt_no_optimizer.pt"
batch_size = 8

autoencoder_path = "../autoencoders/ef4_20k_resample/ef=4_lr=1e-03_l1=5e-02_layer=5/"

autoencoder_model_path = f"{autoencoder_path}ae.pt"
autoencoder_config_path = f"{autoencoder_path}config.json"
ae = AutoEncoder.from_pretrained(autoencoder_model_path, device=DEVICE)

with open(autoencoder_config_path, "r") as f:
    config = json.load(f)

context_length = config["buffer"]["ctx_len"]
layer = config["trainer"]["layer"]

tokenizer = NanogptTokenizer(meta_path="../models/meta.pkl")
model = convert_nanogpt_model(MODEL_PATH, torch.device(DEVICE))
model = LanguageModel(model, device_map=DEVICE, tokenizer=tokenizer).to(DEVICE)

submodule = model.transformer.h[layer].mlp  # layer 1 MLP
activation_dim = config["trainer"]["activation_dim"]  # output dimension of the MLP
dictionary_size = config["trainer"]["dictionary_size"]

# chess_sae_test is 100MB of data, so no big deal to download it
data = hf_dataset_to_generator("adamkarvonen/chess_sae_test", streaming=False)
buffer = ActivationBuffer(
    data,
    model,
    submodule,
    n_ctxs=512,
    ctx_len=context_length,
    refresh_batch_size=4,
    io="out",
    d_submodule=activation_dim,
    device=DEVICE,
    out_batch_size=batch_size,
    )

Collect feature activations on total_inputs inputs.

In [None]:
@torch.no_grad()
def get_feature(
    activations,
    ae: AutoEncoder,
    device,
):
    try:
        x = next(activations).to(device)
    except StopIteration:
        raise StopIteration(
            "Not enough activations in buffer. Pass a buffer with a smaller batch size or more data."
        )

    x_hat, f = ae(x, output_features=True)

    return f

total_inputs = 8192
assert total_inputs % batch_size == 0
num_iters = total_inputs // batch_size

features = torch.zeros((total_inputs, dictionary_size), device=DEVICE)
for i in tqdm(range(num_iters), total=num_iters, desc="Extracting features"):
    feature = get_feature(buffer, ae, DEVICE) # (batch_size, dictionary_size)
    features[i*batch_size:(i+1)*batch_size, :] = feature

A few plots about various statistics.

In [None]:
firing_rate_per_feature = (features != 0).float().sum(dim=0).cpu() / total_inputs

# Creating the histogram
plt.figure(figsize=(10, 6))
plt.hist(firing_rate_per_feature, bins=50, alpha=0.75, color='blue')
plt.title('Histogram of firing rates for features')
plt.xlabel('Probability')
plt.ylabel('Frequency')
plt.grid(True)
plt.show()

In [None]:

firing_rate_per_input = (features != 0).float().sum(dim=-1).cpu() / total_inputs

# Creating the histogram
plt.figure(figsize=(10, 6))
plt.hist(firing_rate_per_input, bins=50, alpha=0.75, color='blue')
plt.title('Percentage of features firing per input')
plt.xlabel('Probability')
plt.ylabel('Frequency')
plt.grid(True)
plt.show()

I got this from: https://colab.research.google.com/drive/19Qo9wj5rGLjb6KsB9NkKNJkMiHcQhLqo?usp=sharing#scrollTo=WZMhAzLTvw-u

In [None]:
feat_prob = features.mean(0)
print(feat_prob.shape)
log_freq = (feat_prob + 1e-10).log10()
print(log_freq.shape)

log_freq_np = log_freq.cpu().numpy()

# Creating the histogram
plt.figure(figsize=(10, 6))
plt.hist(log_freq_np, bins=50, alpha=0.75, color='blue')
plt.title('Histogram of log10 of Feature Probabilities')
plt.xlabel('log10(Probability)')
plt.ylabel('Frequency')
plt.grid(True)
plt.show()

Get the L0 statistic. Then, get a list of indices for features that fire between 0 and 50% of the time.

In [None]:
print(features.shape)
l0 = (features != 0).float().sum(dim=-1).mean()
print(f"l0: {l0}")

firing_rate_per_feature = (features != 0).float().sum(dim=0) / total_inputs

assert firing_rate_per_feature.shape[0] == dictionary_size

mask = (firing_rate_per_feature > 0) & (firing_rate_per_feature < 0.5)
idx = torch.nonzero(mask, as_tuple=False).squeeze()
print(idx.shape)
print(f"\n\nWe have {idx.shape[0]} features that fire between 0 and 50% of the time.")
print(idx[:10])

Next, we collect per dim stats, which include the top tokens it fires on, and the top k inputs and activations per input token.

Rough ballpark times on my RTX 3050: 

2000 dims, 3000 inputs, batch size 50 = 42 seconds

Note that I perform the activation processing on my CPU. This is comparable speed, but much lower VRAM usage.

In [None]:
import importlib
import chess_interp
importlib.reload(chess_interp)


per_dim_stats = chess_interp.examine_dimension_chess(model, submodule, buffer, dictionary=ae, dims=idx[:], n_inputs=3000, k=30, batch_size=50, processing_device="cpu")

This cell looks at syntax related features. Specifically, it looks for features that always fire on a PGN "counting number". In this PGN, I've wrapped the "counting numbers" in brackets.

;<1.>e4 e5 <2.>Nf3 ...

We can easily analyze different syntax related attributes by just passing in a different syntax function, such as one that just finds space indices.

In [None]:
importlib.reload(chess_utils)


def syntax_analysis(per_dim_stats: dict, minimum_number_of_activations: int, top_k: int, max_dims: int, syntax_function: callable, verbose: bool = False) -> dict:

    nonzero_count = 0
    syntax_match_idx_count = 0
    dim_count = 0
    average_input_length = 0
    length_tracker = []

    for dim in per_dim_stats:
        dim_count += 1
        if dim_count > max_dims:
            break

        decoded_tokens = per_dim_stats[dim]['decoded_tokens']
        activations = per_dim_stats[dim]['activations']
        # If the dim doesn't have at least 10 firing activations, skip it
        if activations[minimum_number_of_activations][-1].item() == 0:
            continue
        nonzero_count += 1

        inputs = ["".join(string) for string in decoded_tokens]
        inputs = inputs[:top_k]
        
        num_indices = []
        count = 0
        for i, pgn in enumerate(inputs[:top_k]):
            # NOTE: Uncomment this line to view examples of common indices
            # print(f"dim: {dim} pgn: {pgn}, activation: {activations[i][-1].item()}")
            nums = syntax_function(pgn)
            num_indices.append(nums)

            # If the last token (which contains the max activation for that context) is a number
            # Then we count this firing as a "number index firing"
            if (len(pgn) - 1) in nums:
                count += 1
                average_input_length = sum(len(pgn) for pgn in inputs) / len(inputs)
                length_tracker.append(average_input_length)
                
        if count == top_k:
            # print(f"All top {top_k} activations in dim: {dim} are on num indices")
            syntax_match_idx_count += 1

    total_average_length = -1
    if len(length_tracker) > 0:
        total_average_length = sum(length_tracker) / len(length_tracker)
    
    if verbose:
        print(f"Out of {len(per_dim_stats)} features, {nonzero_count} had at least {minimum_number_of_activations} activations.")
        print(f"{syntax_match_idx_count} features matched on all top {top_k} inputs for our syntax function {syntax_function.__name__}")
        print(f"The average length of inputs of pattern matching features was {total_average_length}")

    results = {}

    results['dim_count'] = dim_count
    results['nonzero_count'] = nonzero_count
    results['syntax_match_idx_count'] = syntax_match_idx_count
    results['average_input_length'] = total_average_length

    return results

syntax_analysis(per_dim_stats, 20, 20, 2500, chess_utils.find_num_indices, verbose=True)
syntax_analysis(per_dim_stats, 20, 20, 2500, chess_utils.find_spaces_indices, verbose=True)


In [None]:
importlib.reload(chess_utils)
from chess_utils import Config


def board_analysis(
    per_dim_stats: dict,
    minimum_number_of_activations: int,
    top_k: int,
    max_dims: int,
    threshold: float,
    configs: list[Config],
    device: str = "cpu",
    verbose: bool = False,
) -> dict:

    nonzero_count = 0
    dim_count = 0

    results = {}

    for config in configs:
        board_tracker = torch.zeros(config.num_rows, config.num_cols, device=device)
        num_classes = abs(config.min_val) + abs(config.max_val) + 1
        per_class_dict = {key: 0 for key in range(0, num_classes)}

        results[config.custom_board_state_function.__name__] = {
            "pattern_match_count": 0,
            "total_average_length": 0,
            "average_matches_per_dim": 0,
            "per_class_dict": per_class_dict,
            "board_tracker": board_tracker,
        }

    for dim in tqdm(per_dim_stats, total=len(per_dim_stats), desc="Processing chess pgn strings"):
        dim_count += 1
        if dim_count > max_dims:
            break

        decoded_tokens = per_dim_stats[dim]["decoded_tokens"]
        activations = per_dim_stats[dim]["activations"]
        # If the dim doesn't have at least minimum_number_of_activations firing activations, skip it
        if activations[minimum_number_of_activations][-1].item() == 0:
            continue
        nonzero_count += 1

        inputs = ["".join(string) for string in decoded_tokens]
        inputs = inputs[:top_k]

        num_indices = []
        count = 0

        chess_boards = [
            chess_utils.pgn_string_to_board(pgn, allow_exception=True) for pgn in inputs
        ]

        for config in configs:

            config_name = config.custom_board_state_function.__name__

            one_hot_list = chess_utils.chess_boards_to_state_stack(chess_boards, device, config)
            one_hot_list = chess_utils.mask_initial_board_states(one_hot_list, device, config)
            averaged_one_hot = chess_utils.get_averaged_states(one_hot_list)
            common_indices = chess_utils.find_common_states(averaged_one_hot, threshold)

            if any(len(idx) > 0 for idx in common_indices):
                results[config_name][
                    "pattern_match_count"
                ] += 1  # Increment if there are nonzero indices
                average_input_length = sum(len(pgn) for pgn in inputs) / len(inputs)
                results[config_name]["total_average_length"] += average_input_length

            if config.num_rows == 8:
                for idx in zip(*common_indices):
                    value = averaged_one_hot[idx].item()
                    # print(f"Dim: {dim}, Average input length: {int(average_input_length):04}, Value: {value:.2f} at Index: {idx}")
                    results[config_name]["board_tracker"][idx[0], idx[1]] += 1
                    results[config_name]["per_class_dict"][idx[2].item()] += 1
                    results[config_name]["average_matches_per_dim"] += 1

    for config in configs:
        config_name = config.custom_board_state_function.__name__
        match_count = results[config_name]["pattern_match_count"]
        results[config_name]["dim_count"] = dim_count
        results[config_name]["nonzero_count"] = nonzero_count
        results[config_name]["board_tracker"] = results[config_name]["board_tracker"].flip(0).tolist()
        if match_count > 0:
            results[config_name]["total_average_length"] /= match_count
            results[config_name]["average_matches_per_dim"] /= match_count

    if verbose:
        for config in configs:
            config_name = config.custom_board_state_function.__name__
            pattern_match_count = results[config_name]["pattern_match_count"]
            total_average_length = results[config_name]["total_average_length"]
            print(f"\n{config_name} Results:")
            print(
                f"Out of {dim_count} features, {nonzero_count} had at least {minimum_number_of_activations} activations."
            )
            print(
                f"{pattern_match_count} features matched on all top {top_k} inputs for our board to state function {config_name}"
            )
            print(
                f"The average length of inputs of pattern matching features was {total_average_length}"
            )

            if config.num_rows == 8:
                per_class_dict = results[config_name]["per_class_dict"]
                print(f"\nThe following square states had the following number of occurances:")
                for key, count in per_class_dict.items():
                    print(f"Index: {key}, Count: {count}")

                print(f"\nHere are the most common squares:")
                print(results[config_name]["board_tracker"])

    return results

In [None]:
board_analysis(per_dim_stats, 20, 20, 2500, 0.99, [chess_utils.piece_config], device="cpu", verbose=True)

In [None]:
board_analysis(per_dim_stats, 20, 20, 2500, 0.99, [chess_utils.threat_config], verbose=True)

In [None]:
board_analysis(per_dim_stats, 20, 20, 2500, 0.99, [chess_utils.check_config], verbose=True)

In [None]:
board_analysis(per_dim_stats, 20, 20, 2500, 0.99, [chess_utils.pin_config], verbose=True)