In [2]:
from stockfish import Stockfish
import stockfish as sf
import pandas as pd
from tqdm import tqdm

import json

In [3]:
stockfish = Stockfish("stockfish")
stockfish.update_engine_parameters({"Threads": 4, "Hash": 1024})

def process_game(row):
    stockfish.set_position([])
    
    tokens = row["piece_uci"].split(" ")
    tokens = [token for i, token in enumerate(tokens) if i % 6 in [0, 2, 3, 4]]

    game = row["piece_uci"].split(" ")[::3]

    best_moves = []
    for move in game:
        best_move = stockfish.get_best_move()
        piece = stockfish.get_what_is_on_square(best_move[0:2]).value.upper()
        best_move = piece + best_move
        best_moves.append(best_move)
        stockfish.make_moves_from_current_position([move[1:]])
    return {"game": game, "white_elo": row["white_elo"], "black_elo": row["black_elo"], "stockfish_moves": best_moves, "tokens": tokens}

In [4]:


train_data_path = "./data/train_piece_count.csv"
test_data_path = "./data/test_piece_count.csv"

games_df = pd.read_csv(test_data_path, delimiter=";", usecols=["white_elo", "black_elo", "result", "ply", "piece_uci"])
games_df = games_df.sample(n = 1000, random_state=42)

# processed_games = []

# for index, row in tqdm(list(games_df.iterrows())):
#     processed_games.append(process_game(row))

In [5]:
with open("data/stockfish_moves.json", "r") as f:
    processed_games = json.load(f)


In [6]:
from training import load_models


piece_count_checkpoint = "./models/full_training/elo_piece_count_ignore_material_prediction/epoch=9-step=1250000.ckpt"
piece_count_model = load_models.piece_count_model(piece_count_checkpoint)

number of parameters: 28.17M


In [58]:
import torch

import chess

from playing.utils import legal_moves_piece_uci


def get_predictions_per_elo(game, model, elos = list(range(800, 2801, 100))):
    """
    Get probabilities for each legal move for each elo rating.
    """
    tokens = game["tokens"]
    
    encoded = model.tokenizer.encode(tokens)
    
    batch = []
    
    
    for elo in elos:
        encoded_elo = model.tokenizer.encode_token(str(elo))
        row = [encoded_elo, encoded_elo] + encoded
        batch.append(row)
    
    batch = torch.tensor(batch, dtype=torch.int64)
    
    with torch.no_grad():
        model.eval()
        batch_lightning = (batch[:, :-1], batch[:, 1:])
        logits, targets = model.logits_and_targets_for_masked_elo(batch_lightning)
        outputs = logits.softmax(dim=-1)
    


    board = chess.Board()

    results = []

    for i_move, move in enumerate(game["game"]):
        legal_moves = legal_moves_piece_uci(board)
        legal_moves_encoded = model.tokenizer.encode(legal_moves)
        legal_moves_scores = outputs[:, i_move, legal_moves_encoded]

        probabilities_per_elo = {elo: {} for elo in elos}

        for j, legal_move in enumerate(legal_moves):
            legal_move_probabilities = legal_moves_scores[:, j]

            for i, elo in enumerate(elos):
                probabilities_per_elo[elo][legal_move] = legal_move_probabilities[i].item()

        for elo in probabilities_per_elo.keys():
            probabilities_per_elo[elo] = dict(sorted(probabilities_per_elo[elo].items(), key=lambda item: item[1], reverse=True))

        results.append({
            "history": game["game"][:i_move],
            "fen": board.fen(),
            "played_move": move,
            "stockfish_move": game["stockfish_moves"][i_move],
            "predictions_per_elo": probabilities_per_elo,
        })

        board.push_uci(move[1:])


    return results

In [61]:
predictions = []

for game in tqdm(processed_games):
    game_predictions = get_predictions_per_elo(game, piece_count_model)
    predictions.append({
        "game": game["game"],
        "tokens": game["tokens"],
        "white_elo": game["white_elo"],
        "black_elo": game["black_elo"],
        "predictions": game_predictions
    })

with open("data/predictions_per_elo.json", "w") as f:
    json.dump(predictions, f, indent=2)


100%|██████████| 1000/1000 [23:30<00:00,  1.41s/it]


In [None]:
predictions