In [1]:
import time
import pickle
import random
from pathlib import Path
import pandas as pd
from functools import partial

import chess
import iceberg as ice
import matplotlib.pyplot as plt
import numpy as np
import torch
from leela_interp import Lc0Model, Lc0sight, LeelaBoard
from leela_interp.core.iceberg_board import palette
from leela_interp.tools import figure_helpers as fh
from leela_interp.tools.piece_movement_heads import (
    bishop_heads,
    knight_heads,
    rook_heads,
)
from leela_interp.tools.attention import attention_attribution, top_k_attributions

device = "mps"

In [2]:
all_effects = torch.load(
    "../results/move_trees/L12H12/attention_pattern_post_softmax.pt", map_location=device
)
all_effects.shape

  all_effects = torch.load(


torch.Size([12000, 64, 64])

In [3]:
with open("../puzzles_with_move_trees_and_tags.pkl", "rb") as f:
    puzzles = pickle.load(f)
print(f"len of puzzles = {len(puzzles)}")

len of puzzles = 12000


In [4]:
model = Lc0sight("../lc0.onnx", device=device)

Using device: mps


In [5]:
def get_top_k_scores(board: LeelaBoard, tensor: torch.Tensor, k: int = 5) -> list[str]:
    
    flat_tensor = tensor.view(-1)
    top_k_indices = torch.topk(flat_tensor, k=k).indices
    
    top_k_2d_indices = [(idx // 64, idx % 64) for idx in top_k_indices.tolist()]
    top_k_2d_moves = [(board.idx2sq(i), board.idx2sq(j)) for i, j in top_k_2d_indices]
    
    return top_k_2d_moves, top_k_2d_indices

In [6]:
def count_top_k(puzzles: pd.DataFrame, get_focus_squares, k=5):
    total = 0
    correct = 0
    for (idx, puzzle) in puzzles.iterrows():
        focus_squares = get_focus_squares(puzzle)
        if focus_squares is None:
            continue
        board = LeelaBoard.from_puzzle(puzzle)
        residual_effects_idx = puzzle.residual_effects_idx
        tensor = all_effects[residual_effects_idx]
        top_k_scores, _ = get_top_k_scores(board, tensor, k=k)
        if focus_squares in top_k_scores:
            correct += 1
        total += 1
    return total, correct

In [7]:
def tgt2tgt(puzzle, query_move, key_move):
    move1 = puzzle.move_tree.get(query_move, None)
    move2 = puzzle.move_tree.get(key_move, None)
    if move1 is None or move2 is None:
        return None
    return (move1[2:4], move2[2:4])

In [8]:
third_move_tgt2first_move_tgt = partial(tgt2tgt, query_move="0", key_move="000")
count_top_k(puzzles, third_move_tgt2first_move_tgt, k=1)

(12000, 3366)

In [9]:
threshold = puzzles["effects_000_tgt"].quantile(0.5)
filtered_puzzles = puzzles[puzzles["effects_000_tgt"] > threshold]
print(f"len of filtered_puzzles = {len(filtered_puzzles)}")

len of filtered_puzzles = 6000


In [10]:
count_top_k(filtered_puzzles, third_move_tgt2first_move_tgt, k=1)

(6000, 2445)

In [22]:
move00000_move0 = partial(tgt2tgt,query_move="0", key_move="00000")
move010_move0 = partial(tgt2tgt,query_move="0", key_move="010")
filtered_puzzles = puzzles[puzzles["effects_00000_tgt"] > 0]
count_top_k(filtered_puzzles, move00000_move0, k=1)

(4444, 487)

In [34]:
move00000_move0 = partial(tgt2tgt,query_move="0", key_move="00000")
move010_move0 = partial(tgt2tgt,query_move="0", key_move="010")
threshold = puzzles["effects_00000_tgt"].quantile(0.97)
filtered_puzzles = puzzles[puzzles["effects_00000_tgt"] > threshold]
count_top_k(filtered_puzzles, move00000_move0, k=1)

(360, 135)

In [23]:
move00000_move0 = partial(tgt2tgt,query_move="0", key_move="00000")
move010_move0 = partial(tgt2tgt,query_move="0", key_move="010")
filtered_puzzles = puzzles[puzzles["effects_010_tgt"] > 0]
filtered_puzzles = filtered_puzzles[filtered_puzzles["move_tree"].apply(lambda x: "010" in x and "000" in x and x["000"][2:4] != x["010"][2:4])]
count_top_k(filtered_puzzles, move010_move0, k=1)

(2597, 224)

In [32]:
move00000_move0 = partial(tgt2tgt,query_move="0", key_move="00000")
move010_move0 = partial(tgt2tgt,query_move="0", key_move="010")
filtered_puzzles = puzzles[puzzles["effects_010_tgt"] > 0.5]
#filtered_puzzles = filtered_puzzles[filtered_puzzles["move_tree"].apply(lambda x: "010" in x and "000" in x and x["000"][2:4] != x["010"][2:4])]
count_top_k(filtered_puzzles, move010_move0, k=1)

(309, 116)

In [12]:
threshold = puzzles["effects_00000_tgt"].quantile(0.99)
filtered_puzzles = puzzles[puzzles["effects_00000_tgt"] > threshold]
count_top_k(filtered_puzzles, move00000_move0, k=5)

(120, 79)

In [13]:
quantile = 0.84
threshold_1 = puzzles["effects_00000_tgt"].quantile(quantile)
other_puzzles = puzzles[puzzles["effects_00000_tgt"]>=threshold_1]
threshold_2 = other_puzzles["effects_010_tgt"].quantile(quantile)
other_puzzles = other_puzzles[other_puzzles["effects_010_tgt"]>threshold_2]
other_puzzles = other_puzzles[other_puzzles["move_tree"].apply(lambda x: "010" in x and "00000" in x and x["010"][2:4] != x["00000"][2:4])]
print(f"len of other_puzzles = {len(other_puzzles)}")

_, top1_00000 = count_top_k(other_puzzles, move00000_move0, k=1)
_, top5_00000 = count_top_k(other_puzzles, move00000_move0, k=5)
_, top1_010 = count_top_k(other_puzzles, move010_move0, k=1)
_, top5_010 = count_top_k(other_puzzles, move010_move0, k=5)

print(f"top1_00000 = {top1_00000}")
print(f"top5_00000 = {top5_00000}")
print(f"top1_010 = {top1_010}")
print(f"top5_010 = {top5_010}")

len of other_puzzles = 89
top1_00000 = 8
top5_00000 = 13
top1_010 = 7
top5_010 = 12


In [14]:
quantile = 0.9
threshold_1 = puzzles["effects_00000_tgt"].quantile(quantile)
other_puzzles = puzzles[puzzles["effects_00000_tgt"]>=threshold_1]
threshold_2 = other_puzzles["effects_000_tgt"].quantile(quantile)
other_puzzles = other_puzzles[other_puzzles["effects_000_tgt"]>threshold_2]
print(f"length of other_puzzles = {len(other_puzzles)}")

_, top1_00000 = count_top_k(other_puzzles, move00000_move0, k=1)
_, top5_00000 = count_top_k(other_puzzles, move00000_move0, k=10)
_, top1_000 = count_top_k(other_puzzles, third_move_tgt2first_move_tgt, k=1)
_, top5_000 = count_top_k(other_puzzles, third_move_tgt2first_move_tgt, k=10)

print(f"top1_00000 = {top1_00000}")
print(f"top5_00000 = {top5_00000}")
print(f"top1_000 = {top1_000}")
print(f"top5_000 = {top5_000}")


length of other_puzzles = 120
top1_00000 = 49
top5_00000 = 66
top1_000 = 29
top5_000 = 52
