`eval_sae_as_classifier.py` does the following:

A custom function could be `board_to_pin_state`, which, for every token in the PGN string, returns a "state stack", which is 0 or 1. 0 means "There is not a pin on the board at this character" and 1 means "there is a pin on the board at this character". Or it could be like `board_to_piece_state`, which returns a state stack one hot tensor of shape (8,8,13) or (rows, cols, num_classes), which returns the state of every square on the chess board.

Over 1000's of input pgn strings, for every activation for every dictionary feature for a range of threshold values, we check if the activation is above every threshold. For every active activation, for every custom function, we add the state stack to the on_tracker. For every off activation, for every custom function, we add the state stack to the off tracker. This runs reasonably quickly - around 2 minutes on an RTX 3090 for every 1000 input pgn strings.

on_tracker is shape (thresholds, features, rows, cols, classes).

So, if for 100% of the times that a feature is active, the board has a corresponding state (such as there is a pinned piece on the board, or a white knight on f3), then it's likely that the feature corresponds to that state.

There's some promising, but not great results. I'm seeing stuff that's directionally correct. Such as, SAEs with good L0s trained on layer 6/8 (layer 6 has a 99% board state accuracy with linear probes) have hundreds of features that are high precision accuracy when they fire above some threshold for square state, and a couple that are high precision for "there is a pin on the board". SAEs trained on layer 0 or SAEs with a poor L0 have almost 0 features that are good classifiers for square state, and 0 for there is a pin on the board. But, I think there's a few features I need to add that are very important:

Currently, the threshold is something like torch.arange(0,4,0.5). However, if I look at max activation per feature, it's 0.2 for some features and 13 for others. So, this mean at the highest threshold value, some features still have many thousands of activations. My plan was to collect max activation per feature over n examples, then have an individual threshold per feature of torch.arange(0,1.1,0.1) * max activation.

As discussed here: https://adamkarvonen.github.io/machine_learning/2024/01/03/chess-world-models.html
Investigate Mine / Yours / vs White / Black model "thinking" for square state.

Add syntax classification filters from chess_interp.py.

Do qualitative analysis of the features, probably make some notebook interface for viewing purposes.

In [None]:
import pickle
import torch
from typing import Callable
import circuits.chess_utils as chess_utils

# This should have been downloaded and unzipped by setup.sh
filename = "group1_results/autoencoders_group1_ef=4_lr=1e-03_l1=1e-01_layer=5_results.pkl"
# filename = "layer0_results/autoencoders_layer0_ef=4_lr=1e-03_l1=1e-01_layer=0_results.pkl"
# filename = "group1_results/autoencoders_group1_ef=16_lr=1e-03_l1=1e-01_layer=5_results.pkl"

with open(filename, 'rb') as f:
    results = pickle.load(f)

print(results.keys())
print("\nAs we can see, every custom function shares the same keys.\n")
print(results['board_to_pin_state'].keys())
print(results['board_to_piece_state'].keys())
print("However, the shapes of the values are different.\n")
print(results['board_to_pin_state']['on'].shape, results['board_to_pin_state']['off'].shape)
print(results['board_to_piece_state']['on'].shape, results['board_to_piece_state']['off'].shape)

In [None]:
# This usually isn't needed as eval_sae_as_classifier now does this, but I have some results that are on the GPU
from circuits.utils import to_cpu
    
results = to_cpu(results)

We have raw counts of how many times every state was active while a feature was on / off. We can convert these to percentages. For example, this state was active 100% of the time this feature was active.

In [None]:
print(results['board_to_piece_state']['on'].shape)

In [None]:

def normalize_tracker(
    results: dict, tracker_type: str, custom_functions: list[Callable], device: torch.device
):
    """Normalize the specified tracker (on or off) values by its count using element-wise multiplication."""
    for custom_function in custom_functions:
        counts_TF = results[f"{tracker_type}_count"]

        # Calculate inverse of counts safely
        inverse_counts_TF = torch.zeros_like(counts_TF).to(device)
        non_zero_mask = counts_TF > 0
        inverse_counts_TF[non_zero_mask] = 1 / counts_TF[non_zero_mask]

        tracker_TFRRC = results[custom_function.__name__][tracker_type]

        # Normalize using element-wise multiplication
        normalized_tracker_TFRRC = tracker_TFRRC * inverse_counts_TF[:, :, None, None, None]

        # Store the normalized results
        results[custom_function.__name__][f"{tracker_type}_normalized"] = normalized_tracker_TFRRC

    return results

results = normalize_tracker(results, "on", [chess_utils.board_to_pin_state, chess_utils.board_to_piece_state], torch.device("cpu"))
results = normalize_tracker(results, "off", [chess_utils.board_to_pin_state, chess_utils.board_to_piece_state], torch.device("cpu"))

In [None]:
print(results['on_count'][:, :5].to(torch.int))

These results came from num_inputs pgn strings of len 256. So, if we sum across possible square states, every element == 256 * num_inputs, which is also the total number of tokens / characters the SAE was evaluated on.

In [None]:
if 'hyperparameters' in results:
    n_inputs = results['hyperparameters']['n_inputs']
    print(f"Every square should sum to {n_inputs * 256}.")

print(results['board_to_piece_state']['on'][0].shape)
print(results['board_to_piece_state']['off'][0].shape)

on_tracker = results['board_to_piece_state']['on'][0].sum(dim=-1)
off_tracker = results['board_to_piece_state']['off'][0].sum(dim=-1)

compare = on_tracker + off_tracker
print(compare.shape)
print(compare[0])

In contrast, every pin state should sum to the number of characters where there was a pin on the board. It often seems to be about 10% of number of characters (the above number).

In [None]:
print(results['board_to_pin_state']['on'].squeeze()[0])
print(results['board_to_pin_state']['off'].squeeze()[0])

compare = results['board_to_pin_state']['on'].squeeze()[1] + results['board_to_pin_state']['off'].squeeze()[1]
print(compare[:5])

This next cell looks for elements that were active > some percentage of the time (high_threshold) whenever a feature was active. For example, maybe there was a pin on the board 98% of the time feature 253 was active above threshold idx 5/10 (maybe the threshold was 2.0 for this index).

If this is the case, we also check that this feature was active at least `significance threshold` times. Otherwise, any feature that was active only 1 time would have many percentage matches.

In [None]:
def get_above_below_counts(
    tracker_TF: torch.Tensor, counts_TF: torch.Tensor, low_threshold: float, high_threshold: float, signficance_threshold: int = 10
) -> torch.Tensor:
    """Must be a 2D tensor matching shape annotation."""

    # Find all elements that were active less than x% of the time (low_threshold)
    below_freq_TF_mask = tracker_TF <= low_threshold
    # Find all elements that were active more than x% of the time (high_threshold)
    above_freq_TF_mask = tracker_TF >= high_threshold

    # For the counts tensor, zero out all elements that were not active enough
    above_counts_TF = counts_TF * above_freq_TF_mask

    # Find all features that were active more than significance_threshold times
    above_counts_TF_mask = above_counts_TF >= signficance_threshold

    print(above_counts_TF.shape)

    # Zero out all elements that were not active enough
    above_counts_TF = above_counts_TF * above_counts_TF_mask

    # Count the number of elements that were active more than high_threshold % and significance_threshold times
    above_counts_T = above_counts_TF_mask.sum(dim=(1))
    print(f"\nThis is the number of elements that were active more than {high_threshold} and {signficance_threshold} times.")
    print(f"Note that this shape is num_thresholds, and every element corresponds to a threshold.")
    print(above_counts_T)

    # Count the number of elements that were active less than low_threshold %
    below_T = below_freq_TF_mask.sum(dim=(1))
    # Count the number of elements that were active more than high_threshold %
    above_T = above_freq_TF_mask.sum(dim=(1))

    print(f"\nThis is the number of elements that were active less than {low_threshold} percent.")
    print(f"Note that this shape is num_thresholds, and every element corresponds to a threshold.")
    print(below_T)
    print(f"\nThis is the number of elements that were active more than {high_threshold} percent.")
    print(above_T)

    values_above_threshold = [tracker_TF[i, above_freq_TF_mask[i]] for i in range(tracker_TF.size(0))]
    counts_above_threshold = [counts_TF[i, above_freq_TF_mask[i]] for i in range(tracker_TF.size(0))]

    # for i, values in enumerate(values_above_threshold):
    #     print(f"Row {i} values above {high_threshold}: {values.tolist()}")

    # for i, counts in enumerate(counts_above_threshold):
    #     print(f"Row {i} counts above {high_threshold}: {counts.tolist()}")

    return above_counts_TF

_ = get_above_below_counts(
    results["board_to_pin_state"]["on_normalized"].squeeze().clone(),
    results["board_to_pin_state"]["on"].squeeze().clone(),
    0.00,
    0.95,
    signficance_threshold=50
)

Now, we flatten the `board_to_piece_state` tracker to shape (thresholds, (rows, cols, classes)). We do some masking of certain states, and rerun the same analysis.

In [None]:
import einops
import chess

key = "on"
normalized_key = "on_normalized"

num_thresholds = results["board_to_piece_state"][normalized_key].shape[0]

piece_state_on_normalized = results["board_to_piece_state"][normalized_key].clone().view(num_thresholds, -1)
piece_state_on = results["board_to_piece_state"][key].clone()
original_shape = piece_state_on.shape


# We mask off the initial board state, otherwise this will have tons of matches for any feature that fires early game
initial_board = chess.Board()
initial_state = chess_utils.board_to_piece_state(initial_board)
initial_state = initial_state.view(1, 1, 8, 8)
initial_one_hot = chess_utils.state_stack_to_one_hot(chess_utils.piece_config, "cpu", initial_state).squeeze()
mask = (initial_one_hot == 1)
piece_state_on[:, :, mask] = 0

# Optionally, we also mask off the blank class
piece_state_on[:, :, :, :, 6] = 0

# Flatten the tensor to a 2D shape for compatibility with get_above_below_counts()
piece_state_on = piece_state_on.view(num_thresholds, -1)
print(piece_state_on_normalized.shape)

new_piece_state_on = get_above_below_counts(piece_state_on_normalized, piece_state_on, 0.00, 0.98, signficance_threshold=50)
new_piece_state_on = new_piece_state_on.view(original_shape)

summary_board = einops.reduce(new_piece_state_on, "T F R1 R2 C -> R1 R2", "sum").to(torch.int)
print(f"\nThis is the number of times each square was active more than 98% of the time above some significance_threshold.")
print(summary_board)

class_dict = einops.reduce(new_piece_state_on, "T F R1 R2 C -> C", "sum").to(torch.int)
print(f"\nThis is the number of times each piece was active more than 98% of the time. 0 == black king, 1 == black queen, 6 == blank, 7 == white pawn, etc.")
print(class_dict)

This is an experiment looking at mine / yours vs white / black. It's half baked right now.

In [None]:
def transform_board_from_piece_color_to_piece(board: torch.Tensor) -> torch.Tensor:
    new_board = torch.zeros(board.shape[:-1] + (7,), dtype=board.dtype, device=board.device)

    for i in range(7):
        if i == 6:
            new_board[..., i] = board[..., 6]
        else:
            new_board[..., i] = board[..., i] + board[..., 12 - i]
    return new_board

results['board_to_piece_state']['on_piece'] = transform_board_from_piece_color_to_piece(results['board_to_piece_state']['on'])
results['on_piece_count'] = results['on_count']
results = normalize_tracker(results, "on_piece", [chess_utils.board_to_piece_state], torch.device("cpu"))

In [None]:
import einops
import chess

key = "on_piece"
normalized_key = "on_piece_normalized"

num_thresholds = results["board_to_piece_state"][normalized_key].shape[0]

piece_state_on_normalized = results["board_to_piece_state"][normalized_key].clone().view(num_thresholds, -1)
piece_state_on = results["board_to_piece_state"][key].clone()
original_shape = piece_state_on.shape


initial_board = chess.Board()
initial_state = chess_utils.board_to_piece_state(initial_board)
initial_state = initial_state.view(1, 1, 8, 8)
initial_one_hot = chess_utils.state_stack_to_one_hot(chess_utils.piece_config, "cpu", initial_state).squeeze()

initial_one_hot = transform_board_from_piece_color_to_piece(initial_one_hot)

mask = (initial_one_hot == 1)
piece_state_on[:, :, mask] = 0

piece_state_on[:, :, :, :, 6] = 0
piece_state_on = piece_state_on.view(num_thresholds, -1)
print(piece_state_on_normalized.shape)

new_piece_state_on = get_above_below_counts(piece_state_on_normalized, piece_state_on, 0.00, 0.98, signficance_threshold=50)
new_piece_state_on = new_piece_state_on.view(original_shape)

summary_board = einops.reduce(new_piece_state_on, "T F R1 R2 C -> R1 R2", "sum").to(torch.int)
print(summary_board)

class_dict = einops.reduce(new_piece_state_on, "T F R1 R2 C -> C", "sum").to(torch.int)
print(class_dict)