In [None]:
import chess
import torch
import os
import sys
from data_manager import board_to_tensor
from model import SmallResNetValue as Model

sys.path.append(os.path.abspath("../.."))

from core.transposition_table import LRUCache, Entry, ZobristBoard

transposition_table = LRUCache(maxsize=100000)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model()
model.load_state_dict(torch.load("../../models/value_network/CN2_BN2_RLROP.pth", map_location=device))
model.to(device)
model.eval()

def evaluate_board(position):
    X_tensor = board_to_tensor(position).to(device)
    with torch.no_grad():
        logits = model(X_tensor).squeeze(0)

    evaluation = logits.item()
    return evaluation

def get_victim_type(position, move):
    '''Zwraca wartość figury zbitej przez dany ruch.'''
    if position.is_en_passant(move):
        return chess.PAWN
    piece = position.piece_at(move.to_square)
    return piece.piece_type if piece else None

def order_moves(position, moves, tt_best_move=None):
    '''Sortuje ruchy według wartości, aby przyspieszyć alfa-beta.'''
    def move_score(move):
        score = 0
        if tt_best_move and move == tt_best_move:
            score += 10000
        victim_type = get_victim_type(position, move)
        if victim_type:
            attacker_type = position.piece_at(move.from_square).piece_type
            score += 10 * victim_type - attacker_type
        if move.promotion:
            score += 5 * move.promotion
        return score
    return sorted(moves, key=move_score, reverse=True)

def minimax(position, depth, alpha, beta, maximizingPlayer):
    '''Algorytm minimax z przycinaniem alfa-beta i tablicą transpozycji.'''
    state_key = position.zobrist_hash
    
    tt_best_move = None
    if state_key in transposition_table:
        entry = transposition_table[state_key]
        if entry.depth >= depth:
            if entry.flag == 'exact':
                return entry.value, entry.best_move
            elif entry.flag == 'lowerbound' and entry.value >= beta:
                return entry.value, entry.best_move
            elif entry.flag == 'upperbound' and entry.value <= alpha:
                return entry.value, entry.best_move
        tt_best_move = entry.best_move
    
    if depth == 0:
        value = quiescence(position, alpha, beta, maximizingPlayer)
        return value, None
    
    if position.is_game_over():
        value = evaluate_board(position)
        transposition_table[state_key] = Entry(value, depth, 'exact', None)
        return value, None
    
    alpha_orig = alpha
    beta_orig = beta
    best_move = None
    
    moves = order_moves(position, list(position.legal_moves), tt_best_move)
    
    if maximizingPlayer:
        maxEval = float('-inf')
        for move in moves:
            position.push(move)
            evaluation, _ = minimax(position, depth-1, alpha, beta, False)
            position.pop()
            if evaluation > maxEval:
                maxEval = evaluation
                best_move = move
            alpha = max(alpha, evaluation)
            if beta <= alpha:
                break
        value = maxEval
    else:
        minEval = float('inf')
        for move in moves:
            position.push(move)
            evaluation, _ = minimax(position, depth-1, alpha, beta, True)
            position.pop()
            if evaluation < minEval:
                minEval = evaluation
                best_move = move
            beta = min(beta, evaluation)
            if beta <= alpha:
                break
        value = minEval
    
    flag = 'exact'
    if value <= alpha_orig:
        flag = 'upperbound'
    elif value >= beta_orig:
        flag = 'lowerbound'
    
    transposition_table[state_key] = Entry(value, depth, flag, best_move)

    return value, best_move

def quiescence(position, alpha, beta, maximizingPlayer, qs_depth=0, max_qs_depth=4):
    '''Kontynuuje wyszukiwanie tylko dla captures, aż do cichej pozycji.'''
    if qs_depth > max_qs_depth or position.is_game_over():
        return evaluate_board(position)
    
    stand_pat = evaluate_board(position)
    
    if maximizingPlayer:
        if stand_pat >= beta:
            return beta
        alpha = max(alpha, stand_pat)
    else:
        if stand_pat <= alpha:
            return alpha
        beta = min(beta, stand_pat)
    
    captures = [move for move in position.legal_moves if position.piece_at(move.to_square) is not None]
    if not captures:
        return stand_pat
    
    captures = order_moves(position, captures)
    
    if maximizingPlayer:
        for move in captures:
            position.push(move)
            score = quiescence(position, alpha, beta, False, qs_depth + 1, max_qs_depth)
            position.pop()
            if score >= beta:
                return beta
            alpha = max(alpha, score)
    else:
        for move in captures:
            position.push(move)
            score = quiescence(position, alpha, beta, True, qs_depth + 1, max_qs_depth)
            position.pop()
            if score <= alpha:
                return alpha
            beta = min(beta, score)
    
    return alpha if maximizingPlayer else beta

In [None]:
import time
import chess.svg
from IPython.display import display, SVG

from core.gaviota import get_move_from_table

TB_DIR = "../../tablebases/gaviota"

board = ZobristBoard()

white = True
depth = 6
move_count = 0
board

start = time.time()
while not board.is_game_over():
    with chess.gaviota.open_tablebase(TB_DIR) as tb:
        wdl = tb.get_wdl(board)
        if wdl is not None:
            move_info, move = get_move_from_table(board, tb)
        else:
            score, move = minimax(board, depth, float('-inf'), float('inf'), white)
        
        if move is None:
            break

        board.push(move)
        move_count += 1
        display(SVG(chess.svg.board(board, size=400,arrows=[chess.svg.Arrow(move.from_square, move.to_square, color="#fc681fcc")],)))
        white = not white

print("Minimax with QS and model evaluation:", board.result(), "after", move_count, "moves,", "in", time.time() - start, "seconds.")
