In [38]:
import chess
import gym
import gym_chess
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import random
import math

class Node:
    def __init__(self, move=None, parent=None, state=None):
        self.move = move
        self.parent = parent
        self.state = state
        self.visits = 0
        self.score = 0
        self.children = []

    def is_fully_expanded(self):
        return len(self.children) == len(list(self.state.legal_moves))

    def is_terminal(self):
        return self.state.is_game_over()

    def select_child(self, exploration_constant):
        return max(self.children, key=lambda child: child.score / child.visits +
                                                   exploration_constant * math.sqrt(
                                                       2 * math.log(self.visits) / child.visits))

    def expand(self):
        untried_moves = list(set(self.state.legal_moves) - set(child.move for child in self.children))
        random_move = random.choice(untried_moves)
        new_state = self.state.copy()
        new_state.push(random_move)
        child_node = Node(move=random_move, parent=self, state=new_state)
        self.children.append(child_node)
        return child_node

    def update(self, reward):
        self.visits += 1
        self.score += reward
    
    def get_untried_moves(self):
        return self.untried_moves.copy()

def uct_search(initial_state, num_iterations):
    root = Node(state=initial_state)
    for _ in range(num_iterations):
        node = root
        state = initial_state.copy()

        # Selection
        while node.children:
            if node.is_fully_expanded() and not node.is_terminal():
                node = node.select_child(math.sqrt(2))
                state.push(node.move)
            else:
                break

        # Expansion
        if not node.is_terminal():
            node = node.expand()
            state.push(node.move)

        # Simulation
        while not state.is_game_over():
            random_move = random.choice(list(state.legal_moves))
            state.push(random_move)

        # Backpropagation
        while node is not None:
            reward = evaluate_state(node.state)
            node.update(reward)
            node = node.parent

    best_child = max(root.children, key=lambda child: child.visits)
    return best_child.move

def evaluate_state(state):
    # Placeholder evaluation function
    return np.random.uniform(0, 1)

# Training
env = gym.make('Chess-v0')
num_episodes = 100
steps_per_episode = []

for episode in tqdm(range(num_episodes)):
    state = env.reset()
    counter = 0

    while True:
        print("Current Board Position:")
        print(chess.Board())
        print("Legal Moves:", chess.Board().legal_moves)
        
        print("Current Board Position:")
        env.render()

        print("Legal Moves:", list(env.legal_moves))

        best_move_uci = best_move.uci()
        print("Best Move UCI:", best_move_uci)

        best_move = chess.Move.from_uci(best_move_uci)
        print("Best Move:", best_move)

        next_state, reward, done, _ = env.step(best_move)
        counter += 1
        if done:
            break

        state = next_state

# Plotting the results
plt.plot(steps_per_episode)
plt.xlabel('Episode')
plt.ylabel('Steps per Episode')
plt.title('Monte Carlo Tree Search: Learning Progress')
plt.show()

  0%|          | 0/100 [00:00<?, ?it/s]

Current Board Position:
r n b q k b n r
p p p p p p p p
. . . . . . . .
. . . . . . . .
. . . . . . . .
. . . . . . . .
P P P P P P P P
R N B Q K B N R
Legal Moves: <LegalMoveGenerator at 0x1b77ec05310 (Nh3, Nf3, Nc3, Na3, h3, g3, f3, e3, d3, c3, b3, a3, h4, g4, f4, e4, d4, c4, b4, a4)>
Current Board Position:
Legal Moves: [Move.from_uci('g1h3'), Move.from_uci('g1f3'), Move.from_uci('b1c3'), Move.from_uci('b1a3'), Move.from_uci('h2h3'), Move.from_uci('g2g3'), Move.from_uci('f2f3'), Move.from_uci('e2e3'), Move.from_uci('d2d3'), Move.from_uci('c2c3'), Move.from_uci('b2b3'), Move.from_uci('a2a3'), Move.from_uci('h2h4'), Move.from_uci('g2g4'), Move.from_uci('f2f4'), Move.from_uci('e2e4'), Move.from_uci('d2d4'), Move.from_uci('c2c4'), Move.from_uci('b2b4'), Move.from_uci('a2a4')]
Best Move UCI: g1f3
Best Move: g1f3
Current Board Position:
r n b q k b n r
p p p p p p p p
. . . . . . . .
. . . . . . . .
. . . . . . . .
. . . . . . . .
P P P P P P P P
R N B Q K B N R
Legal Moves: <LegalMoveGen

ValueError: Illegal move g1f3 for board position rnbqkbnr/pppppppp/8/8/8/5N2/PPPPPPPP/RNBQKB1R b KQkq - 1 1