In [44]:
import torch
import chess
import numpy as np
import onnxruntime as ort
from maia2 import inference

all_moves = inference.get_all_possible_moves()
all_moves_dict = {move: i for i, move in enumerate(all_moves)}
all_moves_dict_reversed = {v: k for k, v in all_moves_dict.items()}
elo_dict = inference.create_elo_dict()

In [45]:
def to_numpy(tensor):
    return (
        tensor.detach().cpu().numpy()
        if tensor.requires_grad
        else tensor.cpu().numpy()
    )


def mirror_move(move_uci):
    # Check if the move is a promotion (length of UCI string will be more than 4)
    is_promotion = len(move_uci) > 4

    # Extract the start and end squares, and the promotion piece if applicable
    start_square = move_uci[:2]
    end_square = move_uci[2:4]
    promotion_piece = move_uci[4:] if is_promotion else ""

    # Mirror the start and end squares
    mirrored_start = mirror_square(start_square)
    mirrored_end = mirror_square(end_square)

    # Return the mirrored move, including the promotion piece if applicable
    return mirrored_start + mirrored_end + promotion_piece


def mirror_square(square):
    file = square[0]
    rank = str(9 - int(square[1]))

    return file + rank

In [46]:
fen = input("Enter FEN: ")
elo_self = 1100
elo_opp = 1100

board = chess.Board(fen=fen)
board_input, elo_self_category, elo_opp_category, legal_moves = (
    inference.preprocessing(
        board.fen(), elo_self, elo_opp, elo_dict, all_moves_dict
    )
)

# boards = to_numpy(torch.tensor(inference.board_to_tensor(board)).unsqueeze(0))
boards = to_numpy(torch.tensor(board_input).unsqueeze(0))

np.save("python_boards.npy", boards)
elos_self = to_numpy(torch.tensor([elo_self_category]))
elos_oppo = to_numpy(torch.tensor([elo_opp_category]))

NEW FEN:  8/1N1p4/R3P3/6pB/1PkpKP1b/8/Np3Q1B/8 w - - 0 1


  boards = to_numpy(torch.tensor(board_input).unsqueeze(0))


In [47]:
onnx_model = ort.InferenceSession("./maia_rapid_onnx.onnx")
inputs = {
    "boards": boards,
    "elo_self": elos_self,
    "elo_oppo": elos_oppo,
}
outputs = onnx_model.run(None, inputs)

In [None]:
logits_maia_legal = outputs[0] * legal_moves.numpy()
probs = torch.tensor(logits_maia_legal).softmax(dim=-1).cpu().tolist()
preds = np.argmax(logits_maia_legal, axis=-1)

black_flag = False
if board.fen().split(" ")[1] == "b":
    # logits_value = 1 - logits_value
    black_flag = True

move_probs = {}
legal_move_indices = legal_moves.nonzero().flatten().cpu().numpy().tolist()
legal_moves_mirrored = []
for move_idx in legal_move_indices:
    move = all_moves_dict_reversed[move_idx]
    if black_flag:
        move = mirror_move(move)
    legal_moves_mirrored.append(move)

for j in range(len(legal_move_indices)):
    move_probs[legal_moves_mirrored[j]] = round(
        probs[0][legal_move_indices[j]], 4
    )

move_probs = dict(
    sorted(move_probs.items(), key=lambda item: item[1], reverse=True)
)

print(move_probs)

FEN:  8/nP3q1b/8/1pKPkp1B/6Pb/r3p3/1n1P4/8 b - - 0 1
{'f7b7': 0.4286, 'f7c7': 0.2149, 'f7d5': 0.0754, 'f7f8': 0.0718, 'f7e7': 0.0537, 'b2d3': 0.0277, 'h4e7': 0.0191, 'e3d2': 0.0177, 'f7f6': 0.0134, 'h7g6': 0.0119, 'b2a4': 0.0098, 'f7g6': 0.0056, 'f5g4': 0.0054, 'a3a4': 0.0046, 'f7h5': 0.0044, 'e3e2': 0.0038, 'f7g8': 0.0033, 'a3c3': 0.003, 'a3d3': 0.0029, 'f7g7': 0.0023, 'a3a6': 0.002, 'h7g8': 0.0017, 'b2c4': 0.0015, 'f7d7': 0.0014, 'a7c8': 0.0013, 'b5b4': 0.0012, 'a3b3': 0.0012, 'a3a5': 0.0012, 'f7e8': 0.0011, 'a3a2': 0.0011, 'f5f4': 0.0009, 'h4d8': 0.0007, 'a7c6': 0.0005, 'h4f6': 0.0005, 'h4g5': 0.0004, 'a3a1': 0.0004, 'f7e6': 0.0003, 'e5e4': 0.0003, 'b2d1': 0.0003, 'e5f4': 0.0002, 'h4e1': 0.0002, 'h4f2': 0.0002, 'h4g3': 0.0002, 'e5f6': 0.0001}
