# Chess

In [70]:
import chess, random
import torch
import torch.nn.functional as F

board = chess.Board()

In [103]:
for m in board.legal_moves:
    print(m)

g1h3
g1f3
b1c3
b1a3
h2h3
g2g3
f2f3
e2e3
d2d3
c2c3
b2b3
a2a3
h2h4
g2g4
f2f4
e2e4
d2d4
c2c4
b2b4
a2a4


In [173]:
moves_dict = {}
num_moves = 7

dir = [0, 1, -1]

for dir_x in dir:
    for dir_y in dir:
        if dir_x == 0 and dir_y == 0:
            continue
        for n in range(1, num_moves+1):
            moves_dict[(dir_x*n, dir_y*n)] = len(moves_dict)

# knight moves
moves_dict[(2, 1)] = len(moves_dict)
moves_dict[(2, -1)] = len(moves_dict)
moves_dict[(-2, 1)] = len(moves_dict)
moves_dict[(-2, -1)] = len(moves_dict)
moves_dict[(1, 2)] = len(moves_dict)
moves_dict[(1, -2)] = len(moves_dict)
moves_dict[(-1, 2)] = len(moves_dict)
moves_dict[(-1, -2)] = len(moves_dict)

# promotions
for piece in range(4):
    moves_dict[(0, 1, piece)] = len(moves_dict)
    moves_dict[(1, 1, piece)] = len(moves_dict)
    moves_dict[(-1, 1, piece)] = len(moves_dict)


class BoardArithmetic:

    def __init__(self):
        self.board = chess.Board()
        self.unique_pieces = sorted(set(str(board).replace("\n", "").replace(" ", "")))
        self.piece_to_idx = {p: i for i, p in enumerate(self.unique_pieces)}

        self.alpha_idx = {c: i for i, c in enumerate("abcdefgh")}
        self.numeric_idx = {c: i for i, c in enumerate("12345678")}
        self.promotion_idx = {c: i for i, c in enumerate("qrbn")}

    def board_to_state(self, board):
        state = [row.split(" ") for row in str(board).split("\n")]
        state = [[self.piece_to_idx[p] for p in row] for row in state]
        state_onehot = F.one_hot(torch.tensor(state), num_classes=len(self.unique_pieces)).float().permute(2, 0, 1)
        state_onehot = state_onehot[1:].unsqueeze(0) 
        return state_onehot

    def move_get_origin(self, move):
        move = str(move)
        x, y = self.alpha_idx[move[0]], self.numeric_idx[move[1]]
        return x + y*8

    def move_get_delta(self, move):
        move = str(move)
        x1, y1 = self.alpha_idx[move[0]], self.numeric_idx[move[1]]
        x2, y2 = self.alpha_idx[move[2]], self.numeric_idx[move[3]]
        delta = (x2-x1, y2-y1)

        if len(move) == 5:  # promotion
            piece = self.promotion_idx[move[4].lower()]
            delta = (*delta , piece)

        return moves_dict[delta]

        
    def move_get_action(self, move):
        origin = self.move_get_origin(move)
        delta = self.move_get_delta(move)

        return origin * len(moves_dict) + delta


In [174]:
bts = BoardArithmetic()
board = chess.Board()
state = bts.board_to_state(str(board))
print(state.shape)

torch.Size([1, 12, 8, 8])


In [176]:
moves = torch.tensor([bts.move_get_action(m) for m in board.legal_moves])

In [177]:
print(moves)

tensor([ 516,  518,  136,  138, 1140, 1064,  988,  912,  836,  760,  684,  608,
        1141, 1065,  989,  913,  837,  761,  685,  609])


In [167]:
Q = torch.randn((64, 76))

In [172]:
torch.argmax(Q)


tensor(2608)