In [7]:
################################################################################
#                           1 Import packages                                  #
################################################################################
%reload_ext autoreload
%autoreload 2
from chess_gym import chess_gym

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random

import torch
import torch.nn as nn
import torch.optim as optim

# Use a double ended queue (deque) for memory
# When memory is full, this will replace the oldest value with the new one
from collections import deque

# Supress all warnings (e.g. deprecation warnings) for regular use
import warnings
warnings.filterwarnings("ignore")

In [8]:
################################################################################
#                           2 Define model parameters                          #
################################################################################

# Set whether to display on screen (slows model)
DISPLAY_ON_SCREEN = False
# Discount rate of future rewards
GAMMA = 0.95
# Learing rate for neural network
LEARNING_RATE = 0.0003
# Maximum number of game steps (observation, action, reward, next observation) to keep
MEMORY_SIZE = 1000000
# Sample batch size for policy network update
BATCH_SIZE = 3
# Number of game steps to play before starting training (all random actions)
REPLAY_START_SIZE = 10
# Time step between actions
TIME_STEP = 1
# Number of steps between policy -> target network update
SYNC_TARGET_STEPS = 10
# Exploration rate (episolon) is probability of choosign a random action
EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.001
# Reduction in epsilon with each game step
EXPLORATION_DECAY = 0.9
# Simulation duration
SIM_DURATION = 200
# Training episodes
TRAINING_EPISODES = 100000

In [9]:
################################################################################
#                      3 Define DQN (Deep Q Network) class                     #
#                    (Used for both policy and target nets)                    #
################################################################################

from chess_class import chess_class
dummy_chess=chess_class()
class DQN(nn.Module):

    """Deep Q Network. Udes for both policy (action) and target (Q) networks."""

    def __init__(self, observation_space, action_space, neurons_per_layer=1024):
        """Constructor method. Set up neural nets."""

        # Set starting exploration rate
        self.exploration_rate = EXPLORATION_MAX
        
        # Set up action space (choice of possible actions)
        self.action_space = action_space
              
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=observation_space[0], out_channels=32, kernel_size=3,padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3,padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3,padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3,padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3,padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3,padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3,padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            torch.nn.Flatten(),
            nn.Linear(15488,action_space)

        )
        # self.net = nn.Sequential(
        #     nn.Linear(observation_space, 1024),
        #     nn.Sigmoid(),
        #     nn.Linear(1024, 2048),
        #     nn.Sigmoid(),
        #     nn.Linear(2048, 4096),
        #     nn.Sigmoid(),
        #     nn.Linear(4096, action_space)
        #     )
        
    def act(self, observation):
        """Act either randomly or by redicting action that gives max Q"""
        
        # Act randomly if random number < exploration rate
        if np.random.rand() < self.exploration_rate:
            action=dummy_chess.choose_random_move_from_obs(observation[0],"black")
            # action = random.randrange(self.action_space)
            
        else:
            # Otherwise get predicted Q values of actions
            q_values = self.net(torch.FloatTensor(observation))
            # Get index of action with best Q
            action=dummy_chess._action_from_Q_values(observation[0],"black",q_values.detach().numpy()[0])
            # action = np.argmax(q_values.detach().numpy()[0])
        
        return  action
        
        
    def forward(self, x):
        """Forward pass through network"""
        return self.net(x)

In [10]:
################################################################################
#                    4 Define policy net training function                     #
################################################################################

def optimize(policy_net, target_net, memory):
    """
    Update  model by sampling from memory.
    Uses policy network to predict best action (best Q).
    Uses target network to provide target of Q for the selected next action.
    """
      
    # Do not try to train model if memory is less than reqired batch size
    if len(memory) < BATCH_SIZE:
        return    
 
    # Reduce exploration rate (exploration rate is stored in policy net)
    policy_net.exploration_rate *= EXPLORATION_DECAY
    policy_net.exploration_rate = max(EXPLORATION_MIN, 
                                      policy_net.exploration_rate)
    # Sample a random batch from memory
    batch = random.sample(memory, BATCH_SIZE)
    for observation, action, reward, observation_next, terminal in batch:
        
        observation_action_values = policy_net(torch.FloatTensor(observation))
        
        # Get target Q for policy net update
       
        if not terminal:
            # For non-terminal actions get Q from policy net
            expected_observation_action_values = policy_net(torch.FloatTensor(observation))
            # Detach next observation values from gradients to prevent updates
            expected_observation_action_values = expected_observation_action_values.detach()
            # Get next observation action with best Q from the policy net (double DQN)
            policy_next_observation_values = policy_net(torch.FloatTensor(observation_next))
            policy_next_observation_values = policy_next_observation_values.detach()
            best_action = np.argmax(policy_next_observation_values[0].numpy())
            # Get target net next observation
            next_observation_action_values = target_net(torch.FloatTensor(observation_next))
            # Use detach again to prevent target net gradients being updated
            next_observation_action_values = next_observation_action_values.detach()
            best_next_q = next_observation_action_values[0][best_action].numpy()
            updated_q = reward + (GAMMA * best_next_q)      
            expected_observation_action_values[0][action] = updated_q
        else:
            # For termal actions Q = reward (-1)
            expected_observation_action_values = policy_net(torch.FloatTensor(observation))
            # Detach values from gradients to prevent gradient update
            expected_observation_action_values = expected_observation_action_values.detach()
            # Set Q for all actions to reward (-1)
            expected_observation_action_values[0] = reward
 
        # Set net to training mode
        policy_net.train()
        # Reset net gradients
        policy_net.optimizer.zero_grad()  
        # calculate loss
        loss_v = nn.MSELoss()(observation_action_values, expected_observation_action_values)
        # Backpropogate loss
        loss_v.backward()
        # Update network gradients
        policy_net.optimizer.step()  

    return

In [11]:
################################################################################
#                            5 Define memory class                             #
################################################################################

class Memory():
    """
    Replay memory used to train model.
    Limited length memory (using deque, double ended queue from collections).
      - When memory full deque replaces oldest data with newest.
    Holds, observation, action, reward, next observation, and episode done.
    """
    
    def __init__(self):
        """Constructor method to initialise replay memory"""
        self.memory = deque(maxlen=MEMORY_SIZE)

    def remember(self, observation, action, reward, next_observation, done):
        """observation/action/reward/next_observation/done"""
        self.memory.append((observation, action, reward, next_observation, done))

In [12]:
"""Main program loop"""
from chess_gym import chess_gym
############################################################################
#                          8 Set up environment                            #
############################################################################
    
# Set up game environemnt
sim = chess_gym()

# Get number of observations returned for observation
observation_space = sim.observation_size

# Get number of actions possible
action_space = sim.action_size

############################################################################
#                    9 Set up policy and target nets                       #
############################################################################

# Set up policy and target neural nets
policy_net = DQN(observation_space, action_space)
target_net = DQN(observation_space, action_space)

# Set loss function and optimizer
policy_net.optimizer = optim.Adam(
        params=policy_net.parameters(), lr=LEARNING_RATE)

# Copy weights from policy_net to target
target_net.load_state_dict(policy_net.state_dict())

# Set target net to eval rather than training mode
# We do not train target net - ot is copied from policy net at intervals
target_net.eval()

############################################################################
#                            10 Set up memory                              #
############################################################################
    
# Set up memomry
memory = Memory()

############################################################################
#                     11 Set up + start training loop                      #
############################################################################

# Set up run counter and learning loop    
run = 0
all_steps = 0
continue_learning = True

# Set up list for results
results_run = []
results_exploration = []
results_score = []

# Continue repeating games (episodes) until target complete
while continue_learning:
    
    ########################################################################
    #                           12 Play episode                            #
    ########################################################################
    
    # Increment run (episode) counter
    run += 1
    
    ########################################################################
    #                             13 Reset game                            #
    ########################################################################
    
    # Reset game environment and get first observation observations
    observation = sim.reset()

    # Trackers for observation
    rewards = []
    
    # Reset total reward
    total_reward = 0
    
    # Reshape observation into 2D array with observation obsverations as first 'row'
    # observation = np.reshape(observation, [1, observation_space])
    observation = np.reshape(observation, [1]+list(observation_space))
    
    # Continue loop until episode complete
    while True:
        
    ########################################################################
    #                       14 Game episode loop                           #
    ########################################################################
        
        ####################################################################
        #                       15 Get action                              #
        ####################################################################
        
        # Get action to take (se eval mode to avoid dropout layers)
        policy_net.eval()
        action = policy_net.act(observation)
        
        ####################################################################
        #                 16 Play action (get S', R, T)                    #
        ####################################################################
        
        # Act 
        observation_next, reward, terminal, info = sim.step(action)
        total_reward += reward

        # # Update trackers
        # tolva_acululacion_camion.append(observation_next[0])
        # cinta_pre_evicerado.append(observation_next[1])
        # tolva_ev_aut.append(observation_next[2])
        # tolva_ev_man.append(observation_next[3])
        # salida_evicerado.append(observation_next[4])
        rewards.append(reward)
                                                        
        # Reshape observation into 2D array with observation obsverations as first 'row'
        observation_next = np.reshape(observation_next, [1]+list(observation_space))
        
        # Update display if needed
        if DISPLAY_ON_SCREEN:
            sim.render()
        
        ####################################################################
        #                  17 Add S/A/R/S/T to memory                      #
        ####################################################################
        
        # Record observation, action, reward, new observation & terminal
        
        memory.remember(observation, action, reward, observation_next, terminal)
        
        # Update observation
        observation = observation_next
        
        ####################################################################
        #                  18 Check for end of episode                     #
        ####################################################################
        
        # Actions to take if end of game episode
        if terminal:
            # Get exploration rate
            exploration = policy_net.exploration_rate
            # Clear print row content
            clear_row = '\r' + ' '*79 + '\r'
            print (clear_row, end ='')
            print (f'Run: {run}, ', end='')
            print (f'Exploration: {exploration: .3f}, ', end='')
            average_reward = total_reward/SIM_DURATION
            print (f'Average reward: {average_reward:4.6f}', end='')
            print(sim.game_status)
            # Add to results lists
            # results_run.append(run)
            # results_exploration.append(exploration)
            # results_score.append(total_reward)
            
            ################################################################
            #             18b Check for end of learning                    #
            ################################################################
            
            if run == TRAINING_EPISODES:
                continue_learning = False
            
            # End episode loop
            break
        
        
        ####################################################################
        #                        19 Update policy net                      #
        ####################################################################
        
        # Avoid training model if memory is not of sufficient length
        if len(memory.memory) > REPLAY_START_SIZE:
    
            # Update policy net
            optimize(policy_net, target_net, memory.memory)

            ################################################################
            #             20 Update target net periodically                #
            ################################################################
            
            # Use load_state_dict method to copy weights from policy net
            if all_steps % SYNC_TARGET_STEPS == 0:
                target_net.load_state_dict(policy_net.state_dict())
            
############################################################################
#                      21 Learning complete - plot results                 #
############################################################################

# Add last run to DataFrame. summarise, and return
# run_details = pd.DataFrame()
# run_details['tolva_acululacion_camion'] = tolva_acululacion_camion 
# run_details['cinta_pre_evicerado'] = cinta_pre_evicerado
# run_details['tolva_ev_aut'] = tolva_ev_aut
# run_details['tolva_ev_man'] = tolva_ev_man
# run_details['salida_evicerado'] = salida_evicerado
# run_details['reward'] = rewards    
    
# Target reached. Plot results
# plot_results(
#     results_run, results_exploration, results_score, run_details)

Run: 1, Exploration:  1.000, Average reward: -0.049400Black made an illegal move
Run: 2, Exploration:  1.000, Average reward: -0.049625Black made an illegal move
Run: 3, Exploration:  0.656, Average reward: -0.049125Black made an illegal move
Run: 4, Exploration:  0.314, Average reward: -0.048425Black made an illegal move


KeyboardInterrupt: 

In [None]:
memory.memory


In [None]:
observation[0][:12]

In [None]:
observation_space


In [None]:
[1]*8