In [1]:
import chess
import chess.engine
import pandas
from utils import *
from main import *
from collections import Counter
import numpy as np
import torch
from tqdm import tqdm
import threading
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
import json
from typing import Dict, Any, List, Tuple
from collections import defaultdict


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def extract_puzzle_data(json_line):
    data = json.loads(json_line)
    fen = data['fen']
    best_move = data['evals'][0]['pvs'][0]['line'].split()[0]
    best_cp = data['evals'][0]['pvs'][0].get('cp', 100000)

    best_moves = []

    pvs = []
    for eval in data['evals']:
        for pv in eval['pvs']:
            pvs.append(pv)

    for pv in pvs:
        if abs(pv.get('cp', 0) - best_cp) < 30:
            best_moves.append(pv['line'].split()[0])

    elo_dict = {
        '<1100': 0, '1100-1199': 1, '1200-1299': 2, '1300-1399': 3, '1400-1499': 4,
        '1500-1599': 5, '1600-1699': 6, '1700-1799': 7, '1800-1899': 8, '1900-1999': 9,
        '>=2000': 10
    }

    maia_data = {}
    for elo_range, category in elo_dict.items():
        if elo_range in data:
            move_probs = data[elo_range]['move_probs']
            top_maia_move = max(move_probs, key=move_probs.get)
            best_move_prob = move_probs.get(best_move, 0)

            maia_data[category] = {
                'top_maia_move': top_maia_move,
                'best_move_prob': best_move_prob
            }

    return {
        'fen': fen,
        'best_move': best_move,
        'maia_data': maia_data,
        'best_moves': list(set(best_moves))
    }


def is_monotonic(maia_data: Dict[int, Dict[str, Any]]) -> bool:
    probs = [data['best_move_prob'] for data in maia_data.values()]
    return all(probs[i] <= probs[i + 1] for i in range(len(probs) - 1))


def is_transitional(maia_data: Dict[int, Dict[str, Any]], best_moves: List[str]) -> Tuple[bool, int]:
    moves = [data['top_maia_move'] for data in maia_data.values()]
    elo_categories = sorted(maia_data.keys())

    # Check if all moves are correct (not truly transitional)
    if all(move in best_moves for move in moves):
        return False, -1

    for x in range(1, len(elo_categories)):
        if all(moves[i] not in best_moves for i in range(x)) and \
                all(moves[i] in best_moves for i in range(x, len(moves))):
            return True, elo_categories[x]
    return False, -1


def analyze_puzzle(puzzle_data: Dict[str, Any]) -> Dict[str, Any]:
    fen = puzzle_data['fen']
    best_move = puzzle_data['best_move']
    maia_data = puzzle_data['maia_data']
    best_moves = puzzle_data['best_moves']

    monotonic = is_monotonic(maia_data)
    transitional, transition_point = is_transitional(maia_data, best_moves)
    all_correct = all(data['top_maia_move'] == best_move for data in maia_data.values())

    return {
        'fen': fen,
        'best_move': best_move,
        'maia_moves': {elo: data['top_maia_move'] for elo, data in maia_data.items()},
        'maia_probs': {elo: data['best_move_prob'] for elo, data in maia_data.items()},
        'is_monotonic': monotonic,
        'is_transitional': transitional,
        'transition_point': transition_point,
        'is_both': monotonic and transitional,
        'all_correct': all_correct,
    }


def process_puzzles(file_path, special_fens, best_transitional_moves, all_best_transitional_moves,
                    transitional_maia_moves, counters, transition_points) -> tuple:
    with open(file_path, 'r') as file:
        for line in file:
            puzzle_data = extract_puzzle_data(line)
            analyzed_data = analyze_puzzle(puzzle_data)
            # results.append(analyzed_data)

            counters['total'] += 1
            if analyzed_data['is_monotonic']:
                counters['monotonic'] += 1
                special_fens['monotonic'].append(analyzed_data['fen'])
            if analyzed_data['is_transitional']:
                counters['transitional'] += 1
                special_fens['transitional'].append(analyzed_data['fen'])
                transition_points.append(analyzed_data['transition_point'])
                best_transitional_moves.append(puzzle_data['best_move'])
                all_best_transitional_moves.append(puzzle_data['best_moves'])
                transitional_maia_moves.append(analyzed_data['maia_moves'])
            if analyzed_data['is_both']:
                counters['both'] += 1
                special_fens['both'].append(analyzed_data['fen'])
            if analyzed_data['all_correct']:
                counters['all_correct'] += 1

    # df = pd.DataFrame(results)
    return special_fens, counters, transition_points, best_transitional_moves, transitional_maia_moves


special_fens = defaultdict(list)
counters = {'total': 0, 'monotonic': 0, 'transitional': 0, 'both': 0, 'all_correct': 0}
transition_points = []
best_transitional_moves = []
all_best_transitional_moves = []
transitional_maia_moves = []

for i in tqdm(range(5)):
    file_path = f'lichess_db_eval_chunk_{i}.jsonl'
    special_fens, counters, transition_points, best_transitional_moves, transitional_maia_moves = process_puzzles(
        file_path, special_fens, best_transitional_moves, all_best_transitional_moves, transitional_maia_moves,
        counters, transition_points)
print(transitional_maia_moves[0])
total_non_all_correct = counters['total'] - counters['all_correct']
print("Total positions:", counters['total'])
print("Positions where all Maia2 predictions are correct:", counters['all_correct'],
      f"({counters['all_correct'] / counters['total']:.2%})")
print("Monotonic positions:", counters['monotonic'], f"({counters['monotonic'] / counters['total']:.2%})")
print("Transitional positions:", counters['transitional'],
      f"({counters['transitional'] / total_non_all_correct:.2%} of non-all-correct positions)")
print("Both monotonic and transitional:", counters['both'],
      f"({counters['both'] / total_non_all_correct:.2%} of non-all-correct positions)")

with open('trans_mono_positions.json', 'w') as f:
    json.dump(special_fens, f, indent=2)


100%|██████████| 5/5 [00:32<00:00,  6.56s/it]

{0: 'c6g2', 1: 'f7g7', 2: 'f7g7', 3: 'f7g7', 4: 'f7g7', 5: 'f7g7', 6: 'f7g7', 7: 'f7g7', 8: 'f7g7', 9: 'f7g7', 10: 'f7g7'}
Total positions: 500000
Positions where all Maia2 predictions are correct: 158219 (31.64%)
Monotonic positions: 123356 (24.67%)
Transitional positions: 50191 (14.69% of non-all-correct positions)
Both monotonic and transitional: 25139 (7.36% of non-all-correct positions)





In [3]:
def is_piece_no_longer_under_attack(fen: str, move: str, square_index) -> bool:
    """
    Determines if a piece on the given square index was under attack before the move
    and is no longer under attack after the move.

    Parameters:
    fen (str): The FEN string representing the current board position.
    move (str): The move in UCI format (e.g., 'e2e4').
    square_index (int): The square to check, as an index between 0 and 63 (0 = 'a1', 63 = 'h8').

    Returns:
    bool: True if the piece on the square was under attack before the move but is no longer under attack after.
    """
    # Load the board from the given FEN
    board = chess.Board(fen)

    piece_before_move = board.piece_at(square_index)

    # Function to count attacks and defenses on the square
    def is_under_attack(board, square_index):
        attackers = board.attackers(chess.BLACK, square_index)  # Opponent's attackers
        defenders = board.attackers(chess.WHITE, square_index)  # Defenders (current player's pieces)

        return len(attackers) > len(defenders)

    # Check if the square is under attack before the move
    was_under_attack_before = is_under_attack(board,
                                              square_index) and piece_before_move is not None and piece_before_move.color is not chess.BLACK

    # Make the move
    move_obj = chess.Move.from_uci(move)
    if not board.is_legal(move_obj):
        return False

    board.push(move_obj)  # Apply the move

    # Check if there is still a piece on the square after the move
    piece_after_move = board.piece_at(square_index)

    # Check if the square is still under attack after the move
    is_under_attack_after = is_under_attack(board, square_index) and piece_after_move is not None

    # Return True if it was under attack before, but is no longer under attack
    return was_under_attack_before and not is_under_attack_after


def is_blunder(fen: str, uci_move: str, stockfish_path: str = '/opt/homebrew/bin/stockfish',
               threshold: int = 150) -> bool:
    # Load the engine
    with chess.engine.SimpleEngine.popen_uci(stockfish_path) as engine:
        # Parse the FEN to create a board
        board = chess.Board(fen)

        # Create the move from the UCI string
        move = chess.Move.from_uci(uci_move)

        # Check if the move is legal
        if move not in board.legal_moves:
            return False  # Automatically return False if the move is illegal

        # Get the evaluation before the move
        info_before = engine.analyse(board, chess.engine.Limit(time=0.1))
        eval_before = info_before['score'].relative.score()

        # If eval_before is None, it's a checkmate scenario, so return False (no blunder possible)
        if eval_before is None:
            return False

        # Apply the move
        board.push(move)

        # Get the evaluation after the move
        info_after = engine.analyse(board, chess.engine.Limit(time=0.1))
        eval_after = info_after['score'].relative.score()

        # If eval_after is None, it's a checkmate scenario, so return False (no blunder possible)
        if eval_after is None:
            return False

        # Calculate the centipawn difference
        centipawn_loss = abs(eval_before - eval_after)

        # Check if the loss is greater than the threshold (150 centipawns)
        return centipawn_loss > threshold


def square_index(square_name: str) -> int:
    if not isinstance(square_name, str) or len(square_name) != 2:
        raise ValueError("Square name must be a string of length 2")

    file = square_name[0].lower()
    rank = square_name[1]

    if file not in 'abcdefgh' or rank not in '12345678':
        raise ValueError("Invalid square name")

    file_index = 'abcdefgh'.index(file)
    rank_index = '12345678'.index(rank)

    return rank_index * 8 + file_index


In [4]:
# squarewise_alarmbells = {"b1": (0, 717), "d1": (0, 1888), "f1": (0, 1864), "h1": (0, 1000), "d2": (0, 1608), "e2":(0, 1701), "f2": (0, 1747), "a3": (1, 416), "c3": (1, 1346), "d3": (0, 1676), "e3": (0, 142), "h3": (1, 1009), "b4": (1, 1), "d4": (0, 120), "e4": (1, 977), "f4": (1, 1846), "g4": (1, 538), "h4": (0, 437), "b5": (0, 1687), "d5": (0, 1538), "e5": (0, 630), "f5": (1, 843), "c6": (1, 574), "e6": (0, 1924), "e7": (0, 2029), "f7": (0, 1715)}

squarewise_alarmbells = {"b1": (0, 717, 0.4865), "d1": (0, 1888, 0.4414), "f1": (0, 1864, 0.5453), "h1": (0, 1000, 0.29), "d2": (0, 1608, 0.87), "e2":(0, 1701, 0.1005), "f2": (0, 1747, 0.3497), "a3": (1, 416, 0.0934), "c3": (1, 1346, 0.2888), "d3": (0, 1676, 0.3614), "e3": (0, 142, 0.2222), "h3": (1, 1009, 0.2700), "b4": (1, 1, 1.0180), "d4": (0, 120, 0.4834), "e4": (0, 977, 0.4095), "f4": (0, 1846, 0.9502), "g4": (0, 538, 0.4124), "h4": (0, 437, 0.6328), "b5": (0, 1687, 0.2789), "d5": (0, 1538, 0.4424), "e5": (0, 630, 0.4104), "f5": (0, 843, 0.2159), "c6": (1, 574, 1.3355), "e6": (0, 1924, 0.5567), "e7": (0, 2029, 0.3414), "f7": (0, 1715, 0.5274)}

lookup_thresholds = {0: {}, 1:{}}
for value in squarewise_alarmbells.values():
    lookup_thresholds[value[0]][value[1]] = value[2]
    


target_key_list = ['transformer block 0 hidden states', 'transformer block 1 hidden states'] # 'conv_last'
all_ground_truths = []
intervention_site = []

true_transitional_positions = {k:[] for k in [square for square in squarewise_alarmbells]}

for key, value in squarewise_alarmbells.items():
    layer, feature_idx = value[0], value[1]
    # layer, feature_idx = random.choice([0, 1]), random.randint(0, 2048)
    intervention_site.append((target_key_list[layer], feature_idx))
    ground_truth = []
    for i in tqdm(range(len(special_fens['transitional']))):
        fen = special_fens['transitional'][i]
        move = best_transitional_moves[i]
        square_idx = square_index(key)
        try:
            no_longer_under_attack = is_piece_no_longer_under_attack(fen, move, square_idx)
            micro_gt = no_longer_under_attack and is_blunder(fen, transitional_maia_moves[i][0])
            ground_truth.append(micro_gt)    
            if is_piece_no_longer_under_attack(fen, move, square_idx):
                temp = {"fen": fen}
                temp["transition_points"] = transition_points[i]
                temp["best_moves"] = all_best_transitional_moves[i]
                temp['blunder'] = transitional_maia_moves[i][0]
                true_transitional_positions[key].append(temp)
        except ValueError:
            ground_truth.append(0)

            # print("illegal_move" + move)
    all_ground_truths.append(ground_truth)
all_ground_truths = torch.tensor(all_ground_truths, dtype=torch.int)
with open("relevant_positions.json", "w") as f:
    json.dump(true_transitional_positions, f)


 47%|████▋     | 23786/50191 [00:05<00:05, 4460.30it/s]


KeyboardInterrupt: 