In [98]:
import gym
import gym_chess
import chess
import numpy as np
import random
import time
from IPython.display import clear_output

In [99]:
env = gym.make("Chess-v0")

In [100]:
class Agent:
    def __init__(self,exploration_rate,min_exploration_rate,max_exploration_rate,exploration_decay_rate,num_episodes,learning_rate,discount_rate,avg_num_episodes):
        self.exploration_rate = exploration_rate
        self.min_exploration_rate = min_exploration_rate
        self.max_exploration_rate = max_exploration_rate
        self.exploration_decay_rate = exploration_decay_rate
        self.num_episodes = num_episodes
        self.learning_rate = learning_rate
        self.discount_rate = discount_rate
        self.avg_num_episodes = avg_num_episodes
        self.rewards_all_episodes = []
        self.q_table = {}
    
    def getMoves(self,legal_moves):
        return str([str(move) for move in list(legal_moves)])
    def isLegalAction(self,action,legal_moves):
        return True if action in legal_moves else False
    def Play(self):
        for episode in range(num_episodes):
            finished = False
            iteration_counts = 0
            rewards_current_episode = 0

            state = env.reset()
            previous_action = None

            clear_output(wait=True)
            print("*****EPISODE ", episode+1, "*****\n\n\n\n")
    #         print(env.render(mode='unicode'))

            while not finished:
                if state.is_checkmate() or state.is_stalemate() or state.is_insufficient_material() or state.is_game_over() or state.can_claim_threefold_repetition() or state.can_claim_fifty_moves() or state.can_claim_draw() or state.is_fivefold_repetition() or state.is_seventyfive_moves():
                    finished = True
                if finished:
    #                 print(state.)
    #                 print('Game Finished')
    #                 print(env.render(mode='unicode'))
    #                 time.sleep(2)
                    break
                else:
    #                 print(env.render(mode='unicode'))
                    previous_state = self.getMoves(state.legal_moves)


                    # Exploration-exploitation trade-off
                    self.exploration_rate_threshold = random.uniform(0, 1)
                    if self.exploration_rate_threshold > self.exploration_rate:
                        if previous_state in self.q_table:
                            current_action = np.argmax(self.q_table[previous_state]) # Exploitation
                            if not self.isLegalAction(current_action, state.legal_moves):
                                current_action = random.choice(list(state.legal_moves)) # Exploration
                        else:
                            current_action = random.choice(list(state.legal_moves)) # Exploration
                    else:
                        current_action = random.choice(list(state.legal_moves)) # Exploration


                        new_state, reward, done, info = env.step(current_action)

                        if iteration_counts >= 1:

                            current_state = self.getMoves(new_state.legal_moves)

                            if self.q_table.get(previous_state) is None:
                                self.q_table[previous_state] = {previous_action:0}

                            if self.q_table.get(current_state) is None:
                                self.q_table[current_state] = {current_action:0}

                            if current_action in self.q_table[current_state] and previous_action in self.q_table[previous_state]:

                                self.q_table[previous_state][previous_action] = self.q_table[previous_state][previous_action]  + \
                                    self.learning_rate * (reward + self.discount_rate * self.q_table[current_state][current_action] - self.q_table[previous_state][previous_action])

                        state = new_state
                        previous_action = current_action
                        iteration_counts += 1
                        rewards_current_episode += reward

            # Exploration rate decay
            self.exploration_rate = self.min_exploration_rate + \
                (self.max_exploration_rate - self.min_exploration_rate) * np.exp(-self.exploration_decay_rate*episode)

            self.rewards_all_episodes.append(rewards_current_episode)

        # Calculate and print the average reward per thousand episodes

        rewards_per_thousand_episodes = np.split(np.array(self.rewards_all_episodes),self.num_episodes/self.avg_num_episodes)
        count = self.avg_num_episodes

        print("********Average reward per "+str(self.avg_num_episodes)+" episodes********\n")
        for r in rewards_per_thousand_episodes:
            print(count, ": ", str(sum(r/self.avg_num_episodes)))
            count += self.avg_num_episodes

        env.close()
        return self.q_table
    
    
    def PlayWithPolicy(self,q_table):
        for episode in range(num_episodes):
            finished = False
            rewards_current_episode = 0

            state = env.reset()
            previous_action = None

            clear_output(wait=True)
            print("*****EPISODE ", episode+1, "*****\n\n\n\n")
#             print(env.render(mode='unicode'))

            while not finished:
                if state.is_checkmate() or state.is_stalemate() or state.is_insufficient_material() or state.is_game_over() or state.can_claim_threefold_repetition() or state.can_claim_fifty_moves() or state.can_claim_draw() or state.is_fivefold_repetition() or state.is_seventyfive_moves():
                    finished = True
                if finished:
#                     print('Game Finished')
#                     print(env.render(mode='unicode'))
#                     time.sleep(2)
                    break
                else:
                    action = None
    #                 print(env.render(mode='unicode'))
                    previous_state = self.getMoves(state.legal_moves)

                    if previous_state in self.q_table:
                        action = np.argmax(q_table[previous_state]) # Exploitation
                        if not self.isLegalAction(action, state.legal_moves):
                            action = random.choice(list(state.legal_moves)) # Exploration
                    else:
                        action = random.choice(list(state.legal_moves)) # Exploration

                        
                    new_state, reward, done, info = env.step(action)

                    state = new_state
                    rewards_current_episode += reward

            self.rewards_all_episodes.append(rewards_current_episode)

        # Calculate and print the average reward per thousand episodes

        rewards_per_thousand_episodes = np.split(np.array(self.rewards_all_episodes),self.num_episodes/self.avg_num_episodes)
        count = self.avg_num_episodes

        print("********Average reward per "+str(self.avg_num_episodes)+" episodes********\n")
        for r in rewards_per_thousand_episodes:
            print(count, ": ", str(sum(r/self.avg_num_episodes)))
            count += self.avg_num_episodes

        env.close()

In [101]:
num_episodes = 1000
avg_num_episodes = round(num_episodes/10)

learning_rate = 0.1
discount_rate = 0.99

exploration_rate = 1
max_exploration_rate = 1
min_exploration_rate = 0.01
exploration_decay_rate = 0.001

In [102]:
agent = Agent(exploration_rate,min_exploration_rate,max_exploration_rate,exploration_decay_rate,num_episodes,learning_rate,discount_rate,avg_num_episodes)

In [103]:
Q_table = agent.Play()

*****EPISODE  1000 *****




********Average reward per 100 episodes********

100 :  -0.019999999999999997
200 :  -0.02
300 :  0.04
400 :  -0.01
500 :  -0.04
600 :  -0.010000000000000004
700 :  0.03
800 :  3.469446951953614e-18
900 :  -0.05
1000 :  0.04


In [104]:
agent.PlayWithPolicy(Q_table)

*****EPISODE  1000 *****




********Average reward per 100 episodes********

100 :  -0.04
200 :  0.03
300 :  -0.05
400 :  0.03
500 :  -0.009999999999999997
600 :  0.03
700 :  -0.05
800 :  3.469446951953614e-18
900 :  0.09
1000 :  0.16
