In [1]:
# import surge
import numpy
import torch
import chess

# surge.init()

In [2]:
# print surge pos using the chess module
def pprint_pos(position: surge.Position):
    print(chess.Board(position.fen()))

# convert a position to the tensor representation used by the nn
def pos_2_tensor(pos: surge.Position) -> torch.tensor:
    bit_state = pos.bit_state()
    bits = np.unpackbits(bit_state.view(np.uint8))
    return torch.tensor(bits).float()

NameError: name 'surge' is not defined

In [296]:
# count and weight the material for each side, a very simple evaluation metric
def static_material_eval(pos: surge.Position) -> float:
    # 1 - pawn, 3 - knight, bishop, 5 - rook, 9 - queen, 50 - kign
    type_value = [1, 3, 3, 5, 9, 50, 0, 0, -1, -3, -3, -5, -9, -50, 0]
    bit_state = pos.bit_state()
    score = 0
    for type, bits in enumerate(bit_state):
        score += bin(bits).count("1") * type_value[type] # count the number of 1s and weight with the piece value
    return score

# return the score of the position - black, + white and if the game is over
def eval_position(pos: surge.Position, nn) -> (float, bool):
    
    if len(pos.legals()) == 0:
        if pos.in_check():
            # in check and no legal moves left -> the side to play is checkmate
            # if white is checkmate black gets an -infinit score and inf for black in checkmate
            return (float("-inf"), True) if pos.color_to_play == 0 else (float("inf"), True)
        # if there is no check but no move it is stalemate
        return 0, True

    # if the game is not over evaluate the board using the nn
    # return nn.eval_position(pos_2_tensor(pos)), False
    return static_material_eval(pos), False

In [297]:
# a minmax tree search of each possible move util the max depth is reached using alpha beta pruning
def minmax(pos: surge.Position, depth: int, alpha: float, beta: float, is_maximizing: bool) -> (float, surge.Move):
    
    score, is_gameover = eval_position(pos, None)
    if is_gameover or depth is 0:
        # stop the tree search and return, no move is returned because there are no legals or we stop exploring here
        return score, []

    print("\t" * (DEPTH - depth), end="")
    print(f"minmax: color: {'W' if is_maximizing else 'B'} depth: {depth}")
    

    best_moves = []
    best_score = 0
    # white wants to maximize the score
    if is_maximizing:
        best_score = float("-inf")

        for move in pos.legals():
            pos.play(move)

            print("\t" * (DEPTH - depth), end="")
            print(move)


            score, next_moves = minmax(pos, depth - 1, alpha, beta, False)
            pos.undo(move)

            if score >= best_score:
                best_score = score
                best_moves = [move] + next_moves
            alpha = max(alpha, score)

            if alpha >= beta:
                break

        print("\t" * (DEPTH - depth), end="")
        print(f"W best: {best_score}, {best_moves}")

    # black wants to minimize the score
    else:
        best_score = float("inf")
        
        for move in pos.legals():
            pos.play(move)

            print("\t" * (DEPTH - depth), end="")
            print(move)

            score, next_moves = minmax(pos, depth - 1, alpha, beta, True)
            pos.undo(move)

            if score <= best_score:
                best_score = score
                best_moves = [move] + next_moves
            beta = min(beta, score)

            if beta <= alpha:
                break
        
        print("\t" * (DEPTH - depth), end="")
        print(f"B best: {best_score}, {best_moves}")

    return best_score, best_moves

def find_best_move(pos, depth):
    move = None
    for i in range(depth, depth + 1):
        score, move = minmax(pos, i, float("-inf"), float("inf"), pos.color_to_play() == 0)
        if score in [float("-inf"), float("inf")]:
            return score, move
    return score, move

In [None]:
pos = surge.Position()
surge.Position.set("r1bk3r/1pp2ppp/pb1p1n2/n2P4/B3P1q1/2Q2N2/PB3PPP/RN3RK1 w - - 0 1", pos)
pprint_pos(pos)
game_over = False
while not game_over:
    move = find_best_move(pos, 4)
    print(move)
    pos.play(move)
    pprint_pos(pos)
    score, game_over = eval_position(pos, None)

print(eval_position(pos, None))
pprint_pos(pos)

In [94]:
move = surge.Move()
[move] + []

[a1a1]