In [27]:
import chess
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

def get_mapped():
    mapped = {
            'P': 1,     # White Pawn
            'p': -1,    # Black Pawn
            'N': 2,     # White Knight
            'n': -2,    # Black Knight
            'B': 3,     # White Bishop
            'b': -3,    # Black Bishop
            'R': 4,     # White Rook
            'r': -4,    # Black Rook
            'Q': 5,     # White Queen
            'q': -5,    # Black Queen
            'K': 6,     # White King
            'k': -6     # Black King
            }
    return mapped
def get_positions()->np.array:
    letters = ['a','b','c','d','e','f','g','h']
    numbers = [str(num) for num in range(1,9) ]
    
    positions = []
    
    for letter in letters:
        for number in numbers:
            positions.append(f'{letter}{number}')
            
    positions = np.array(positions,dtype=object)
    return positions

class ChessBoardEnv():
    def __init__(self):
        self.black_mapped = {
         #black pieces
        'p': 1,    
        'n': 3,    
        'b': 3,    
        'r': 5,    
        'q': 9
        }
        self.white_mapped = {
            'P':1,
            'N':3,
            'B':3,
            'R':5,
            'Q':9
        }
        self.board = chess.Board()
        self.n_observations = 64
        self.action_space = get_positions()
    
    def get_legal_moves(self)->np.array:
        legal_moves = [str(move) for move in list(self.board.legal_moves)]
        legal_moves = np.array(legal_moves,dtype='object')
        return legal_moves
        
    def sample(self)->list:
        moves_list = self.get_legal_moves()
        choice = np.random.choice(moves_list)
        from_position = np.where(self.action_space == choice[:2])[0][0]
        to_position = np.where(self.action_space == choice[2:])[0][0]
        return [from_position,to_position]
        
    # def to_uci(self)
    def state(self)->np.array: #this function defines State
        pgn = self.board.epd()
        array = []  #Final board
        pieces = pgn.split(" ", 1)[0]
        rows = pieces.split("/")
        mapped = get_mapped()
        for row in rows:
            array2 = []  #Row
            for thing in row:
                if thing.isdigit():
                    for i in range(0, int(thing)):
                        array2.append(0)
                else:
                    array2.append(mapped[thing])
            array.append(array2)
        return np.array(array,dtype=np.int16).reshape(1,64)

    
    def next_state(self,move_str):
        move = chess.Move.from_uci(move_str)
        if move in self.board.legal_moves:
            self.board.push(move)
        else:
            print("not legal move")
        next_state = self.board.epd()
        self.board.pop()
        return next_state

    def make_move(self,move_str)-> None:
        move = chess.Move.from_uci(move_str)
        if move in self.board.legal_moves:
            self.board.push(move)
        else:
            raise Exception("Illegal move")
        
    
    def calculate_reward_for_move(self):
        color = 'black' if self.board.turn else 'white' # if turn is True(white) then last move was made by black and vice versa
        
        if self.board.is_checkmate():
            if color == 'white':
                return torch.tensor(1.0, dtype=torch.float32)
            elif color == 'black':
                return torch.tensor(-1.0, dtype=torch.float32)
            else:
                return torch.tensor(-0.001, dtype=torch.float32)
        
        if self.board.is_fifty_moves() or self.board.is_stalemate():
            return torch.tensor(-1.0, dtype=torch.float32)
            
    def reset(self):
        self.board.reset()
        return self.board.fen
        
    def step(self, move_str):
        self.make_move(move_str)
        reward = self.calculate_reward_for_move()
        turn = 'white' if self.board.turn else 'black'
        observation = self.state()
        if self.board.is_checkmate() or self.board.is_stalemate():
            terminated = True
        if self.board.is_fifty_moves():
            truncated = True
        return observation,reward,terminated,truncated
        
        

In [28]:
# observation, reward, terminated, truncated, _ = env.step(action.item())
# print(env.step(action.item())) #np.array(1,4), reward, terminated bool, truncated bool

In [29]:
positions = get_positions()

In [38]:
np.where(positions=='g2')[0][0]

49