In [None]:
import os
import sys
import time
import chess
import torch

from data_manager import board_to_tensor
from model import SmallResNetValue as Model

from core.transposition_table import LRUCache, Entry, ZobristBoard

TB_DIR = "../../tablebases/gaviota"
MODEL_PATH = "../../models/value_network/CN2_BN2_RLROP.pth"
USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")

transposition_table = LRUCache(maxsize=100000)

_stats = {
    "batches": 0,
    "total_batch_size": 0,
    "tensor_time": 0.0,
    "model_time": 0.0,
    "eval_calls": 0,
    "serial_evals": 0
}

def print_stats():
    b = _stats["batches"]
    avg = (_stats["total_batch_size"] / b) if b else 0.0
    print("=== Eval stats ===")
    print(f"batches: {b}, avg_batch_size: {avg:.2f}, eval_calls: {_stats['eval_calls']}, serial_evals: {_stats['serial_evals']}")
    print(f"tensor_time: {_stats['tensor_time']:.4f}s, model_time: {_stats['model_time']:.4f}s")
    print("==================")

ckpt = torch.load(MODEL_PATH, map_location="cpu")
if isinstance(ckpt, dict) and "model_state" in ckpt:
    state_dict = ckpt["model_state"]
else:
    state_dict = ckpt

model = Model()
model.load_state_dict(state_dict)
model.to(device)
model.eval()

model_cpu = Model()
model_cpu.load_state_dict(state_dict)
model_cpu.to("cpu")
model_cpu.eval()

torch.set_grad_enabled(False)

# ---------------- helper functions ----------------
def evaluate_board(position):
    X_tensor = board_to_tensor(position).to(device)  # (1, C, 8, 8)
    with torch.no_grad():
        logits = model(X_tensor).squeeze(0)
    return float(logits.item())

def evaluate_batch_positions_cpu(positions):
    preds = []
    for p in positions:
        t = board_to_tensor(p).to("cpu")  # (1,C,8,8) on CPU
        with torch.no_grad():
            out = model_cpu(t).squeeze(1).cpu().item()
        preds.append(float(out))
        _stats["serial_evals"] += 1
    return preds

def evaluate_batch_tensors(batch_cpu_tensor):
    global _stats
    B = batch_cpu_tensor.size(0)
    _stats["eval_calls"] += B

    t0 = time.time()
    batch = batch_cpu_tensor.to(device, non_blocking=True)
    if device.type == "cuda":
        torch.cuda.synchronize()
    t1 = time.time()
    _stats["tensor_time"] += (t1 - t0)

    t0 = time.time()
    with torch.no_grad():
        out = model(batch).squeeze(1).cpu().numpy()
    if device.type == "cuda":
        torch.cuda.synchronize()
    t1 = time.time()
    _stats["model_time"] += (t1 - t0)
    return out

def get_victim_type(position, move):
    if position.is_en_passant(move):
        return chess.PAWN
    piece = position.piece_at(move.to_square)
    return piece.piece_type if piece else None

def order_moves(position, moves, tt_best_move=None):
    def move_score(move):
        score = 0
        if tt_best_move and move == tt_best_move:
            score += 10000
        victim_type = get_victim_type(position, move)
        if victim_type:
            attacker_type = position.piece_at(move.from_square).piece_type
            score += 10 * victim_type - attacker_type
        if move.promotion:
            score += 5 * (move.promotion or 0)
        return score
    return sorted(moves, key=move_score, reverse=True)

def minimax(position, depth, alpha, beta, maximizingPlayer):
    state_key = position.zobrist_hash

    tt_best_move = None
    if state_key in transposition_table:
        entry = transposition_table[state_key]
        if entry.depth >= depth:
            if entry.flag == 'exact':
                return entry.value, entry.best_move
            elif entry.flag == 'lowerbound' and entry.value >= beta:
                return entry.value, entry.best_move
            elif entry.flag == 'upperbound' and entry.value <= alpha:
                return entry.value, entry.best_move
        tt_best_move = entry.best_move

    if depth == 0:
        value = quiescence(position, alpha, beta, maximizingPlayer)
        return value, None

    if position.is_game_over():
        value = evaluate_board(position)
        transposition_table[state_key] = Entry(value, depth, 'exact', None)
        return value, None

    alpha_orig = alpha
    beta_orig = beta
    best_move = None

    moves = order_moves(position, list(position.legal_moves), tt_best_move)

    if depth == 1:
        moves_list = moves
        n_moves = len(moves_list)
        values = [None] * n_moves

        positions_to_eval = [] 
        indices_to_eval = []   
        child_keys_to_eval = [] 

        # collect from TT or prepare tensors
        for i, mv in enumerate(moves_list):
            position.push(mv)
            child_key = position.zobrist_hash
            if child_key in transposition_table:
                values[i] = float(transposition_table[child_key].value)
            else:
                # board_to_tensor returns (1,C,8,8) CPU tensor
                cpu_tensor = board_to_tensor(position).squeeze(0)  # (C,8,8)
                positions_to_eval.append(cpu_tensor)
                indices_to_eval.append(i)
                child_keys_to_eval.append(child_key)
            position.pop()

        B = len(positions_to_eval)
        if B > 0:
            _stats["batches"] += 1
            _stats["total_batch_size"] += B

            # choose strategy: for very small B it may be cheaper to eval on CPU serially
            SMALL_BATCH_THRESHOLD = 4
            if B < SMALL_BATCH_THRESHOLD:
                # evaluate serially on CPU model to avoid transfer overhead
                # reconstruct position objects corresponding to these indices for evaluate_batch_positions_cpu
                # we must re-create chess.Board states for cpu eval; faster to reconstruct via FEN
                boards_for_cpu = []
                for idx in indices_to_eval:
                    mv = moves_list[idx]
                    position.push(mv)
                    # clone by fen -> cheap-ish
                    boards_for_cpu.append(chess.Board(position.fen()))
                    position.pop()
                t0 = time.time()
                preds = evaluate_batch_positions_cpu(boards_for_cpu)
                _stats["model_time"] += (time.time() - t0)
                # assign and store into TT
                for j, val in enumerate(preds):
                    orig_idx = indices_to_eval[j]
                    values[orig_idx] = float(val)
                    ckey = child_keys_to_eval[j]
                    transposition_table[ckey] = Entry(values[orig_idx], 0, 'exact', None)
            else:
                # stack CPU tensors once, then single transfer + model eval
                t0 = time.time()
                batch_cpu = torch.stack(positions_to_eval, dim=0)  # (B, C, 8, 8)
                t1 = time.time()
                _stats["tensor_time"] += (t1 - t0)

                preds = evaluate_batch_tensors(batch_cpu)  # numpy array (B,)
                # write back values and TT entries (we have stored child_keys_to_eval)
                for j, val in enumerate(preds):
                    orig_idx = indices_to_eval[j]
                    values[orig_idx] = float(val)
                    ckey = child_keys_to_eval[j]
                    transposition_table[ckey] = Entry(values[orig_idx], 0, 'exact', None)

        # now do the reduction (alpha-beta) and pick best move
        if maximizingPlayer:
            maxEval = float('-inf')
            for mv, val in zip(moves_list, values):
                if val is None:
                    continue
                if val > maxEval:
                    maxEval = val
                    best_move = mv
                alpha = max(alpha, val)
                if beta <= alpha:
                    break
            value = maxEval
        else:
            minEval = float('inf')
            for mv, val in zip(moves_list, values):
                if val is None:
                    continue
                if val < minEval:
                    minEval = val
                    best_move = mv
                beta = min(beta, val)
                if beta <= alpha:
                    break
            value = minEval

        # store in TT
        flag = 'exact'
        if value <= alpha_orig:
            flag = 'upperbound'
        elif value >= beta_orig:
            flag = 'lowerbound'
        transposition_table[state_key] = Entry(value, depth, flag, best_move)
        return value, best_move

    # --- standard recursion for depth > 1 ---
    if maximizingPlayer:
        maxEval = float('-inf')
        for move in moves:
            position.push(move)
            evaluation, _ = minimax(position, depth - 1, alpha, beta, False)
            position.pop()
            if evaluation > maxEval:
                maxEval = evaluation
                best_move = move
            alpha = max(alpha, evaluation)
            if beta <= alpha:
                break
        value = maxEval
    else:
        minEval = float('inf')
        for move in moves:
            position.push(move)
            evaluation, _ = minimax(position, depth - 1, alpha, beta, True)
            position.pop()
            if evaluation < minEval:
                minEval = evaluation
                best_move = move
            beta = min(beta, evaluation)
            if beta <= alpha:
                break
        value = minEval

    flag = 'exact'
    if value <= alpha_orig:
        flag = 'upperbound'
    elif value >= beta_orig:
        flag = 'lowerbound'
    transposition_table[state_key] = Entry(value, depth, flag, best_move)
    return value, best_move

# ---------------- quiescence (unchanged except uses evaluate_board) ----------------
def quiescence(position, alpha, beta, maximizingPlayer, qs_depth=0, max_qs_depth=4):
    '''Kontynuuje wyszukiwanie tylko dla captures, aÅ¼ do cichej pozycji.'''
    if qs_depth > max_qs_depth or position.is_game_over():
        return evaluate_board(position)

    stand_pat = evaluate_board(position)

    if maximizingPlayer:
        if stand_pat >= beta:
            return beta
        alpha = max(alpha, stand_pat)
    else:
        if stand_pat <= alpha:
            return alpha
        beta = min(beta, stand_pat)

    captures = [move for move in position.legal_moves if position.piece_at(move.to_square) is not None]
    if not captures:
        return stand_pat

    captures = order_moves(position, captures)

    if maximizingPlayer:
        for move in captures:
            position.push(move)
            score = quiescence(position, alpha, beta, False, qs_depth + 1, max_qs_depth)
            position.pop()
            if score >= beta:
                return beta
            alpha = max(alpha, score)
    else:
        for move in captures:
            position.push(move)
            score = quiescence(position, alpha, beta, True, qs_depth + 1, max_qs_depth)
            position.pop()
            if score <= alpha:
                return alpha
            beta = min(beta, score)

    return alpha if maximizingPlayer else beta

# ---------------- main play loop (example) ----------------
if __name__ == "__main__":
    import chess.svg
    from IPython.display import display, SVG
    from core.gaviota import get_move_from_table

    board = ZobristBoard()
    white = True
    depth = 6
    move_count = 0

    start = time.time()
    while not board.is_game_over():
        with chess.gaviota.open_tablebase(TB_DIR) as tb:
            wdl = tb.get_wdl(board)
            if wdl is not None:
                move_info, move = get_move_from_table(board, tb)
            else:
                score, move = minimax(board, depth, float('-inf'), float('inf'), white)

            if move is None:
                break

            board.push(move)
            move_count += 1
            # display board visually if running in notebook
            try:
                display(SVG(chess.svg.board(board, size=400,
                                            arrows=[chess.svg.Arrow(move.from_square, move.to_square,
                                                                    color="#fc681fcc")])))
            except Exception:
                pass
            white = not white

        # Print intermediate stats occasionally
        if move_count % 2 == 0:
            print(f"After {move_count} moves, elapsed {time.time() - start:.2f}s")
            print_stats()

    total_time = time.time() - start
    print("Minimax with QS and model evaluation:", board.result(), "after", move_count, "moves,", "in", total_time, "seconds.")
    print_stats()
