In [73]:
import chess
import chess.engine
import pandas
from utils import *
from main import *
from collections import Counter
import numpy as np
import torch
from tqdm import tqdm
import threading
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
import json
from typing import Dict, Any, List, Tuple
from collections import defaultdict

In [74]:
class Config:
    def __init__(self):
        self.data_root = 'pgn'
        self.seed = 42
        self.num_workers = 8
        self.verbose = True
        self.max_epochs = 1
        self.max_ply = 300
        self.clock_threshold = 30
        self.chunk_size = 8000  # Note: This is a string in your argparse, but consider changing it to int if applicable
        self.start_year = 2013
        self.start_month = 1
        self.end_year = 2019
        self.end_month = 12
        self.from_checkpoint = False
        self.checkpoint_year = 2018
        self.checkpoint_month = 12
        self.test_year = 2024
        self.test_month = 1
        self.num_cpu_left = 4
        self.model = 'ViT'  # Default model type
        self.lr = 1e-4
        self.wd = 1e-5
        self.batch_size = 30000
        self.first_n_moves = 10
        self.last_n_moves = 10
        self.dim_cnn = 256
        self.dim_vit = 1024
        self.num_blocks_cnn = 5
        self.num_blocks_vit = 2
        self.input_channels = 18
        self.vit_length = 8
        self.elo_dim = 128
        self.side_info = True
        self.max_games_per_elo_range = 20
        self.value = True
        self.value_coefficient = 1
        self.side_info_coefficient = 1

In [75]:
cfg = Config()
all_moves = get_all_possible_moves()
all_moves_dict = {move: i for i, move in enumerate(all_moves)}
elo_dict = create_elo_dict()
move_dict = {v: k for k, v in all_moves_dict.items()}

trained_model_path = "weights.v2.pt"
ckpt = torch.load(trained_model_path, map_location=torch.device('cpu'))
model = MAIA2Model(len(all_moves), elo_dict, cfg)
model = torch.nn.DataParallel(model)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()

{'<1100': 0, '1100-1199': 1, '1200-1299': 2, '1300-1399': 3, '1400-1499': 4, '1500-1599': 5, '1600-1699': 6, '1700-1799': 7, '1800-1899': 8, '1900-1999': 9, '>=2000': 10}


  ckpt = torch.load(trained_model_path, map_location=torch.device('cpu'))


DataParallel(
  (module): MAIA2Model(
    (chess_cnn): ChessResNet(
      (conv1): Conv2d(18, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (layers): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (dropout): Dropout(p=0.5, inplace=False)
        )
        (1): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (con

In [78]:
def extract_puzzle_data(json_line):
    data = json.loads(json_line)
    fen = data['fen']
    best_move = data['evals'][0]['pvs'][0]['line'].split()[0]
    best_cp = data['evals'][0]['pvs'][0].get('cp', 100000)
    
    best_moves = []
    
    pvs = []
    for eval in data['evals']:
        for pv in eval['pvs']:
            pvs.append(pv)
            
    for pv in pvs:
        if abs(pv.get('cp', 0) - best_cp) < 30:
            best_moves.append(pv['line'].split()[0])
    
    
    elo_dict = {
        '<1100': 0, '1100-1199': 1, '1200-1299': 2, '1300-1399': 3, '1400-1499': 4,
        '1500-1599': 5, '1600-1699': 6, '1700-1799': 7, '1800-1899': 8, '1900-1999': 9,
        '>=2000': 10
    }
    
    maia_data = {}
    for elo_range, category in elo_dict.items():
        if elo_range in data:
            move_probs = data[elo_range]['move_probs']
            top_maia_move = max(move_probs, key=move_probs.get)
            best_move_prob = move_probs.get(best_move, 0)
            
            maia_data[category] = {
                'top_maia_move': top_maia_move,
                'best_move_prob': best_move_prob
            }
    
    return {
        'fen': fen,
        'best_move': best_move,
        'maia_data': maia_data,
        'best_moves': list(set(best_moves))
    }

def is_monotonic(maia_data: Dict[int, Dict[str, Any]]) -> bool:
    probs = [data['best_move_prob'] for data in maia_data.values()]
    return all(probs[i] <= probs[i+1] for i in range(len(probs)-1))

def is_transitional(maia_data: Dict[int, Dict[str, Any]], best_moves: List[str]) -> Tuple[bool, int]:
    moves = [data['top_maia_move'] for data in maia_data.values()]
    elo_categories = sorted(maia_data.keys())
    
    # Check if all moves are correct (not truly transitional)
    if all(move in best_moves for move in moves):
        return False, -1
    
    for x in range(1, len(elo_categories)):
        if all(moves[i] not in best_moves for i in range(x)) and \
           all(moves[i] in best_moves for i in range(x, len(moves))):
            return True, elo_categories[x]
    return False, -1

def analyze_puzzle(puzzle_data: Dict[str, Any]) -> Dict[str, Any]:
    fen = puzzle_data['fen']
    best_move = puzzle_data['best_move']
    maia_data = puzzle_data['maia_data']
    best_moves = puzzle_data['best_moves']
    
    monotonic = is_monotonic(maia_data)
    transitional, transition_point = is_transitional(maia_data, best_moves)
    all_correct = all(data['top_maia_move'] == best_move for data in maia_data.values())
    
    return {
        'fen': fen,
        'best_move': best_move,
        'maia_moves': {elo: data['top_maia_move'] for elo, data in maia_data.items()},
        'maia_probs': {elo: data['best_move_prob'] for elo, data in maia_data.items()},
        'is_monotonic': monotonic,
        'is_transitional': transitional,
        'transition_point': transition_point,
        'is_both': monotonic and transitional,
        'all_correct': all_correct,
    }

def process_puzzles(file_path, special_fens, best_transitional_moves, all_best_transitional_moves, transitional_maia_moves, counters, transition_points) -> tuple:
    with open(file_path, 'r') as file:
        for line in file:
            puzzle_data = extract_puzzle_data(line)
            analyzed_data = analyze_puzzle(puzzle_data)
            # results.append(analyzed_data)
            
            counters['total'] += 1
            if analyzed_data['is_monotonic']:
                counters['monotonic'] += 1
                special_fens['monotonic'].append(analyzed_data['fen'])
            if analyzed_data['is_transitional']:
                counters['transitional'] += 1
                special_fens['transitional'].append(analyzed_data['fen'])
                transition_points.append(analyzed_data['transition_point'])
                best_transitional_moves.append(puzzle_data['best_move'])
                all_best_transitional_moves.append(puzzle_data['best_moves'])
                transitional_maia_moves.append(analyzed_data['maia_moves'])
            if analyzed_data['is_both']:
                counters['both'] += 1
                special_fens['both'].append(analyzed_data['fen'])
            if analyzed_data['all_correct']:
                counters['all_correct'] += 1
    
    # df = pd.DataFrame(results)
    return special_fens, counters, transition_points, best_transitional_moves, transitional_maia_moves

special_fens = defaultdict(list)
counters = {'total': 0, 'monotonic': 0, 'transitional': 0, 'both': 0, 'all_correct': 0}
transition_points = []
best_transitional_moves = []
all_best_transitional_moves = []
transitional_maia_moves = []

for i in tqdm(range(9)):
    file_path = f'lichess_db_eval_chunk_{i}.jsonl'
    special_fens, counters, transition_points, best_transitional_moves, transitional_maia_moves = process_puzzles(file_path, special_fens, best_transitional_moves, all_best_transitional_moves, transitional_maia_moves, counters, transition_points)
print(transitional_maia_moves[0])
total_non_all_correct = counters['total'] - counters['all_correct']
print("Total positions:", counters['total'])
print("Positions where all Maia2 predictions are correct:", counters['all_correct'], f"({counters['all_correct']/counters['total']:.2%})")
print("Monotonic positions:", counters['monotonic'], f"({counters['monotonic']/counters['total']:.2%})")
print("Transitional positions:", counters['transitional'], f"({counters['transitional']/total_non_all_correct:.2%} of non-all-correct positions)")
print("Both monotonic and transitional:", counters['both'], f"({counters['both']/total_non_all_correct:.2%} of non-all-correct positions)")

with open('trans_mono_positions.json', 'w') as f:
    json.dump(special_fens, f, indent=2)

100%|██████████| 9/9 [01:00<00:00,  6.68s/it]

{0: 'c6g2', 1: 'f7g7', 2: 'f7g7', 3: 'f7g7', 4: 'f7g7', 5: 'f7g7', 6: 'f7g7', 7: 'f7g7', 8: 'f7g7', 9: 'f7g7', 10: 'f7g7'}
Total positions: 900000
Positions where all Maia2 predictions are correct: 283697 (31.52%)
Monotonic positions: 222689 (24.74%)
Transitional positions: 90562 (14.69% of non-all-correct positions)
Both monotonic and transitional: 45695 (7.41% of non-all-correct positions)





In [81]:
def is_piece_no_longer_under_attack(fen: str, move: str, square_index) -> bool:
    """
    Determines if a piece on the given square index was under attack before the move
    and is no longer under attack after the move.

    Parameters:
    fen (str): The FEN string representing the current board position.
    move (str): The move in UCI format (e.g., 'e2e4').
    square_index (int): The square to check, as an index between 0 and 63 (0 = 'a1', 63 = 'h8').

    Returns:
    bool: True if the piece on the square was under attack before the move but is no longer under attack after.
    """
    # Load the board from the given FEN
    board = chess.Board(fen)

    piece_before_move = board.piece_at(square_index)

    # Function to count attacks and defenses on the square
    def is_under_attack(board, square_index):
        attackers = board.attackers(chess.BLACK, square_index)  # Opponent's attackers
        defenders = board.attackers(chess.WHITE, square_index)  # Defenders (current player's pieces)

        return len(attackers) > len(defenders)

    # Check if the square is under attack before the move
    was_under_attack_before = is_under_attack(board, square_index) and piece_before_move is not None and piece_before_move.color is not chess.BLACK

    # Make the move
    move_obj = chess.Move.from_uci(move)
    if not board.is_legal(move_obj):
        return False

    board.push(move_obj)  # Apply the move

    # Check if there is still a piece on the square after the move
    piece_after_move = board.piece_at(square_index)

    # Check if the square is still under attack after the move
    is_under_attack_after = is_under_attack(board, square_index) and piece_after_move is not None

    # Return True if it was under attack before, but is no longer under attack
    return was_under_attack_before and not is_under_attack_after

def is_blunder(fen: str, uci_move: str, stockfish_path: str = '/opt/homebrew/bin/stockfish', threshold: int = 150) -> bool:
    # Load the engine
    with chess.engine.SimpleEngine.popen_uci(stockfish_path) as engine:
        # Parse the FEN to create a board
        board = chess.Board(fen)
        
        # Create the move from the UCI string
        move = chess.Move.from_uci(uci_move)

        # Check if the move is legal
        if move not in board.legal_moves:
            return False  # Automatically return False if the move is illegal

        # Get the evaluation before the move
        info_before = engine.analyse(board, chess.engine.Limit(time=0.1))
        eval_before = info_before['score'].relative.score()

        # If eval_before is None, it's a checkmate scenario, so return False (no blunder possible)
        if eval_before is None:
            return False

        # Apply the move
        board.push(move)

        # Get the evaluation after the move
        info_after = engine.analyse(board, chess.engine.Limit(time=0.1))
        eval_after = info_after['score'].relative.score()

        # If eval_after is None, it's a checkmate scenario, so return False (no blunder possible)
        if eval_after is None:
            return False

        # Calculate the centipawn difference
        centipawn_loss = abs(eval_before - eval_after)

        # Check if the loss is greater than the threshold (150 centipawns)
        return centipawn_loss > threshold

def square_index(square_name: str) -> int:
    if not isinstance(square_name, str) or len(square_name) != 2:
        raise ValueError("Square name must be a string of length 2")

    file = square_name[0].lower()
    rank = square_name[1]

    if file not in 'abcdefgh' or rank not in '12345678':
        raise ValueError("Invalid square name")

    file_index = 'abcdefgh'.index(file)
    rank_index = '12345678'.index(rank)

    return rank_index * 8 + file_index

In [82]:
squarewise_alarmbells = {"b1": (0, 717), "d1": (0, 1888), "f1": (0, 1864), "h1": (0, 1000), "d2": (0, 1608), "e2":(0, 1701), "f2": (0, 1747), "a3": (1, 416), "c3": (1, 1346), "d3": (0, 1676), "e3": (0, 142), "h3": (1, 1009), "b4": (1, 1), "d4": (0, 120), "e4": (1, 977), "f4": (1, 1846), "g4": (1, 538), "h4": (0, 437), "b5": (0, 1687), "d5": (0, 1538), "e5": (0, 630), "f5": (1, 843), "c6": (1, 574), "e6": (0, 1924), "e7": (0, 2029), "f7": (0, 1715)}

squarewise_alarmbells = {"b1": (0, 717, 0.4865), "d1": (0, 1888, 0.4414), "f1": (0, 1864, 0.5453), "h1": (0, 1000, 0.29), "d2": (0, 1608, 0.87), "e2":(0, 1701, 0.1005), "f2": (0, 1747, 0.3497), "a3": (1, 416, 0.0934), "c3": (1, 1346, 0.2888), "d3": (0, 1676, 0.3614), "e3": (0, 142, 0.2222), "h3": (1, 1009, 0.2700), "b4": (1, 1, 1.0180), "d4": (0, 120, 0.4834), "e4": (0, 977, 0.4095), "f4": (0, 1846, 0.9502), "g4": (0, 538, 0.4124), "h4": (0, 437, 0.6328), "b5": (0, 1687, 0.2789), "d5": (0, 1538, 0.4424), "e5": (0, 630, 0.4104), "f5": (0, 843, 0.2159), "c6": (1, 574, 1.3355), "e6": (0, 1924, 0.5567), "e7": (0, 2029, 0.3414), "f7": (0, 1715, 0.5274)}

lookup_thresholds = {0: {}, 1:{}}
for value in squarewise_alarmbells.values():
    lookup_thresholds[value[0]][value[1]] = value[2]
    

# squarewise_alarmbells = {"b1": (0, 717), "d1": (0, 1888), "f1": (0, 1864), "h1": (0, 1000), "d2": (0, 1608), "e2":(0, 1701), "f2": (0, 1747), "a3": (1, 416), "c3": (1, 1346), "d3": (0, 1676), "e3": (0, 142), "h3": (1, 1009)}


target_key_list = ['transformer block 0 hidden states', 'transformer block 1 hidden states'] # 'conv_last'
all_ground_truths = []
intervention_site = []

for_ashton = {k:[] for k in [square for square in squarewise_alarmbells]}

for key, value in squarewise_alarmbells.items():
    layer, feature_idx = value[0], value[1]
    # layer, feature_idx = random.choice([0, 1]), random.randint(0, 2048)
    intervention_site.append((target_key_list[layer], feature_idx))
    ground_truth = []
    for i in tqdm(range(len(special_fens['transitional']))):
        fen = special_fens['transitional'][i]
        move = best_transitional_moves[i]
        square_idx = square_index(key)
        try:
            no_longer_under_attack = is_piece_no_longer_under_attack(fen, move, square_idx)
            # micro_gt = no_longer_under_attack and is_blunder(fen, transitional_maia_moves[i][0])
            micro_gt = no_longer_under_attack and not is_piece_no_longer_under_attack(fen, transitional_maia_moves[i][0], square_idx)
            ground_truth.append(micro_gt)    
            if is_piece_no_longer_under_attack(fen, move, square_idx):
                temp = {"fen": fen}
                temp["transition_points"] = transition_points[i]
                temp["best_moves"] = all_best_transitional_moves[i]
                temp['blunder'] = transitional_maia_moves[i][0]
                for_ashton[key].append(temp)
        except ValueError:
            ground_truth.append(0)

            # print("illegal_move" + move)
    all_ground_truths.append(ground_truth)
    print(len(ground_truth))
all_ground_truths = torch.tensor(all_ground_truths, dtype=torch.int)
with open("non_ashtonian_relevant_positions.json", "w") as f:
    json.dump(for_ashton, f)

100%|██████████| 90562/90562 [00:10<00:00, 9026.63it/s]


90562


100%|██████████| 90562/90562 [00:10<00:00, 9045.27it/s]


90562


100%|██████████| 90562/90562 [00:10<00:00, 8927.14it/s]


90562


100%|██████████| 90562/90562 [00:10<00:00, 9036.56it/s]


90562


100%|██████████| 90562/90562 [00:09<00:00, 9091.18it/s]


90562


100%|██████████| 90562/90562 [00:09<00:00, 9092.86it/s]


90562


100%|██████████| 90562/90562 [00:10<00:00, 9015.62it/s]


90562


100%|██████████| 90562/90562 [00:10<00:00, 8994.84it/s]


90562


100%|██████████| 90562/90562 [00:09<00:00, 9074.60it/s]


90562


100%|██████████| 90562/90562 [00:10<00:00, 8979.17it/s]


90562


100%|██████████| 90562/90562 [00:09<00:00, 9127.31it/s]


90562


100%|██████████| 90562/90562 [00:09<00:00, 9160.80it/s]


90562


100%|██████████| 90562/90562 [00:10<00:00, 8540.49it/s]


90562


100%|██████████| 90562/90562 [00:09<00:00, 9110.94it/s]


90562


100%|██████████| 90562/90562 [00:09<00:00, 9146.40it/s]


90562


100%|██████████| 90562/90562 [00:09<00:00, 9268.25it/s]


90562


100%|██████████| 90562/90562 [00:10<00:00, 8953.45it/s]


90562


100%|██████████| 90562/90562 [00:09<00:00, 9102.58it/s]


90562


100%|██████████| 90562/90562 [00:09<00:00, 9107.62it/s]


90562


100%|██████████| 90562/90562 [00:10<00:00, 9045.73it/s]


90562


100%|██████████| 90562/90562 [00:10<00:00, 8955.21it/s]


90562


100%|██████████| 90562/90562 [00:10<00:00, 8994.87it/s]


90562


100%|██████████| 90562/90562 [00:09<00:00, 9103.52it/s]


90562


100%|██████████| 90562/90562 [00:09<00:00, 9073.97it/s]


90562


100%|██████████| 90562/90562 [00:09<00:00, 9066.42it/s]


90562


100%|██████████| 90562/90562 [00:09<00:00, 9141.72it/s]


90562


In [83]:
valid_indices = [i for i in range(all_ground_truths.shape[1]) if torch.sum(all_ground_truths[:, i]) != 0]


filtered_special_fens = [special_fens['transitional'][i] for i in valid_indices]
filtered_transition_points = [transition_points[i] for i in valid_indices]
filtered_best_transitional_moves = [best_transitional_moves[i] for i in valid_indices]
filtered_all_best_transitional_moves = [all_best_transitional_moves[i] for i in valid_indices]
filtered_all_ground_truths = all_ground_truths[:, valid_indices]

In [84]:
board_fens = []
board_inputs = []
for fen in filtered_special_fens:
    board_tensor = board_to_tensor(chess.Board(fen))
    board_fens.append(fen)
    board_inputs.append(board_tensor)

board_inputs = torch.stack(board_inputs, dim=0)

In [85]:
torch.sum(filtered_all_ground_truths, dim=1)

tensor([ 13,  21,   7,  22,  26,  63, 155,  46, 203,  74, 105,  80, 125, 556,
        668, 153,  92,  79,  83, 209, 440,  75,  28,  40,  18,  26])

In [86]:
intervention_site

[('transformer block 0 hidden states', 717),
 ('transformer block 0 hidden states', 1888),
 ('transformer block 0 hidden states', 1864),
 ('transformer block 0 hidden states', 1000),
 ('transformer block 0 hidden states', 1608),
 ('transformer block 0 hidden states', 1701),
 ('transformer block 0 hidden states', 1747),
 ('transformer block 1 hidden states', 416),
 ('transformer block 1 hidden states', 1346),
 ('transformer block 0 hidden states', 1676),
 ('transformer block 0 hidden states', 142),
 ('transformer block 1 hidden states', 1009),
 ('transformer block 1 hidden states', 1),
 ('transformer block 0 hidden states', 120),
 ('transformer block 0 hidden states', 977),
 ('transformer block 0 hidden states', 1846),
 ('transformer block 0 hidden states', 538),
 ('transformer block 0 hidden states', 437),
 ('transformer block 0 hidden states', 1687),
 ('transformer block 0 hidden states', 1538),
 ('transformer block 0 hidden states', 630),
 ('transformer block 0 hidden states', 843),


In [87]:
sae_dim = 2048
sae_lr = 1e-05
sae_site = "res"
sae_date = "2023-12"
sae = torch.load(f'trained_saes_{sae_date}-{sae_dim}-{sae_lr}-{sae_site}.pt')
target_key_list = ['transformer block 0 hidden states', 'transformer block 1 hidden states'] # 'conv_last'

  sae = torch.load(f'trained_saes_{sae_date}-{sae_dim}-{sae_lr}-{sae_site}.pt')


In [88]:
def _enable_activation_hook(model, cfg):
    def get_activation(name):
        def hook(model, input, output):
            if not hasattr(_thread_local, 'residual_streams'):
                _thread_local.residual_streams = {}
            _thread_local.residual_streams[name] = output.detach()
        return hook
        
    for i in range(cfg.num_blocks_vit):
        feedforward_module = model.module.transformer.elo_layers[i][1]
        feedforward_module.register_forward_hook(get_activation(f'transformer block {i} hidden states'))

def apply_sae_to_activations(sae, activations, target_key_list):
    sae_activations = {}
    for key in target_key_list:
        if key in activations and key in sae:
            # act = activations[key].view(-1, activations[key].size(-1))
            act = torch.mean(activations[key], dim=1)
            # print(act.shape)
            encoded = nn.functional.linear(act, sae[key]['encoder_DF.weight'], sae[key]['encoder_DF.bias'])
            encoded = nn.functional.relu(encoded)
            
            sae_activations[key] = encoded
    
    return sae_activations

def apply_sae_to_reconstruction(sae, activations, target_key_list):
    sae_activations = {}
    for key in target_key_list:
        if key in activations and key in sae:
            act = torch.mean(activations[key], dim=1)
            encoded = nn.functional.linear(act, sae[key]['encoder_DF.weight'], sae[key]['encoder_DF.bias'])
            encoded = nn.functional.relu(encoded)
            decoded = nn.functional.linear(encoded, sae[key]['decoder_FD.weight'], sae[key]['decoder_FD.bias'])
            
            sae_activations[key] = decoded
    
    return sae_activations

def get_legal_moves_idx(board, all_moves_dict):
    legal_moves = torch.zeros(len(all_moves_dict))
    legal_moves_idx = []
    for move in board.legal_moves:
        move_uci = move.uci()
        if move_uci in all_moves_dict:
            legal_moves_idx.append(all_moves_dict[move_uci])
    legal_moves_idx = torch.tensor(legal_moves_idx)
    legal_moves[legal_moves_idx] = 1
    return legal_moves

In [89]:
def _enable_intervention_hook(model, cfg):
    def get_intervention_hook(name):
        def hook(module, input, output):
            if not hasattr(_thread_local, 'residual_streams'):
                _thread_local.residual_streams = {}
            _thread_local.residual_streams[name] = output.detach()
            if hasattr(_thread_local, 'modified_values') and name in _thread_local.modified_values:
                return _thread_local.modified_values[name]
            return None
        return hook
    
    for i in range(cfg.num_blocks_vit):
        feedforward_module = model.module.transformer.elo_layers[i][1]
        feedforward_module.register_forward_hook(
            get_intervention_hook(f'transformer block {i} hidden states')
        )

def set_modified_values(modified_dict):
    _thread_local.modified_values = modified_dict

def clear_modified_values():
    if hasattr(_thread_local, 'modified_values'):
        del _thread_local.modified_values

In [62]:
change = {}

In [112]:
intervention_site = []
for key, value in squarewise_alarmbells.items():
    layer, feature_idx = value[0], value[1]
    #Random intervention
    # layer, feature_idx = random.choice([0, 1]), random.randint(0, 2048)
    intervention_site.append((target_key_list[layer], feature_idx))


specific_square_idx = [val[1] for val in squarewise_alarmbells.values()].index(1009)
specific_square_name = "h3"
# intervention_site[specific_square_idx] = ('transformer block 1 hidden states', 603)

print(torch.sum(filtered_all_ground_truths, dim=1)[specific_square_idx])


specific_square_idx = intervention_site[specific_square_idx][1]
# specific_square_idx = 603

tensor(80)


In [113]:
_thread_local = threading.local()
_enable_activation_hook(model, cfg)
target_key_list = ['transformer block 0 hidden states', 'transformer block 1 hidden states']
# all_sae_activations = {key: [] for key in target_key_list}
# all_sae_reconstruct = {key: [] for key in target_key_list}
intervened_pred_list = []

legal_moves_list = []
for fen in filtered_special_fens:
    board = chess.Board(fen)
    legal_moves_list.append(get_legal_moves_idx(board, all_moves_dict))
legal_moves = torch.stack(legal_moves_list)

elos_self = torch.zeros(len(filtered_special_fens))
elos_oppo = torch.zeros(len(filtered_special_fens))
intervention_strength = 10
epsilon = 0.005

for elo in range(len(elo_dict) - 1):
    
    elos_self = elos_self.fill_(elo).long()
    elos_oppo = elos_oppo.fill_(elo).long()

    # Clean Run
    with torch.no_grad():
        logits_maia, logits_side_info, logits_value = model(board_inputs, elos_self, elos_oppo)
        activations = getattr(_thread_local, 'residual_streams', {})
        sae_activations = apply_sae_to_activations(sae, activations, target_key_list)
        sae_reconstruct_activations = apply_sae_to_reconstruction(sae, activations, target_key_list)
        # for key in target_key_list:
        #     if key in activations:
        #         all_sae_activations[key].append(sae_activations[key])
        #         all_sae_reconstruct[key].append(sae_reconstruct_activations[key])
    
        logits_maia_legal = logits_maia * legal_moves
        preds = logits_maia_legal.argmax(dim=-1)

    # Intervention
    intervened_sae_activations = {}
    for key in sae_activations:
        intervened_sae_activations[key] = sae_activations[key].clone()
    
    threshold_statistics = {k:[0, 0] for k in [val[1] for val in squarewise_alarmbells.values()]}
    
    for i in range(filtered_all_ground_truths.shape[1]):
        for j in range(filtered_all_ground_truths.shape[0]):
            if filtered_all_ground_truths[j][i] == 1:
                layer, feature_idx = intervention_site[j]
                intervened_sae_activations[layer][i, feature_idx]  += epsilon
                intervened_sae_activations[layer][i, feature_idx]  *= intervention_strength
        
    reconstructed_activations = {}
    for key in intervened_sae_activations:
        reconstructed_activations[key] = nn.functional.linear(intervened_sae_activations[key], sae[key]['decoder_FD.weight'], 
                                                              sae[key]['decoder_FD.bias']).unsqueeze(1).expand(-1, 8, -1)

    # intervene_site = 'transformer block 1 hidden states'
    _enable_intervention_hook(model, cfg)
    set_modified_values(reconstructed_activations)
    with torch.no_grad():
        intervened_logits_maia, intervened_logits_side_info, intervened_logits_value = model(board_inputs, elos_self, elos_oppo)
        intervened_logits_maia_legal = intervened_logits_maia * legal_moves
        intervened_preds = intervened_logits_maia_legal.argmax(dim=-1)
    clear_modified_values()

    intervened_pred_list.append(intervened_preds)

    # Best move rate


    original_results_dict = {k:[0, 0] for k in [val[1] for val in intervention_site]}


    print(f"Maia 2 Strength: {elo}")
    original_cnt = 0
    for i in range(filtered_all_ground_truths.shape[1]):
        for j in range(filtered_all_ground_truths.shape[0]):
            if filtered_all_ground_truths[j][i] == 1:
         
                pred = move_dict[preds[i].item()]
                if pred in filtered_all_best_transitional_moves[i]:
                    original_cnt += 1
                    original_results_dict[intervention_site[j][1]][1] += 1
                else:
                    original_results_dict[intervention_site[j][1]][0] += 1
                    pass
                    
    original_cnt /= filtered_all_ground_truths.shape[1]
    print(f"Original rate for predicting the best {specific_square_name} move: {original_results_dict[specific_square_idx][1]/(original_results_dict[specific_square_idx][0]+original_results_dict[specific_square_idx][1])}")
    # print(f"Original rate for predicting the best move: {original_cnt}")
    # print(original_results_dict)

    intervened_results_dict = {k:[0, 0] for k in [val[1] for val in intervention_site]}

    intervened_cnt = 0
    for i in range(filtered_all_ground_truths.shape[1]):
        pred = move_dict[intervened_preds[i].item()]
        for j in range(filtered_all_ground_truths.shape[0]):
            if filtered_all_ground_truths[j][i] == 1:
                if pred in filtered_all_best_transitional_moves[i]:
                    intervened_cnt += 1
                    intervened_results_dict[intervention_site[j][1]][1] += 1
                else:
                    pass
                    intervened_results_dict[intervention_site[j][1]][0] += 1
    intervened_cnt /= filtered_all_ground_truths.shape[1]
    # print(f"Intervened rate for predicting the best move: {intervened_cnt}")
    # print(intervened_results_dict)

    print(f"Intervened rate for predicting the best {specific_square_name} move:{intervened_results_dict[specific_square_idx][1]/(intervened_results_dict[specific_square_idx][0]+intervened_results_dict[specific_square_idx][1])}")
    
intervened_pred_list = torch.stack(intervened_pred_list, dim=0)

intervened_transition_points = []
for i in range(intervened_pred_list.shape[1]):
    intervened = False
    for j in range(1, len(elo_dict) - 1):
        if all(move_dict[intervened_pred_list[k][i].item()] != filtered_best_transitional_moves[i] for k in range(j)) and \
           all(move_dict[intervened_pred_list[k][i].item()] == filtered_best_transitional_moves[i] for k in range(j, len(elo_dict) - 1)):
               intervened_transition_points.append(j)
               intervened = True
    if not intervened:
        intervened_transition_points.append(-1)

cnt = 0
tot_sum = 0
for i in range(intervened_pred_list.shape[1]):
    if intervened_transition_points[i] != -1:
        cnt += 1
        tot_sum += (filtered_transition_points[i] - intervened_transition_points[i])

print(tot_sum/cnt, cnt)
change[intervention_strength] = tot_sum/cnt

Maia 2 Strength: 0
Original rate for predicting the best h3 move: 0.0375
Intervened rate for predicting the best h3 move:0.15
Maia 2 Strength: 1
Original rate for predicting the best h3 move: 0.225
Intervened rate for predicting the best h3 move:0.4125
Maia 2 Strength: 2
Original rate for predicting the best h3 move: 0.325
Intervened rate for predicting the best h3 move:0.4625
Maia 2 Strength: 3
Original rate for predicting the best h3 move: 0.4125
Intervened rate for predicting the best h3 move:0.4875


KeyboardInterrupt: 