# A Mechanistic Interpretability Study of a Chess Playing Language Model

There have been many reports of language models possessing linear representations of the board state in games. In particular, Karvonen<sup>1</sup> trained a GPT-2 style transformer (ChessGPT) on chess PGNs and was able to show that the model contained a linearly decodable representation of the board.

However, how the transformer constructs this representation remains unclear. Recent interpretability efforts, such as those by Davis et. al.<sup>2</sup> classify some attention patterns and focus on finding where the model commits to its next move, but do not analyze the model's board representation. This work investigates the internal workings of ChessGPT and proposes methods that explain how the model computes the board state.

---
<sup>1</sup> Karvonen, A. Emergent world models and latent variable estimation in chess-playing language models. In Proceedings of the Conference on Language Modeling (COLM), 2024. URL https://openreview.net/forum?id=PPTrmvEnpW.
Accepted at COLM 2024.

<sup>2</sup> Davis, A. L. and Sukthankar, G. Decoding chess mastery: A mechanistic analysis of a chess language transformer model. In Artificial General Intelligence: 17th International Conference, AGI 2024, Seattle, WA, USA, August 13–16, 2024, Proceedings, pp. 63–72, Berlin, Heidelberg, 2024. Springer-Verlag. ISBN 978-3-031-65571-5. doi: 10.1007/978-3-031-65572-2 7. URL https://doi.org/10.1007/978-3-031-65572-2_7.

In [None]:
### import relevant packages

import bisect
import collections
import os
import pickle
from dataclasses import dataclass, field
from pathlib import Path

import chess
import circuitsvis as cv 
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm
from tabulate import tabulate

from transformer_lens import HookedTransformer, HookedTransformerConfig

In [None]:
# chess utils
PIECE_TO_INT = {
    chess.PAWN: 1,
    chess.KNIGHT: 2,
    chess.BISHOP: 3,
    chess.ROOK: 4,
    chess.QUEEN: 5,
    chess.KING: 6,
}

INT_TO_PIECE = {value: key for key, value in PIECE_TO_INT.items()}

# model params
D_MODEL = 512
N_HEADS = 8

MODEL_DIR = "models/"
DATA_DIR = "data/"

DEVICE = (
    "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
)

for d in (Path(MODEL_DIR), Path(DATA_DIR)):
    d.mkdir(parents=True, exist_ok=True)

In [None]:
# download data and models from huggingface

snapshot_download(
    repo_id="spherical-chisel/ChessGPT-Interp",
    repo_type="dataset",
    local_dir=DATA_DIR,
)

snapshot_download(
    repo_id="spherical-chisel/ChessGPT-Interp",
    repo_type="model",
    local_dir=MODEL_DIR,
)

In [None]:
# utils for loading the model
with open(f"{MODEL_DIR}meta.pkl", "rb") as f:
    meta = pickle.load(f)

stoi, itos = meta["stoi"], meta["itos"]
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])

meta_round_trip_input = "1.e4 e6 2.Nf3"
print(encode(meta_round_trip_input))
print("Performing round trip test on meta")
assert decode(encode(meta_round_trip_input)) == meta_round_trip_input

def get_transformer_lens_model(
    model_name: str, n_layers: int, device: torch.device
) -> HookedTransformer:

    cfg = HookedTransformerConfig(
        n_layers=n_layers,
        d_model=D_MODEL,
        d_head=int(D_MODEL / N_HEADS),
        n_heads=N_HEADS,
        d_mlp=D_MODEL * 4,
        d_vocab=32,
        n_ctx=1023,
        act_fn="gelu",
        normalization_type="LNPre",
    )
    model = HookedTransformer(cfg)
    model.load_state_dict(torch.load(f"{MODEL_DIR}{model_name}.pth"))
    model.to(device)
    return model

# convert each game transcript into a sequence of integer token IDs
def get_board_seqs_int(df: pd.DataFrame):
    encoded_df = df["transcript"].apply(encode)
    board_seqs_int_Bl = torch.tensor(encoded_df.apply(list).tolist())
    return board_seqs_int_Bl 

# extract the game string from the dataframe
def get_board_seqs_string(df: pd.DataFrame):

    key = "transcript"
    row_length = len(df[key].iloc[0])

    assert all(
        df[key].apply(lambda x: len(x) == row_length)
    ), "Not all transcripts are of length {}".format(row_length)

    board_seqs_string_Bl = df[key]

    return board_seqs_string_Bl

# load the model
dataset_prefix = "lichess_"
n_layers = 8
model_name = f"tf_lens_{dataset_prefix}{n_layers}layers_ckpt_no_optimizer"

model = get_transformer_lens_model(model_name, n_layers, DEVICE)

In [None]:
# load data
input_file = "data/lichess_train.csv"

df = pd.read_csv(input_file)
df = df[:10000] # we use the first 10,000 games for analysis

## Model Overview

We examine a 8-layered GPT model pretrained by Karvonen. The model uses a 512-dimension hidden space and 8 heads per layer, with a total parameter count of 25 million. The input the model is a chess PGN (1.e4 e5 2.Nf3 ...) with a maximum length of 1023 characters. Each character represents an input token, and the model’s vocabulary is restricted to the 32 characters required to construct chess PGN strings. Additionally, each game string starts with a ';' token.

The model was trained to autoregressively generate the next character of the PGN. As reported in Karvonen’s work, the model has a legal move rate of 99.6% and an ELO rating of roughly 1300 with a win rate of 46% against Stockfish 16 level 0.

Using the post-MLP residual stream as input, we follow Karvonen and train a linear probe that classifies every square on the chessboard into one of 13 states: blank, or one of the six piece types (pawn, knight, bishop, rook, queen, king), each in white or black.

We make the following key observations:
- Accuracy jumps from 88.7 % to 99.0 % between layers 4 and 5.  
- Fitting a probe at layer 5 on the pre-MLP residual stream also reaches 99.0 %, implying the board representation originates in attention rather than the MLP.  
- There is no drop in accuracy when the probe is tested on randomly generated games

## Data Overview

The dataset we use is also obtained from Karvonen<sup>1</sup>. The data consists of games from the Lichess open database. During probing and analysis, we create a set of 10,000 games not found in the model’s training data. All games in this set were truncated to a length of 365 tokens, as this was the median length of a game.

## Attention Overview

The PGN strings that are input to the model represent game moves in standard algebraic notation. In general, each move is indicated by a letter denoting the piece type followed by the coordinates of the destination square. The coordinates are then further split into a letter a-h representing the file and a number 1-8 denoting the rank. For pawn moves, the letter indicating the piece is omitted (e.g., c5). Additionally, there is special notation for captures, checks, castling, and piece disambiguation.

Furthermore, the PGN move list is serialized as a space-separated sequence with the following pattern:

```<Move-number>. <White-move> <Black-move> <Move-number+1>. <White-move> <Black-move> ...```

When probing for the board state, we use the residual stream vectors at each dot token. The equivalent token from the Black perspective is the whitespace after the white move.

Previous literature shows that the model represents the pieces with a (Mine, Yours) scheme. Hence, all of our analysis is done solely from the White perspective, as the opposite side is likely symmetric. Some interesting observations (mostly by visual inspection) are highlighted:

- In layer 2, information about the parity, piece type, and file are moved to the rank token. This is sensible because the rank number is the last token of a move and is the prime candidate for collecting the information in a move
- In layer 3, we see a proto-"previous move head" (Head 3.2) for knights and bishops
- In layer 4, there is a more robust previous move head (Head 4.0), a head that tracks the King position (Head 4.1), companion piece heads (Head 4.3), and a gather all head (Head 4.6).
- In layer 5, Head 5.0 tracks opponent pawns, Head 5.1 tracks your pawns, Head 5.2 tracks your queen/rooks/knights, Head 5.3 tracks your kingside bishop, Head 5.4 sometimes attends to the piece being moved, Head 5.5 tracks opponent bishops/knights, Head 5.6 tracks your queenside bishop/knight, and Head 5.7 tracks opponent queen/rooks

Moreover, once piece information has been aggregated in the rank token, heads in later layers primarily attend to the move’s rank token and the start of game delimiter with some exceptions (e.g., captures). Thus, when we refer to the token for a certain move, we specifically refer to the rank token of the move. The aforementioned descriptions of some of the heads will become more clear later on. Below is code for visualizing the attention that you can play around with.

---
<sup>1</sup> More details and code can be found: https://github.com/adamkarvonen/chess_llm_interpretability



In [None]:
game = 4 # interested game

game_int_seqs = get_board_seqs_int(df[:100])

raw_tokens_for_game = game_int_seqs[game] 
if isinstance(raw_tokens_for_game, torch.Tensor):
    raw_tokens_for_game = raw_tokens_for_game.tolist()

inp_tokens = torch.tensor([raw_tokens_for_game]).to(DEVICE) # [1, sequence_length]

resid_post_dict_BLD_viz = {}
with torch.inference_mode():
    _, cache = model.run_with_cache(inp_tokens, return_type=None) # Get all activations 
    for layer in range(n_layers):
        resid_post_dict_BLD_viz[layer] = cache[f"blocks.{layer}.attn.hook_pattern"]


In [None]:
# visualization parameters
nm = 200       # number of tokens to visualize
layer_to_viz = 5 # layer index to visualize attention from

str_tokens = [itos[token_id] for token_id in raw_tokens_for_game]

# ensure nm is not greater than the actual sequence length
max_seq_len_viz = min(nm, len(str_tokens))
token_sub = str_tokens[:max_seq_len_viz]

attention_for_layer = resid_post_dict_BLD_viz[layer_to_viz]
attention_to_visualize = attention_for_layer[0, :, :max_seq_len_viz, :max_seq_len_viz]

display(cv.attention.attention_patterns(
    tokens=token_sub,
    attention=attention_to_visualize
))

## Previous Move Heads

In algebraic notation a move only includes the destination square and omits information about the square the piece moved from (except in rare cases where disambiguation is necessary).

Howeover, heads 3.2 and 4.0 capture this "from square" information by attending to the "to square" of the previous move by that piece. Moreover, it seems that while Head 4.0 is a generic previous move head, Head 3.2 is a weaker version that specializes in finding previous moves for knights, bishops and kings.

We observe that for these heads, the rank token position of a move attends to the rank token position of its previous move. For example, if a knight moves from its initial square to c3 (denoted Nc3) and then later moves from c3 to d5 (denoted Nd5), the "5" will direct most of its attention to the "3".

To quantify this, for moves where the piece has already moved at least once, we identify the token with the maximal attention coefficient. Then, we check whether this matches the true previous move token.

In [None]:
def filter_prev_move_indices(moves_string: str, interested_pieces: list[int]) -> tuple[list[int]]:
    """
    Scan a PGN move string and return (current move token index, previous move token index)
    for any move involving a piece in 'interested_pieces' that has a corresponding
    previous move

    - Token indices point at the rank character of the move.
    - Example: In "Nc3 ... Nd5 ...", the "5" in Nd5 pairs with the "3" in Nc3.
    """

    # find the file/rank character of the pieces
    indices = [idx for idx, char in enumerate(moves_string) if char == " "]
    indices = indices[::2]
    indices = [idx - 1 for idx in indices]

    for i, idx in enumerate(indices):
        if moves_string[idx] == "+":
            indices[i] -= 1
            
    # the final move may also be valid
    indices.append(len(moves_string) - 1)

    # given a move, we know the to square and from square. There exists a corresponding
    # previous move if a previous move with the same piece type moved to the from square
    move_squares = [] 
    ret = []

    moves = moves_string.strip().split(".")[1:]

    # omit the first king move: it is unclear how castling is attended to
    king_moved = False

    board = chess.Board()
    for i, move in enumerate(moves):
        # process the white move
        t = move.strip().split()
        try:
            white_move = t[0]
            mv = board.parse_san(white_move)
            piece = board.piece_at(mv.from_square)
            board.push_san(white_move)
        except:
            break

        to_square = mv.to_square 

        if piece.piece_type != chess.KING or king_moved:
            move_squares.append((indices[i], to_square))

        if piece.piece_type == chess.KING:
            king_moved = True

        # check if there is a corresponding previous move
        if piece.piece_type in interested_pieces:
            from_square = mv.from_square

            for idx, square in reversed(move_squares):
                if square == from_square:
                    ret.append((indices[i], idx))
                    break

        # process the black move
        try:
            black_move = t[1]
            board.push_san(black_move)
        except:
            break

    return zip(*ret)

def compute_prev_piece_acc(
        df,
        num_games=100,
        interested_pieces=[chess.KNIGHT]):
    """
    Measure how often the "previous-move" attention heads in layer 3 and 4 point correctly.

    For each move with a corresponding previous move, see if the highest-attended token index matches the  
    true index of that previous move.

    Returns:
        (accuracy_layer3, accuracy_layer4)
    """

    def l3_attn_hook(attn, hook):
        # attn: [batch, heads, seq, seq]
        layer_attn[3] = attn.detach()
        return attn

    def l4_attn_hook(attn, hook):
        # attn: [batch, heads, seq, seq]
        layer_attn[4] = attn.detach()
        return attn

    # create attention hooks
    l3_hook_name = "blocks.3.attn.hook_pattern"
    l4_hook_name = "blocks.4.attn.hook_pattern"

    hooks = [(l3_hook_name, l3_attn_hook), (l4_hook_name, l4_attn_hook)]

    l3_head = 2
    l4_head = 0

    l3_correct = 0
    l4_correct = 0
    total = 0

    prev_move_df = df[:num_games]

    board_seqs_int_Bl = get_board_seqs_int(prev_move_df)
    board_seqs_str_Bl = get_board_seqs_string(prev_move_df)

    for seqs_int, seq_str in tqdm(zip(board_seqs_int_Bl, board_seqs_str_Bl),
                                    total=len(board_seqs_int_Bl)):
        try:
            indices, prev_indices = filter_prev_move_indices(moves_string=seq_str, 
                                                            interested_pieces = interested_pieces)
        except ValueError: 
            # sometimes there are no values to unpack
            continue

        layer_attn = {}

        # hook into layers to find the attention matrices
        with model.hooks(fwd_hooks=hooks):
            model(seqs_int.unsqueeze(0))

        indices_tensor = torch.tensor(indices)

        l3_attn = layer_attn[3][0, l3_head, indices_tensor, :]
        l4_attn = layer_attn[4][0, l4_head, indices_tensor, :]

        # find the highest-attended tokens
        l3_from = torch.argmax(l3_attn, dim=1)
        l4_from = torch.argmax(l4_attn, dim=1)

        # get the true previous move indices
        prev_idx = torch.as_tensor(prev_indices, device=l3_attn.device)

        # book‑keeping
        batch_size = prev_idx.numel()
        total      += batch_size

        # correctness counts
        l3_correct += (l3_from == prev_idx).sum().item()
        l4_correct += (l4_from == prev_idx).sum().item()

        del layer_attn
        if DEVICE == "mps":
            torch.mps.empty_cache()
        elif DEVICE == "cuda":
            torch.cuda.empty_cache()

    return (l3_correct / total, l4_correct / total)

How well the heads perform vary with the piece type. It looks like the heads have more difficulty with longe range pieces like rooks (\~83% accurate) and queens (\~91% accurate). The model can "cheat" for bishops by treating them as distinct; the c1 bishop will always be on a dark square, and the f1 bishop will always be on a light square. 

Furthermore, the model struggles with pawns (\~56% accurate). Due to the limited movement range of a pawn, the destination square often already contains a lot of information about the origin square. Hence, the model may not need to explicitly find the previous move to infer the from square.

In [None]:
interested_piece = chess.KNIGHT # change to the interested piece type (e.g., chess.BISHOP)

l3_acc, l4_acc = compute_prev_piece_acc(
    df=df,
    num_games=1000, # change to the desired number of games.
    interested_pieces = [interested_piece] # modify for the desired set of piece types
)

print(f"Layer 3 accuracy: {(100 * l3_acc):.2f}")
print(f"Layer 4 accuracy: {(100 * l4_acc):.2f}")

## Layer 5 Pawn Average Head

From the attention, we see that Head 5.1 attends mostly to pawn moves. Hence, we look for the linear representation of the pawn structure in this head. We propose the following method for computing the pawn structure:

### Head Value Vector
For every pawn move $t$, let from($t$) and to($t$) denote the from-square and to-square of the move. Then, let $v_s \in \mathbb{R}^d$ be a vector representation of the square $s$. We claim that the pre-projection value vector $z_t  = W^Vx_t$ computed by Head 5.1 satisfies:

$$
z_t \approx -v_{\text{from}(t)} + v_{\text{to}(t)}
$$

### Averaging

Then, at move $T$<sup>1</sup>, taking the average of $z_t$ over all pawn moves $t$ that occur before $T$ yields:

$$
Z_T = \frac{1}{n} \sum_{t \leq T}z_t = \frac{1}{n} \left(-\sum_{t \leq T} v_{\text{from}(t)} + \sum_{t \leq T} v_{\text{to}(t)}\right)
$$

Now, let us track a particular pawn. Every square, besides the initial square, that this pawn moves from must have previously been moved to. For example, if a pawn moves *from* e4 to e5, at some point, the pawn must have been moved *to* e4. Thus, the two sums nearly cancel out and we are left with:

$$
\frac1n \Bigl(
      v_{\text{final}(p)}
      - v_{\text{initial}(p)}
   \Bigr)
$$

If we let $P$ be the set of all moved pawns, we can sum over all pawns to obtain:
$$
Z_T = \frac1n \Bigl(
      \sum_{p \in P} v_{\text{final}(p)} - \sum_{p \in P} v_{\text{initial}(p)}
   \Bigr)
$$

Unfortunately, this doesn't give us a completely clean linear representation of the pawn structure, which would be $\sum v_{\text{final}(p)}$; how the model deals with the $\frac 1 n$ scaling and initial pawn positions is outside the scope of this work. However, we offer some discussion on this at the end of the section.

### Captures
When the opponent captures a pawn, the head attends to the dot token following the capturing move. Then, this capture token $t$ can erase the pawn on the square by having the value $z_t = -v_{\text{capture square}}$. In our analysis, we treat these captures as "pawn moves" and include it in the $Z_T$ sum.

---

<sup>1</sup> When we say "move $T$", we generally refer to the dot token following that move. For example, given the game string:

```text
1.e4 c5 2.Nf3 Nc6 3.Nc3 e5 4.Bc4 d6 5.d3 Be7 6.Nd5
```
the $T$ that contains all of the pawn moves up to and including d3 would the index of the dot following the 6.


In [None]:
# load precomputed data (code can be found in generate_data.py)

cache_file = os.path.join(DATA_DIR, "precomputed_game_cache.pt")

with open(cache_file, "rb") as f:
    data = torch.load(f, weights_only=False)

# per move data
indices = data["index"] # the indices of the rank token for each move
head_v = data["head_v"] # the value vectors associated with the rank token position
to_squares = data["to"] # for each move, the to-square
frm_squares = data["from"] # for each move, the from-square
piece_types = data["piece_type"] # for each move, the piece type (the captured piece if it was a capture by the opponent)

# per dot data
dots_indices = data["dots_game_index"] # the indices of the dots
dots_attn = data["dots_attn"] # the attention matrices for the dots
board_stacks = data["board_state"] # the board state at the dot

# perform sanity checks on the data
n = len(to_squares)
for i in range(n):
    piece_len = to_squares[i].shape[0]

    assert frm_squares[i].shape[0] == piece_len
    assert len(piece_types[i]) == piece_len    
    assert head_v[i].shape[1] == piece_len
    assert indices[i].shape[0] == piece_len
    assert board_stacks[i].shape[0] == piece_len

    dots_len = dots_indices[i].shape[0]
    
    assert dots_attn[i].shape[2] == dots_len
    assert dots_attn[i].shape[3] == piece_len

### Cosine Similarity Analysis

To see if the model is truly taking an average over the pawn moves $t$, we compare $Z_T$ to the true head contribution. We say that the true contribution is given as:

$$
H_T = \sum_{t \leq T} \alpha_t z_t
$$

where $\alpha_t$ is the attention score for token $t$ at move $T$. We emphasize that this **only includes the contributions from the rank tokens**, as they tend to have the largest attention weights. We exclude the delimiter (;) and all other non-rank tokens from the computation.

In this section, we compute the cosine similarity between $H_T$ and $Z_T$.


In [None]:
def get_piece_sims(interested_head, alpha, piece_type, num_games=100, parity_req = None):
    """
    Check cosine similarity between actual head and reconstructed head for piece piece_type.
    The parity requirement restricts the partiy of the square (e.g., 0 for dark squares and 1 for light squares)
    and is used to distinguish between the bishops.

    Computes and returns the similarities between the head and:
     - EMA with decay factor alpha
     - most recent move
     - average
    """
    sims = []
    last_sims = []
    avg_sims = []

    for i in tqdm(range(num_games)):
        piece_len = to_squares[i].shape[0]
        piece_indices = []

        cur_ema = 0
        cur_sum = 0

        emas = []
        sums = []

        for j in range(piece_len):
            piece = piece_types[i][j][1]
            square = to_squares[i][j]
            rank = square // 8
            file = square % 8
            parity = (rank + file) % 2
            meets_parity = True
            if parity_req is not None:
                meets_parity = (parity == parity_req)

            if piece == piece_type and meets_parity:
                piece_indices.append(j)

                piece_v = head_v[i][0, j, interested_head, :]

                # update sum and ema
                cur_ema = alpha * cur_ema + (1 - alpha) * piece_v
                cur_sum = cur_sum + piece_v

            emas.append(cur_ema) 
            sums.append(cur_sum)

        dots_len = dots_indices[i].shape[0]
        for j in range(dots_len):
            dot_idx = dots_indices[i][j]
            cur_piece_indices = [idx for idx in piece_indices if indices[i][idx] <= dot_idx]

            # skip if no piece moves yet
            if len(cur_piece_indices) == 0:
                continue

            last_idx = cur_piece_indices[-1]

            # find the most recent avg and ema values at this move
            last_ema = emas[last_idx]
            last = head_v[i][0, last_idx, interested_head, :]
            last_avg = sums[last_idx] / len(cur_piece_indices)

            cur_piece_indices = torch.tensor(cur_piece_indices)

            # compute the true head value from the piece
            attn = dots_attn[i][0, interested_head, j, :][cur_piece_indices]
            v = head_v[i][0, cur_piece_indices, interested_head, :]
            act = torch.einsum("l, l d -> d", attn, v)

            # compute cosine similarities
            sim = torch.nn.functional.cosine_similarity(last_ema, act, dim=0)
            sims.append(sim.item())

            sim = torch.nn.functional.cosine_similarity(last, act, dim=0)
            last_sims.append(sim.item())

            sim = torch.nn.functional.cosine_similarity(last_avg, act, dim=0)
            avg_sims.append(sim.item())

    return sims, last_sims, avg_sims

The averaging scheme obtains a cosine similarity of roughly 0.9 with the true head contribution. The baseline of just using the last pawn move's value vector obtains a cosine similarity of roughly 0.44

In [None]:
_, last_sims, avg_sims = get_piece_sims(
    interested_head=1,
    alpha=0,
    piece_type=chess.PAWN,
    num_games=1000,
    parity_req=None
)

print(f"Cosine Similarity between most recent move and head output: {torch.mean(torch.tensor(last_sims)).item():.4f}")
print(f"Cosine Similarity between average and head output: {torch.mean(torch.tensor(avg_sims)).item():.4f}")

### Value Vector Analysis

Earlier, we claimed that the value vector $z$ is the sum $-v_{\text{from}} + v_{\text{to}}$. Then, we should be able to extract the to and from squares from $z$ using a linear probe. Let $y_t \in \{0,1\}^{64}$ be the one-hot encoding of the to-square of the move. We train a linear probe  
$$
\hat{y}_t = \mathrm{softmax}\bigl(W\,z_t\bigr)
$$

by minimizing the average cross-entropy loss. We also repeat this process for the from squares.



In [None]:
def generate_to_from_stack(head, piece_reqs, num_games=100, include_dots=False):
    """
    Piece reqs is a tuple that contains the piece type (e.g., chess.KNIGHT) and the square parity:
    None, 0, or 1 for no square restrictions, darks squares, and light squares respectively

    For each move of the specified piece type, this function returns the head output, the to square, 
    and the from square
    """
    piece_type = piece_reqs[0]
    parity_req = piece_reqs[1]

    piece_state_stack = []
    piece_to_stack = []
    piece_from_stack = []

    for i in tqdm(range(num_games)):
        piece_len = to_squares[i].shape[0]

        dots_index_list = dots_indices[i].tolist()

        for j in range(piece_len):
            if not include_dots and indices[i][j] in dots_index_list: 
                continue

            piece = piece_types[i][j][1]

            square = to_squares[i][j]
            rank = square // 8
            file = square % 8
            parity = (rank + file) % 2

            meets_parity = True
            if parity_req is not None:
                meets_parity = (parity == parity_req)

            if piece == piece_type and meets_parity:
                piece_v = head_v[i][0, j, head, :]

                piece_state_stack.append(piece_v)
                piece_to_stack.append(to_squares[i][j])
                piece_from_stack.append(frm_squares[i][j]) 

    return torch.stack(piece_state_stack), torch.stack(piece_to_stack), torch.stack(piece_from_stack)

The information for when the opponent captures a pawn is stored in the dots. Since the dots do not have "to squares", we ignore them for now.

In [None]:
pawn_state_stack, pawn_to_stack, pawn_from_stack = generate_to_from_stack(1, (chess.PAWN, None), num_games=10000, include_dots=False)

In [None]:
# probe training utils
# the same default hyperparameters are used to train all probes

@dataclass
class TrainingParams:
    wd: float = 0.01
    lr: float = 0.001
    beta1: float = 0.9
    beta2: float = 0.99
    max_train_games: int = 10000
    max_test_games: int = 10000
    max_val_games: int = 1000
    max_iters: int = 50000
    eval_iters: int = 50
    num_epochs: int = 100

@dataclass
class LinearProbe:
    linear_probe: torch.Tensor
    probe_name: str
    optimiser: torch.optim.AdamW
    loss: torch.Tensor = torch.tensor(0.0)
    accuracy: torch.Tensor = torch.tensor(0.0)
    accuracy_queue: collections.deque = field(
        default_factory=lambda: collections.deque(maxlen=1000)
    )

def create_linear_probe(train_params, num_classes, has_rc=False, num_rows=8, num_cols=8, dim=64):
    linear_probe_name = "probe"

    if has_rc:
        linear_probe_DC = torch.randn(
            dim,
            num_rows,
            num_cols,
            num_classes,
            requires_grad=False,
            device=DEVICE,
        ) / torch.sqrt(torch.tensor(D_MODEL))
    else:
        linear_probe_DC = torch.randn(
            dim,
            num_classes,
            requires_grad=False,
            device=DEVICE,
        ) / torch.sqrt(torch.tensor(D_MODEL))

    linear_probe_DC.requires_grad = True

    optimiser = torch.optim.AdamW(
        [linear_probe_DC],
        lr=train_params.lr,
        betas=(train_params.beta1, train_params.beta2),
        weight_decay=train_params.wd,
    )
    linear_probe = LinearProbe(
        linear_probe=linear_probe_DC,
        probe_name=linear_probe_name,
        optimiser=optimiser,
    )
    return linear_probe

def linear_probe_forward_rc(probe, batch_data, batch_labels):
    logits = torch.einsum("bd,dxyc->bcxy", batch_data, probe.linear_probe)
    loss = F.cross_entropy(logits, batch_labels)
    return logits, loss

def linear_probe_forward_mse(probe, batch_data, batch_labels):
    logits = torch.einsum("bd,dc->bc", batch_data, probe.linear_probe)
    loss = F.mse_loss(logits, batch_labels)
    return logits, loss

def linear_probe_forward(probe, batch_data, batch_labels):
    logits = torch.einsum("bd,dc->bc", batch_data, probe.linear_probe)
    loss = F.cross_entropy(logits, batch_labels)
    return logits, loss

def train_probe(probe, state_stack, label_stack, batch_size, num_epochs, probe_fwd):
    """
    Trains the probe with the state_stack as inputs, label_stack as labels. The loss
    function is specified in probe_fwd

    Returns the validation/training loss and accuracy.
    """

    labels = label_stack.to(DEVICE)

    VAL_FRAC      = 0.10
    num_samples   = state_stack.size(0)
    num_val       = int(num_samples * VAL_FRAC)
    num_train     = num_samples - num_val

    dataset       = TensorDataset(state_stack, labels)
    train_set, val_set = random_split(dataset, [num_train, num_val])

    train_loader  = DataLoader(train_set, batch_size=batch_size,
                            shuffle=True,  drop_last=False)
    val_loader    = DataLoader(val_set,   batch_size=batch_size,
                            shuffle=False, drop_last=False)

    last_val_acc = None
    last_val_loss = None
    last_train_acc = None
    last_train_loss = None

    for epoch in range(num_epochs):
        # train forward pass
        for batch_data, batch_labels in tqdm(train_loader):
            batch_data   = batch_data.to(DEVICE)
            batch_labels = batch_labels.to(DEVICE)

            logits, probe.loss = probe_fwd(probe, batch_data, batch_labels)

            preds = logits.argmax(dim=1)

            # if mse is the loss, accuracy is not a meaningful metric
            if probe_fwd != linear_probe_forward_mse:
                probe.accuracy = (preds == batch_labels).float().mean() 
                probe.accuracy_queue.append(probe.accuracy.item())
                probe.accuracy = torch.tensor(sum(probe.accuracy_queue) / len(probe.accuracy_queue))

            probe.optimiser.zero_grad()
            probe.loss.backward()
            probe.optimiser.step()

        # validation
        val_loss = 0.0
        val_correct = 0
        val_seen = 0

        with torch.no_grad():
            for vdata, vlabels in val_loader:
                vdata   = vdata.to(DEVICE)
                vlabels = vlabels.to(DEVICE)

                logits, loss = probe_fwd(probe, vdata, vlabels)

                val_loss += loss.item() * vlabels.size(0)

                preds = logits.argmax(dim=1)
                if probe_fwd != linear_probe_forward_mse:
                    val_correct += (preds == vlabels).float().mean() * vlabels.size(0)
                val_seen += vlabels.size(0)

        val_accuracy = val_correct / val_seen
        val_loss /= val_seen

        last_val_loss = val_loss
        last_val_acc = val_accuracy
        last_train_acc = probe.accuracy.item()
        last_train_loss = probe.loss.item()

        print(f"Epoch {epoch:3d} │ "
            f"train loss {probe.loss.item():.5f} │ "
            f"train acc {probe.accuracy.item():.4f} │ "
            f"val loss {val_loss:.5f} │ "
            f"val acc {val_accuracy:.4f}")

    return last_val_loss, last_val_acc, last_train_loss, last_train_acc

def test_probe(
    probe,
    state_stack,
    label_stack,
    batch_size=512,
    probe_fwd=None,
):
    """
    Tests the probe with the state_stack as inputs, label_stack as labels. The loss
    function is specified in probe_fwd.

    Returns the loss and accuracy on the dataset.
    """

    assert probe_fwd is not None, "`probe_fwd` helper (forward pass) must be supplied"

    labels = label_stack.to(DEVICE)

    loader = DataLoader(
        TensorDataset(state_stack, labels),
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
    )

    val_loss = 0.0
    val_correct = 0
    val_seen = 0

    with torch.no_grad():
        for vdata, vlabels in loader:
            vdata   = vdata.to(DEVICE)
            vlabels = vlabels.to(DEVICE)

            logits, loss = probe_fwd(probe, vdata, vlabels)

            val_loss += loss.item() * vlabels.size(0)

            preds = logits.argmax(dim=1)
            if probe_fwd != linear_probe_forward_mse:
                val_correct += (preds == vlabels).float().mean() * vlabels.size(0)
            val_seen += vlabels.size(0)

    val_accuracy = val_correct / val_seen
    val_loss /= val_seen

    return val_loss, val_accuracy



In [None]:
probe_params = TrainingParams() # We use the same hyperparameters across all linear probes
pawn_to_probe = create_linear_probe(probe_params, 64)
pawn_from_probe = create_linear_probe(probe_params, 64)

Both probes achieve an accuracy of over 99.7%, which strongly indicates that $z_t$ encodes both the origin and destination square.

In [None]:
pawn_vl_to, pawn_va_to, pawn_tl_to, pawn_ta_to = train_probe(pawn_to_probe, 
                            pawn_state_stack,
                            pawn_to_stack,
                            batch_size=64,
                            num_epochs=3,
                            probe_fwd=linear_probe_forward
                        ) 

pawn_vl_from, pawn_va_from, pawn_tl_from, pawn_ta_from = train_probe(pawn_from_probe, 
                            pawn_state_stack,
                            pawn_from_stack,
                            batch_size=64,
                            num_epochs=3,
                            probe_fwd=linear_probe_forward
                        ) 

print(f"To square probe final validation accuracy: {100 * pawn_va_to:.2f}")
print(f"From square probe final validation accuracy: {100 * pawn_va_from:.2f}")

Another assumption we make is that the to and from square vectors belong to the same subspace and point in opposite directions. This is important for cancellation during averaging, since "moving to square $s$" should negate "moving from square $s$".

If this is truly the case, then negating $z_t$ should reverse the roles of the to and from-squares. Hence, we combine the to and from-square datasets such that for each move $t$, we include both $(z_t, \text{to square}_t)$ and $(-z_t, \text{from square}_t)$. 

In [None]:
combined_pawn_state_stack = torch.cat([pawn_state_stack, -pawn_state_stack])
combined_pawn_square_stack = torch.cat([pawn_to_stack, pawn_from_stack])

combined_pawn_probe = create_linear_probe(probe_params, 64)

The probe obtains an accuracy of 99.9%, demonstrating that negating the value vector reliably swaps the to and from labels.


In [None]:
pawn_vl, pawn_va, pawn_tl, pawn_ta = train_probe(combined_pawn_probe,
                            combined_pawn_state_stack,
                            combined_pawn_square_stack,
                            batch_size=64,
                            num_epochs=3,
                            probe_fwd=linear_probe_forward
                        ) 

print(f"Combined final validation accuracy: {100 * pawn_va:.2f}")


### Value Vector Reconstruction

Finally, we want to confirm that the to and from vectors make up most of the information that $z_t$ carries. Thus, we examine whether we can reconstruct the value vector using only the to and from square information (and just the from square for captures).


In the following section, we create a 128-dimensional input for each pawn move by concatenating the two 64-dimensional one-hot square vectors. Then, we train a linear probe to reconstruct the head value vector using a MSE loss.

In [None]:
pawn_state_stack_all, pawn_to_stack_all, pawn_from_stack_all = generate_to_from_stack(1, (chess.PAWN, None), num_games=10000, include_dots=True)

# split the dataset between the captures and normal moves (captures don't have "to" squares)
normal_idx = (torch.nonzero((pawn_from_stack_all != -1), as_tuple=True))[0]
capture_idx = (torch.nonzero((pawn_from_stack_all == -1), as_tuple=True))[0]

# extract the move value vectors (the target vectors we want to reconstruct)
normal_state_stack = pawn_state_stack_all[normal_idx]
capture_state_stack = pawn_state_stack_all[capture_idx]

normal_to = pawn_to_stack_all[normal_idx]
normal_from = pawn_from_stack_all[normal_idx]

capture_to = pawn_to_stack_all[capture_idx]

to_idx = normal_to.squeeze(-1)
from_idx = normal_from.squeeze(-1)

to_onehot   = F.one_hot(to_idx,  num_classes=64).float()
from_onehot = F.one_hot(from_idx, num_classes=64).float()

# the to square should never equal the from squares
assert (to_idx != from_idx).all(), "Found a sample where to‑idx == from‑idx!"

to_from_stack = torch.cat([from_onehot, to_onehot], dim=1)

to_idx_capture = capture_to.squeeze(-1)

# for captures, the capture square is negated, so it should belong with the from squares
from_onehot_capture = F.one_hot(to_idx_capture, num_classes=64).float()
to_onehot_capture = torch.zeros_like(from_onehot_capture)

to_from_stack_capture = torch.cat([from_onehot_capture, to_onehot_capture], dim=1)

to_from_stack_all = torch.cat([to_from_stack, to_from_stack_capture])
pawn_state_stack = torch.cat([normal_state_stack, capture_state_stack])

The reconstruction probe obtains a validation MSE of around 0.19. The average variance of the value vector is \~5 for capture moves and \~1.8 for non-capture moves. Although these are different, we can take the overall average variance (\~4.33) to estimate the $R^2$:

$$
R^2 = 1 - \frac{\text{MSE}}{{VAR}} = 1 - \frac{0.19}{0.433} \approx 0.956
$$



In [None]:
print(f"Mean variance of the value vectors: {torch.mean(torch.var(pawn_state_stack_all, dim=1)).item():.4f}")

In [None]:
pawn_reconstruction_probe = create_linear_probe(
    train_params=probe_params,
    num_classes=64,
    dim=128
)

In [None]:
pawn_reconstruction_vl, _, pawn_reconstruction_tl, _ = train_probe(
    pawn_reconstruction_probe,
    to_from_stack_all,
    pawn_state_stack,
    64,
    5,
    linear_probe_forward_mse
)

print(f"Final validation MSE loss for pawn reconstruction: {pawn_reconstruction_vl:.4f}")

We observe that the norms of the capture moves is lower than the norm of the non capture moves. This makes sense since $z_t = -v_{\text{from}(t)}$ for captures whereas $z_t = -v_{\text{from}(t)} + v_{\text{to}(t)}$ for non capture moves. If the square vectors were all the same magnitude, then we would expect the non capture magnitudes to be roughly $\sqrt{2}$ times the capture magnitudes.

In [None]:
norms = torch.norm(capture_state_stack, dim=1)
print(f"Mean of norms for capture states: {torch.mean(norms).item():2f}")
print(f"Standard deviation of norms for capture states: {torch.std(norms).item():2f}")
norms = torch.norm(normal_state_stack, dim=1)
print(f"Mean of norms for non capture states: {torch.mean(norms).item():2f}")
print(f"Standard deviation of norms for non capture states: {torch.std(norms).item():2f}")

### Board Reconstruction

We want to verify that the average $Z_T$ is capable of recovering the full pawn structure. For each of the 64 squares, we train a linear probe to classify whether the square is blank or contains a white pawn. More formally, the probe computes:

$$
\hat y_{i,j, T} = \text{Softmax}(W_i Z_T)
$$

where $\hat y_{i,j, T}$ is the probability distribution over the two classes $j \in \{0, 1\}$ for square $i$ on move $T$.

In [None]:
def construct_avg_stack(head, piece_reqs, num_games=100):
    """
    For each dot index, we look at all of the moves by the piece (specified in piece reqs) and
    average the head output those moves. This information is bundled with the true head contribution
    as well as the board state.
    """

    piece_type = piece_reqs[0]
    parity_req = piece_reqs[1]

    sims = []

    head_state_stack = []
    avg_state_stack = []
    piece_board_stack = []

    for i in tqdm(range(num_games)):
        piece_len = to_squares[i].shape[0]

        dots_index_list = dots_indices[i].tolist()

        cur_sum = torch.zeros_like(head_v[i][0, 0, head, :])

        cur_piece_indices = []
        nm = 0

        for j in range(piece_len):
            piece = piece_types[i][j][1]

            square = to_squares[i][j]
            rank = square // 8
            file = square % 8
            parity = (rank + file) % 2

            meets_parity = True
            if parity_req is not None:
                meets_parity = (parity == parity_req)

            if piece == piece_type and meets_parity:
                cur_piece_indices.append(j)
                piece_v = head_v[i][0, j, head, :]

                # The board state at the dot includes the most recent black move. So, if the black move
                # was a capture, we need to replace the previous information
                replace = False
                nm += 1

                if indices[i][j] not in dots_index_list:
                    cur_sum = cur_sum + piece_v
                else:
                    # if the character is a dot, it was a capture by black
                    replace = True
                    cur_sum = cur_sum + piece_v

                # find the next dot
                dot_idx = bisect.bisect_left(dots_index_list, indices[i][j])

                if dot_idx >= len(dots_index_list):
                    break 
                
                # compute the true head contribution
                attn = dots_attn[i][0, head, dot_idx, :][torch.tensor(cur_piece_indices)]
                v = head_v[i][0, torch.tensor(cur_piece_indices), head, :]

                # multiply v and attn
                act = torch.einsum("l, l d -> d", attn, v)

                # compute cosine similarity between act and cur_sum 
                sim = torch.nn.functional.cosine_similarity(cur_sum, act, dim=0)

                cur_board_stack = board_stacks[i][j].clone()
                for r in range(8):
                    for c in range(8):
                        parity = (r + c) % 2
                        meets_parity = True
                        if parity_req is not None:
                            meets_parity = (parity == parity_req)

                        if not meets_parity or cur_board_stack[r][c] != piece_type:
                            cur_board_stack[r][c] = 0
                        else:
                            cur_board_stack[r][c] = 1

                # the previous board stack was invalid because it did not include the black capture
                if replace:
                    piece_board_stack[-1] = cur_board_stack
                    avg_state_stack[-1] = cur_sum / nm
                    head_state_stack[-1] = act
                    sims[-1] = sim.item()
                else:
                    piece_board_stack.append(cur_board_stack)
                    avg_state_stack.append(cur_sum / nm)
                    head_state_stack.append(act)
                    sims.append(sim.item())

    # print the cosine similarity for checks
    print(f"Final similarity: {torch.mean(torch.tensor(sims)).item():.4f}")

    return torch.stack(avg_state_stack), torch.stack(piece_board_stack), torch.stack(head_state_stack)

In [None]:
pawn_avg_stack, pawn_board_stack, pawn_head_out = construct_avg_stack(1, (chess.PAWN, None), num_games=10000)

This board probe achieves an accuracy of 99.5% (compared to a baseline of \~91% if all squares are guessed to be blank), confirming that the averaging scheme can accurately recover the pawn structure.

In [None]:
pawn_board_probe = create_linear_probe(train_params=probe_params,
                                       num_classes=2,
                                       has_rc=True)

In [None]:
vl, va, tl, ta = train_probe(
    pawn_board_probe,
    pawn_avg_stack,
    pawn_board_stack.long(),
    64,
    5,
    linear_probe_forward_rc
)

Finally, to see if this average scheme is similar to what the model is computing, we can apply the probe (trained on the averages) to the true head contribution. This achieves an accuracy of 98.9%, which suggests that the average and the true contribution encode the pawn structure similarly.

In [None]:
vl, va = test_probe(
    pawn_board_probe,
    pawn_avg_stack,
    pawn_board_stack.long(),
    probe_fwd=linear_probe_forward_rc
)

print(f"Accuracy using the computed average: {(100 * va.item()):.2f}")

vl, va = test_probe(
    pawn_board_probe,
    pawn_head_out,
    pawn_board_stack.long(),
    probe_fwd=linear_probe_forward_rc
)

print(f"Accuracy using the true head: {(100 * va.item()):.2f}")

### Open Questions

As mentioned earlier, $Z_T = \frac{1}{n} \left( \sum v_{\text{final}(p)} - \sum v_{\text{initial}(p)} \right)$ is not a perfect representation. We would like to scale by $n$ and add $v_{\text{initial}}$ for every single pawn.

We would also like to note that even a linear probe can somewhat overcome these limitations as the board probe could retrieve the pawn structure with relatively high accuracy. However, things get messier when multiple piece types are involved and the dimensions may become overloaded due to superposition.

#### Scaling
One way that the model can remedy the scaling issue is by using the ';' token at the beginning of the game string. When there are fewer pawn moves, it can dump more attention into the ';' token and even out the scaling as pawn moves are made. We anecdotally observe an increase in overall attention on the pawn moves (and less on the ';') as pawn moves are made, but have not investigated this rigorously.

Moreover, there are a finite number of pawn moves (48) and not *too* many pawn moves per game (especially not until the endgame), so it is also possible that not having the exact scale is ok.

#### Initial Pawn Positions
Another question that arises is how the model knows the "initial pawn positions." Since the initial positions are fixed, it is possible that the model could add a constant for each pawn (including unmoved pawns). 

However, a linear probe is expressive enough to deal with this issue and adopts an alternative strategy. We train a board probe (not included in this notebook) using the average of the 128-dimension to/from vectors in the value vector reconstruction section. We find that for most to-square dimensions, the probe injects a small positive bias onto all eight starting pawn squares. These small contributions are enough to mark the square as occupied, but when a pawn actually leaves the square, the subtraction term for the from-square overwhelms that bias.

In addition, the probe's learned "occupied" and "blank" vectors are almost antipodal, with a cosine similarity of -0.999. Because the probe's task here is binary classification, it can afford to have a large negative direction for blank squares. However, this strategy may not be viable when multiple piece types (and not enough dimensions) are involved.


## Layer 5 Exponential Moving Average Heads

In layer 5, multiple heads work together to construct the positions of the remaining pieces (except the King). Although taking the average (similar to the pawn average head) is a theoretically viable strategy, it seems that the model opts for a different tactic. 

We begin by **fixing** a piece type (e.g., knights).

### Head Value Vector
For every knight move $t$, let from($t$) and to($t$) denote the from-square and to-square of the move. Then, let $v_s \in \mathbb{R}^d$ be a vector representation of the square $s$ for knights. We claim that the pre-projection value vector $z_t  = W^Vx_t$ computed by Head 5.6 satisfies:

$$
z_t \approx -\alpha \: v_{\text{from}(t)} + \: v_{\text{to}(t)} + u_t
$$

with fixed decay factor $\alpha \in (0, 1)$, and auxiliary information $u_t$.

### EMA

Now, we claim the model keeps track of an exponential moving average $Z_T$ where $Z_0 = \mathbf{0}$ and:

$$
    Z_T = \alpha \: Z_{T-1} + (1 - \alpha) \: z_T
$$

If there were $n$ piece moves before $T$, we can denote $t_k$ as the $k$ th piece move. Expanding the recurrence yields:

$$
    Z_T = (1 - \alpha) \sum_{k=1}^n \alpha^{n-k} z_{t_k} 
$$

If we track only a single piece, then $v_{\text{from}(t_k)}$ = $v_{\text{to}(t_{k-1})}$ and cancel out. Then, we are left with 

$$
\begin{aligned}
Z_T &= (1 - \alpha) \: v_{\text{final}} - (1-\alpha)\alpha^{n-1} \: v_{\text{initial}} + u \\
&\approx (1 - \alpha) \: v_{\text{final}} + u
\end{aligned}
$$

where $u$ is the sum over all $u_t$. This essentially gives us the final position of the piece, as the initial vector decays exponentially. Similar to the averaging scheme, on a capture, we can let $\text{from}(t)$ be the capture square and $v_{\text{to}(t)} = \mathbf{0}$.

### Persisting Companion Pieces

The previous strategy works for "unique" pieces like bishops (the model distinguishes between the dark squared bishop and the light squared bishop) and queens<sup>1</sup>. However, we can extend the previous strategy for single pieces to pairs of pieces like knights and rooks. In particular, we can let

$$
u_{t} = (1 - \alpha) \: v_{\text{companion}(t)}
$$

where $\text{companion}(t)$ represents the square of the "other" knight/rook ($u_t = \textbf{0}$ if the other piece was captured). This effectively renews the position of the companion as  $\alpha(1 − \alpha) + (1 − \alpha)^2 = (1 − \alpha)$. Therefore, up to normalization, we are left with a linear representation of both the piece’s own position and that of its symmetric partner.

---
<sup>1</sup> Perhaps it is naive to treat the queen as a unique piece. While underpromotion to a minor piece (knight, bishop) is extremely rare, it is not uncommon to promote a pawn to a queen. So, it is possible (and even likely) that queens also have "companions," though this requires more investigation.





In [None]:
def grid_search(head, piece_type, parity_req, num_games):
    """
    Grid search for the best decay value alpha. Looks at all values between 0.1 and 0.9 inclusive
    """
    lst = []
    avgs = None
    last = None
    for i in range(9):
        alpha = (i + 1) * 0.1
        sims, last, avgs  = get_piece_sims(
            interested_head=head,
            alpha=alpha,
            piece_type=piece_type,
            num_games=num_games,
            parity_req=parity_req,
        )
        lst.append((alpha, torch.mean(torch.tensor(sims)).item()))
    print(f"Sim between true contribution and average: {torch.mean(torch.tensor(avgs)).item()}")
    print(f"Sim between true contribution and last: {torch.mean(torch.tensor(last)).item()}")
    return lst, torch.mean(torch.tensor(avgs)).item(), torch.mean(torch.tensor(last)).item()

def create_line_graph(data_points,
                      avg,
                      last,
                      title,
                      x_label="Decay factor (α)",
                      y_label="Cosine Similarity"):

    if not data_points:
        raise ValueError("data_points list is empty")

    x_vals, y_vals = zip(*data_points)

    plt.figure(figsize=(10, 6))
    
    plt.plot(x_vals, y_vals,
             marker='o', markersize=6,
             linewidth=2, label='EMA')

    
    plt.axhline(avg,  color='tab:green', linestyle='--',
                linewidth=2, label='Average')
    plt.axhline(last, color='tab:red',  linestyle=':',
                linewidth=2, label='Last-move')

    plt.xlabel(x_label, fontsize=16)
    plt.ylabel(y_label, fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)

    plt.grid(False)
    plt.legend(fontsize=14)
    plt.title(title, fontsize=20)
    plt.tight_layout()
    plt.show()

### Finding the Decay Factor (and cos sim analysis)

It is theoretically challenging to find the decay factor $\alpha$. So, we perform a grid search over $\alpha \in [0.1, 0.2, ... , 0.9]$ and empirically pick the $\alpha$ that maximizes the cosine similarity with the true contribution $H_T$. For most pieces, we obtain high cosine similarities of at least 0.97.

In [None]:
# Additional piece types:
rook_lst, rook_avg, rook_last = grid_search(2, chess.ROOK, None, 100)
create_line_graph(
    data_points=rook_lst,
    title="Rook similarities at Head 5.2",
    avg=rook_avg,
    last=rook_last,
    x_label="Decay factor (α)",
    y_label="Cosine Simlarity"
)
queen_lst, queen_avg, queen_last = grid_search(2, chess.QUEEN, None, 100)
create_line_graph(
    data_points=queen_lst,
    title="Queen similarities at Head 5.2",
    avg=queen_avg,
    last=queen_last,
    x_label="Decay factor (α)",
    y_label="Cosine Simlarity"
)
dsbishop_lst, dsb_avg, dsb_last = grid_search(6, chess.BISHOP, 0, 100)
create_line_graph(
    data_points=dsbishop_lst,
    title="Dark-squared Bishop similarities at Head 5.6",
    avg=dsb_avg,
    last=dsb_last,
    x_label="Decay factor (α)",
    y_label="Cosine Simlarity"
)
lsbishop_list, lsb_avg, lsb_last = grid_search(3, chess.BISHOP, 1, 100)
create_line_graph(
    data_points=lsbishop_list,
    title="Light-square Bishop similarities at Head 5.3",
    avg=lsb_avg,
    last=lsb_last,
    x_label="Decay factor (α)",
    y_label="Cosine Simlarity"
)

knight_lst, knight_avg, knight_last = grid_search(6, chess.KNIGHT, None, 100)
create_line_graph(
    data_points=knight_lst,
    title="Knight similarities at Head 5.6",
    avg=knight_avg,
    last=knight_last,
    x_label="Decay factor (α)",
    y_label="Cosine Simlarity"
)




### Value Vector Analysis

We repeat the previous value vector analysis. Since the value vector $z_t$ is the sum $-\alpha \: v_{\text{from}(t)} + \: v_{\text{to}(t)} + u_t$, we should again be able to extract the to and from squares using a linear probe. We also show that the to and from contribution vectors are opposite directions, and construct the combined to/from dataset in the same way.

In [None]:
piece_head_pairs = [
    ("knight", 6, (chess.KNIGHT, None, 0.6)),
    ("queen", 2, (chess.QUEEN, None, 0.3)),
    ("rook", 2, (chess.ROOK, None, 0.5)),
    ("dsbishop", 6, (chess.BISHOP, 0, 0.4)),
    ("lsbishop", 3, (chess.BISHOP, 1, 0.4)),
]

piece_state_stacks = {}
piece_to_stacks = {}
piece_from_stacks = {}

for name, head, reqs in piece_head_pairs:
    piece_state_stack, piece_to_stack, piece_from_stack = generate_to_from_stack(head, reqs, num_games=10000)
    piece_state_stacks[name] = piece_state_stack
    piece_to_stacks[name] = piece_to_stack
    piece_from_stacks[name] = piece_from_stack

    print(f"Computed to/from data for: {name}")


Although the magnitudes of the to and from-square vector contributions differ, a single linear probe trained on the combined set still achieves high accuracy, confirming that the to-square and from-square vector representations lie in the same subspace.

In [None]:
piece_names = ["knight", "queen", "rook", "dsbishop", "lsbishop"]

piece_to_probes = {}
piece_from_probes = {}
piece_to_from_probes = {}

piece_to_from_probe_results = {}
piece_to_probe_results = {}
piece_from_probe_results = {}

num_iters = 10

for name in piece_names:
    to_probe = create_linear_probe(probe_params, 64)
    from_probe = create_linear_probe(probe_params, 64)
    to_from_probe = create_linear_probe(probe_params, 64)

    piece_to_probes[name] = to_probe
    piece_from_probes[name] = from_probe
    piece_to_from_probes[name] = to_from_probe

    piece_to_probe_results[name] = []
    piece_from_probe_results[name] = []
    piece_to_from_probe_results[name] = []

    vl, va, tl, ta = train_probe(to_probe, 
                    piece_state_stacks[name],
                    piece_to_stacks[name],
                    batch_size=64,
                    num_epochs=num_iters,
                    probe_fwd=linear_probe_forward
                ) 
    
    piece_to_probe_results[name].append({
        "val_loss": vl,
        "val_acc": va,
        "train_loss": tl,
        "train_acc": ta,
    })
    
    vl, va, tl, ta = train_probe(from_probe, 
                    piece_state_stacks[name],
                    piece_from_stacks[name],
                    batch_size=64,
                    num_epochs=num_iters,
                    probe_fwd=linear_probe_forward
                ) 
    
    piece_from_probe_results[name].append({
        "val_loss": vl,
        "val_acc": va,
        "train_loss": tl,
        "train_acc": ta,
    })
    
    vl, va, tl, ta = train_probe(to_from_probe, 
                    torch.cat([piece_state_stacks[name], -piece_state_stacks[name]]),
                    torch.cat([piece_to_stacks[name], piece_from_stacks[name]]),
                    batch_size=64,
                    num_epochs=num_iters,
                    probe_fwd=linear_probe_forward
                ) 
    
    piece_to_from_probe_results[name].append({
        "val_loss": vl,
        "val_acc": va,
        "train_loss": tl,
        "train_acc": ta,
    })

In [None]:
def summarize_to_from(piece_to_probe_results,
                         piece_from_probe_results,
                         how="last"):
    rows = []

    def _select(res_list):
        if how == "last":
            return res_list[-1]
        elif how == "best":
            return max(res_list, key=lambda d: d["val_acc"])
        else:
            raise ValueError("how must be 'last' or 'best'")

    for piece_name, res_list in piece_to_probe_results.items():
        r = _select(res_list)
        rows.append({
            "piece": piece_name,
            "probe": "to",
            "train_loss": r["train_loss"],
            "train_acc": r["train_acc"],
            "val_loss":   r["val_loss"],
            "val_acc":    r["val_acc"],
        })

    for piece_name, res_list in piece_from_probe_results.items():
        r = _select(res_list)
        rows.append({
            "piece": piece_name,
            "probe": "from",
            "train_loss": r["train_loss"],
            "train_acc": r["train_acc"],
            "val_loss":   r["val_loss"],
            "val_acc":    r["val_acc"],
        })

    df = pd.DataFrame(rows)

    df.sort_values(["piece", "probe"], inplace=True)

    print("\n=== Probe summary ({}) ===".format(how))
    print(tabulate(df,
                   headers="keys",
                   tablefmt="github",
                   floatfmt=".4f",
                   showindex=False))

    return df

def summarize_combined(piece_to_from_probe_results, how="last"):
    def _select(res_list):
        if how == "last":
            return res_list[-1]
        if how == "best":
            return max(res_list, key=lambda d: d["val_acc"])
        raise ValueError("how must be 'last' or 'best'")

    rows = []
    for piece_name, res_list in piece_to_from_probe_results.items():
        r = _select(res_list)
        rows.append({
            "piece":       piece_name,
            "train_loss":  r["train_loss"],
            "train_acc":   r["train_acc"],
            "val_loss":    r["val_loss"],
            "val_acc":     r["val_acc"],
        })

    df = pd.DataFrame(rows).sort_values("piece")

    print(f"\n=== Combined to-from probe summary ({how}) ===")
    print(tabulate(df,
                   headers="keys",
                   tablefmt="github",
                   floatfmt=".4f",
                   showindex=False))

    return df

In [None]:
_ = summarize_to_from(
    piece_to_probe_results=piece_to_probe_results,
    piece_from_probe_results=piece_from_probe_results,
    how="best",
)

_ = summarize_combined(
    piece_to_from_probe_results=piece_to_from_probe_results,
    how="best",
)

### Board Reconstruction

We want to verify that the exponential moving average $Z_T$ is capable of recovering the piece positions. For each of the 64 squares, we train a linear probe to classify whether the square is blank or contains the piece. For probe details, refer to the pawn board reconstruction section.

In [None]:
def construct_ema_stack(head, piece_reqs, num_games=100):
    """
    For each dot index, we look at all of the moves by the piece (specified in piece reqs) and
    take the EMA of the head output of those moves. This information is bundled with the true head 
    contribution as well as the board state.
    """
    piece_type = piece_reqs[0]
    parity_req = piece_reqs[1]
    alpha = piece_reqs[2]

    sims = []

    head_state_stack = []
    ema_state_stack = []
    piece_board_stack = []

    for i in tqdm(range(num_games)):
        piece_len = to_squares[i].shape[0]

        dots_index_list = dots_indices[i].tolist()

        cur_ema = torch.zeros_like(head_v[i][0, 0, head, :])

        cur_piece_indices = []

        for j in range(piece_len):
            piece = piece_types[i][j][1]

            square = to_squares[i][j]
            rank = square // 8
            file = square % 8
            parity = (rank + file) % 2

            meets_parity = True
            if parity_req is not None:
                meets_parity = (parity == parity_req)

            if piece == piece_type and meets_parity:
                cur_piece_indices.append(j)
                piece_v = head_v[i][0, j, head, :]

                # The board state at the dot includes the most recent black move. So, if the black move
                # was a capture, we need to replace the previous information
                replace = False

                # update the ema
                if indices[i][j] not in dots_index_list:
                    cur_ema = alpha * cur_ema + (1 - alpha) * piece_v
                else:
                    # compute the true head contribution
                    replace = True
                    cur_ema = alpha * cur_ema + (1 - alpha) * piece_v

                # find the next dot
                dot_idx = bisect.bisect_left(dots_index_list, indices[i][j])

                if dot_idx >= len(dots_index_list):
                    break 
                
                attn = dots_attn[i][0, head, dot_idx, :][torch.tensor(cur_piece_indices)]
                v = head_v[i][0, torch.tensor(cur_piece_indices), head, :]

                # multiply v and attn
                act = torch.einsum("l, l d -> d", attn, v)

                # compute cosine similarity between act and cur_ema 
                sim = torch.nn.functional.cosine_similarity(cur_ema, act, dim=0)

                cur_board_stack = board_stacks[i][j].clone()
                for r in range(8):
                    for c in range(8):
                        parity = (r + c) % 2
                        meets_parity = True
                        if parity_req is not None:
                            meets_parity = (parity == parity_req)

                        if not meets_parity or cur_board_stack[r][c] != piece_type:
                            cur_board_stack[r][c] = 0
                        else:
                            cur_board_stack[r][c] = 1

                # the previous board stack was invalid because it did not include the black capture
                if replace:
                    piece_board_stack[-1] = cur_board_stack
                    ema_state_stack[-1] = cur_ema
                    head_state_stack[-1] = act
                    sims[-1] = sim.item()
                else:
                    piece_board_stack.append(cur_board_stack)
                    ema_state_stack.append(cur_ema)
                    head_state_stack.append(act)
                    sims.append(sim.item())

    # print the cosine similarity for checks
    print(f"Sim: {torch.mean(torch.tensor(sims)).item():.4f}")

    return torch.stack(ema_state_stack), torch.stack(piece_board_stack), torch.stack(head_state_stack)

In [None]:
piece_ema_stacks = {}
piece_board_stacks = {}
piece_head_stacks = {}

for name, head, reqs in piece_head_pairs:
    piece_ema_stack, piece_board_stack, piece_head_stack = construct_ema_stack(head, reqs, num_games=10000)

    piece_ema_stacks[name] = piece_ema_stack
    piece_board_stacks[name] = piece_board_stack
    piece_head_stacks[name] = piece_head_stack

    print(f"Computed ema data for: {name}")

In [None]:
board_probes = {}

for name in piece_names:
    board_probes[name] = create_linear_probe(train_params=probe_params,
                                        num_classes=2,
                                        has_rc=True)

In [None]:
for name in piece_names:
    vl, va, tl, ta = train_probe(
        board_probes[name],
        piece_ema_stacks[name],
        piece_board_stacks[name].long(),
        64,
        5,
        linear_probe_forward_rc
    )

Finally, to see if the EMA scheme is similar to what the model is computing, we can apply the probe (trained on the averages) to the true head contribution. The accuracy remains high, which suggests that the EMA and the true contribution encode the piece positions similarly.

In [None]:
for name in piece_names:
    print(f"=== {name} ===")

    vl, va = test_probe(
        board_probes[name],
        piece_ema_stacks[name],
        piece_board_stacks[name].long(),
        probe_fwd=linear_probe_forward_rc
    )

    print(f"Accuracy using the computed ema: {(100 * va.item()):.2f}")

    vl, va = test_probe(
        board_probes[name],
        piece_head_stacks[name],
        piece_board_stacks[name].long(),
        probe_fwd=linear_probe_forward_rc
    )

    print(f"Accuracy using the true head: {(100 * va.item()):.2f}")

### Companion pieces

For symmetric pieces like knights and rooks, we hypothesize that when one of the pieces is moved the position of the other piece is renewed. To test this, we train a probe to examine whether the square of the unmoved piece can be extracted from the value vector.

In [None]:
def generate_to_board_stack(head, piece_reqs, num_games=100, include_dots=False):
    """
    For each move of the specified piece type, this function returns the head's value vector, the to square, 
    and the board state
    """
    piece_type = piece_reqs[0]
    parity_req = piece_reqs[1]

    piece_state_stack = []
    piece_to_stack = []
    piece_board_stack = []

    for i in tqdm(range(num_games)):
        piece_len = to_squares[i].shape[0]

        dots_index_list = dots_indices[i].tolist()

        for j in range(piece_len):
            if not include_dots and indices[i][j] in dots_index_list: 
                continue

            if indices[i][j] in dots_index_list:
                assert frm_squares[i][j] == -1

            piece = piece_types[i][j][1]

            square = to_squares[i][j]
            rank = square // 8
            file = square % 8
            parity = (rank + file) % 2

            meets_parity = True
            if parity_req is not None:
                meets_parity = (parity == parity_req)

            if piece == piece_type and meets_parity:
                piece_v = head_v[i][0, j, head, :]

                piece_state_stack.append(piece_v)
                piece_to_stack.append(to_squares[i][j])

                cur_board_stack = board_stacks[i][j].clone()
                for r in range(8):
                    for c in range(8):
                        if cur_board_stack[r][c] != piece_type:
                            cur_board_stack[r][c] = 0
                        else:
                            cur_board_stack[r][c] = 1

                piece_board_stack.append(cur_board_stack)

    
    return torch.stack(piece_state_stack), torch.stack(piece_to_stack), torch.stack(piece_board_stack)


In [None]:
knight_state_stack, knight_to_stack, knight_board_stack = generate_to_board_stack(6, (chess.KNIGHT, None), num_games=10000)
rook_state_stack, rook_to_stack, rook_board_stack = generate_to_board_stack(2, (chess.ROOK, None), num_games=10000)

In [None]:
def construct_companion_data(state_stack, board_stack, to_stack, ignore_squares=[]):
    """
    For symmetric pieces like knights and rooks, find the square of the "other" knight/rook
    Assumes that there are at most 2 knights/rooks on the board
    """

    train_indices = []
    labels = []

    for i in tqdm(range(state_stack.shape[0])):
        board = board_stack[i]    
        to = to_stack[i].item()

        found = False

        # check if there is another piece not on the to square
        for r in range(8):
            for c in range(8): 
                idx = r * 8 + c
                if idx == to or idx in ignore_squares:
                    continue
                
                if board[r][c]:
                    found = True
                    labels.append(idx)
                    train_indices.append(i)

            if found:
                break

    inp = state_stack[torch.tensor(train_indices)]
    labels = torch.tensor(labels)

    return inp, labels


In [None]:
knight_inp, knight_labels = construct_companion_data(knight_state_stack, knight_board_stack, knight_to_stack)
rook_inp, rook_labels = construct_companion_data(rook_state_stack, rook_board_stack, rook_to_stack)

For knights, the probe achieves an accuracy of \~91%, whereas for rooks, the probe achieves an accuracy of \~80%. This means that while there is some information of the companion piece's location, the signal is not perfect.

As for the exact mechanism, we believe that Heads 4.1, 4.3 and 4.6 could be responsible for passing information of the companion piece, though more investigation is needed.



In [None]:
knight_companion_probe = create_linear_probe(probe_params, 64)
rook_companion_probe = create_linear_probe(probe_params, 64)

In [None]:

knight_companian_vl, knight_companion_va, knight_companion_tl, knight_companion_ta = train_probe(
    knight_companion_probe,
    knight_inp,
    knight_labels,
    64,
    16,
    linear_probe_forward
)

print(f"Final knight companion validation accuracy: {100 * knight_companion_va:.2f}")

In [None]:
rook_companian_vl, rook_companion_va, rook_companion_tl, rook_companion_ta = train_probe(
    rook_companion_probe,
    rook_inp,
    rook_labels,
    64,
    16,
    linear_probe_forward
)

print(f"Final rook companion validation accuracy: {100 * rook_companion_va:.2f}")

## Open Questions and Future Directions

Although we believe we have uncovered the main circuits responsible for board reconstruction, some important details remain. First, it could be interesting to see how the transformer find the "previous move" of a piece, as this requires knowledge of how a piece moves. Whereas for bishops a simple (albeit slightly incorrect) mechanism like "find the last bishop move on a dark/light square" is sufficient, it could be interesting to see how the transformer handles other pieces.

For the pawn average head, it is unclear how the transformer deals with scaling and how it incoporates the initial pawn positions. More specifically, it may be worthwhile to investigate how the heads use the ; token at the beginning of each game.

The EMA heads are even more complicated. First, it is theoretically challenging to find the decay factor $\alpha$. Moreover, the heads do not have clear cut piece responsibilities. Even though most piece types have a "main head," sometimes other heads will still attend to their moves. Also, the companion piece mechanism is also unknown, and may not be accurate enough to support the hypothesized piece persistence.

Finally, although a large chunk of the attention comes from the rank token of the moves, a significant portion of "ambient" attention is spread out throughout the rest of the tokens. It could be interesting to explore why this is the case.

### Future Directions

In this work, a lot of the comparisons made between hypothesized methods and the true head contributions relied on cosine similarities. A more direct approach worth exploring could be activation patching. However, getting the correct scale of the contributions could present a challenge.
