In [6]:
import os
import time
import random
import chess
import chess.svg
import pickle
import cProfile
import numpy as np
import ipywidgets as wg
import tensorflow as tf

from enum import Enum, auto
from typing import Union
from collections import Counter, deque
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
from IPython.display import display, clear_output

In [2]:
# Constants
# Visualizing board
BOARD_SIZE = 500    # 500 * 500 pixel^2 board.

PIECES_TO_PROMOTE = 'qrnb'
COLS = 'abcdefgh'
ROWS = list(range(1, 9))
BOARD = np.array([[f'{col}{row}' for col in COLS] for row in ROWS[::-1]])
flat_board = BOARD.flatten()

# Piece maps
PIECE_MAP = {'p': -1, 'r': -5, 'n': -4, 'b': -3, 'q': -9, 'k': -10, '.': 0}
PIECE_MAP = PIECE_MAP | {piece.upper(): piece_id * -1 for piece, piece_id in PIECE_MAP.items()}

# There are some illegal moves here, fix it later?
all_moves = {f'{m1}{m2}' for m1 in flat_board for m2 in flat_board if m1 != m2}

print(len(all_moves))

4032


In [3]:
def vertical_moves(pos1):
    for row in ROWS:
        pos2 = pos1[0] + str(row)
        if pos1 != pos2:
            yield pos1 + pos2


def horizontal_moves(pos1):
    # Castling is included in the horizontal movement.
    # e1c1, e1g1 are white castling moves and e8c8 and e8g8 are black castling moves.
    for col in COLS:
        pos2 = col + pos1[1]
        if pos1 != pos2:
            yield pos1 + pos2


def diagonal_moves(pos1):
    r, c = np.where(BOARD == pos1)
    r, c = r[0], c[0]
    
    row_shifter = col_shifter = range(-8, 8, 1)
    for k in [1, -1]:
        # k = 1 means top left to bottom right diagonal and 
        #    -1 means top right to bottom left diagonal
        for i, j in zip(row_shifter, col_shifter):
            rr = r + k * i
            cc = c + j
            if 0 <= rr < 8 and 0 <= cc < 8 and (r, c) != (rr, cc):
                yield pos1 + BOARD[rr, cc]
                

def knight_moves(pos1):
    r, c = np.where(BOARD == pos1)
    r, c = r[0], c[0]
    
    for k in range(2):
        for i in [-2, 2]:       # Two horizontal or vertical move
            for j in [-1, 1]:   # One vertical or horizontal move
                
                if k: rr, cc = r + i, c + j   # 2 horizontal then 1 vertical move
                else: rr, cc = r + j, c + i   # 1 horizontal then 2 vertical move

                if 0 <= rr < 8 and 0 <= cc < 8 and (r, c) != (rr, cc):
                    yield pos1 + BOARD[rr, cc]
    
    # for j in [-2, 2]:
    #     for i in [-1, 1]:
    #         rr, cc = r + i, c + j
    #         if 0 <= rr < 8 and 0 <= cc < 8 and (r, c) != (rr, cc):
    #             yield pos1 + BOARD[rr, cc]


def promotion_moves(pos1):
    if not pos1[1] in {'2', '7'}:
        return
    
    ri = COLS.index(pos1[0])
    pos2_col = '1' if pos1[1] == '2' else '8'
    for i in [-1, 0, 1]:
        col_index = ri + i
        if 0 <= col_index < 8:
            for piece in PIECES_TO_PROMOTE:
                for piece2 in [piece, piece.upper()]:   # Hack: piece type was creating problems. Fix this later.
                    # piece = piece.upper() if pos1[1] == '2' else piece
                    yield pos1 + COLS[col_index] + pos2_col + piece2
           
def get_all_legal_moves():      
    all_legal_moves = set()

    for pos1 in flat_board:
        unions_of_moves = set(vertical_moves(pos1)) | set(horizontal_moves(pos1)) | \
            set(diagonal_moves(pos1)) | set(knight_moves(pos1)) | set(promotion_moves(pos1))
            
        all_legal_moves.update(unions_of_moves)

    len(all_legal_moves)
    return dict(zip(all_legal_moves, range(len(all_legal_moves))))

    ALL_LEGAL_MOVES = get_all_legal_moves

In [4]:
print(list(vertical_moves('a2')))
print(list(horizontal_moves('a2')))
print(list(diagonal_moves('a2')))
print(list(knight_moves('c4'))) 
print(list(promotion_moves('a2')))
print(list(promotion_moves('g7')))

['a2a1', 'a2a3', 'a2a4', 'a2a5', 'a2a6', 'a2a7', 'a2a8']
['a2b2', 'a2c2', 'a2d2', 'a2e2', 'a2f2', 'a2g2', 'a2h2']
['a2b1', 'a2b3', 'a2c4', 'a2d5', 'a2e6', 'a2f7', 'a2g8']
['c4a5', 'c4a3', 'c4e5', 'c4e3', 'c4b6', 'c4d6', 'c4b2', 'c4d2']
['a2a1q', 'a2a1Q', 'a2a1r', 'a2a1R', 'a2a1n', 'a2a1N', 'a2a1b', 'a2a1B', 'a2b1q', 'a2b1Q', 'a2b1r', 'a2b1R', 'a2b1n', 'a2b1N', 'a2b1b', 'a2b1B']
['g7f8q', 'g7f8Q', 'g7f8r', 'g7f8R', 'g7f8n', 'g7f8N', 'g7f8b', 'g7f8B', 'g7g8q', 'g7g8Q', 'g7g8r', 'g7g8R', 'g7g8n', 'g7g8N', 'g7g8b', 'g7g8B', 'g7h8q', 'g7h8Q', 'g7h8r', 'g7h8R', 'g7h8n', 'g7h8N', 'g7h8b', 'g7h8B']


In [5]:
BOARD

array([['a8', 'b8', 'c8', 'd8', 'e8', 'f8', 'g8', 'h8'],
       ['a7', 'b7', 'c7', 'd7', 'e7', 'f7', 'g7', 'h7'],
       ['a6', 'b6', 'c6', 'd6', 'e6', 'f6', 'g6', 'h6'],
       ['a5', 'b5', 'c5', 'd5', 'e5', 'f5', 'g5', 'h5'],
       ['a4', 'b4', 'c4', 'd4', 'e4', 'f4', 'g4', 'h4'],
       ['a3', 'b3', 'c3', 'd3', 'e3', 'f3', 'g3', 'h3'],
       ['a2', 'b2', 'c2', 'd2', 'e2', 'f2', 'g2', 'h2'],
       ['a1', 'b1', 'c1', 'd1', 'e1', 'f1', 'g1', 'h1']], dtype='<U2')

In [6]:
class QType(Enum):
    DQN = auto()
    TABULAR = auto()

In [10]:

class DeepQNetwork:
    """Using a Convolutional Neural Network Architecture."""
    def __init__(self, epochs: int = 10, batch_size: int = 100):
        self.name = 'cnn_chessRL'

        self.optimizer = keras.optimizers.Adam(learning_rate=0.00025, clipnorm=1.0)
        self.loss_function = keras.losses.Huber()
        self.epochs = epochs
        self.batch_size = batch_size

        self.training_count = 0

        self.model = self.create_model(batch_size=self.batch_size)
        self.predictor_model = self.create_model(batch_size=1) 
        self.predictor_model.set_weights(self.model.get_weights())
        
        self.frozen_func = None
        self.save_as_function(self.predictor_model)
        self.load_frozen_graph()
    
    def __repr__(self):
        return str(self.model.summary())
    
    def create_model(self, batch_size: int):
        input_layer = layers.Input(shape=(8, 8, 1), batch_size=batch_size)
        
        hidden_layer_1 = layers.Conv2D(32, (3, 3), strides=1, activation='relu')(input_layer)
        hidden_layer_2 = layers.Conv2D(64, (3, 3), strides=1, activation='relu')(hidden_layer_1)
        hidden_layer_3 = layers.Conv2D(64, (2, 2), strides=1, activation='relu')(hidden_layer_2)
        
        hidden_layer_4 = layers.Flatten()(hidden_layer_3)
        hidden_layer_5 = layers.Dense(512, activation='relu')(hidden_layer_4)
        
        output_layer = layers.Dense(len(ALL_LEGAL_MOVES), activation='linear')(hidden_layer_5)
        
        model = keras.Model(inputs=input_layer, outputs=output_layer)

        if batch_size != 1:
            model.compile(optimizer=self.optimizer, loss=self.loss_function, metrics=['accuracy'])

        return model
    
    def train(self, features_and_labels, save_path):
        """Features is board state and the labels is the action q values."""

        features, labels = list(zip(*features_and_labels))
        features = np.array(features)
        labels = np.array(labels)

        self.model.fit(features, labels, epochs=self.epochs, batch_size=self.batch_size, shuffle=True)

        self.model.save(save_path)
        self.predictor_model.set_weights(self.model.get_weights())
        self.save_as_function(self.predictor_model)
        self.load_frozen_graph()

        self.training_count += 1
    
    @staticmethod 
    def save_as_function(model):
       # Convert Keras model to ConcreteFunction
        full_model = tf.function(lambda x: model(x))
        full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

        # Get frozen ConcreteFunction
        frozen_func = convert_variables_to_constants_v2(full_model)
        frozen_func.graph.as_graph_def()

        # Save frozen graph from frozen ConcreteFunction to hard drive
        tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir="./frozen_models", name="frozen_graph.pb", as_text=False) 
    
    @staticmethod 
    def wrap_frozen_graph(graph_def, inputs, outputs, print_graph=False):
        _imports_graph_def = lambda: tf.compat.v1.import_graph_def(graph_def, name="")
        wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
        import_graph = wrapped_import.graph

        return wrapped_import.prune(tf.nest.map_structure(import_graph.as_graph_element, inputs), 
                                    tf.nest.map_structure(import_graph.as_graph_element, outputs))
 
    def load_frozen_graph(self):
        with tf.io.gfile.GFile("./frozen_models/frozen_graph.pb", "rb") as f:
            graph_def = tf.compat.v1.GraphDef()
            graph_def.ParseFromString(f.read())
        self.frozen_func = self.wrap_frozen_graph(graph_def=graph_def, inputs=["x:0"], outputs=["Identity:0"], print_graph=True)
        
    def get_Q(self, state):
        state = tf.convert_to_tensor(state, dtype=tf.float32)
        state = tf.reshape(state, self.predictor_model.inputs[0].shape)
        
        return self.frozen_func(state)[0][0]

In [11]:
# Constants
TRAIN_INTERVAL = 2000


class ChessRL:
    def __init__(self, init_reward=None, q_method: QType = QType.DQN, experience_replay_size_threshold: int = 100000, 
                 saved_q: Union[str, os.PathLike, None] = 'dqn_py_chess'):
        
        # Initializing the board
        self.board = chess.Board()
        self.temp_board = chess.Board()
        
        # Initializing the rewards
        self._init_reward = init_reward
        self.reward_lb = -1
        self.reward_ub = 1
        
        # Initialize Q
        self.q_architecture = q_method
        self.dqn_model_path = saved_q
        self.initialize_q(saved_q)
        
        # Experience replay stuff
        self.experience_replay_visits = 0
        self.experience_replay = deque(maxlen=experience_replay_size_threshold)
        self.experience_replay_size_threshold = experience_replay_size_threshold
        
        # All possible actions (Including illegal ones)
        self.action_space_size = len(ALL_LEGAL_MOVES)
        
        # Q cache: Will change every time network is trained again
        self.q_cache = {}
        
        # Extra informations to determine terminal state
        self._sub_episode = None
        self._max_episode_length = None
        
    def initialize_q(self, saved_q: Union[str, os.PathLike, None] = None):

        if self.q_architecture == QType.DQN:
            # Initializing the Deep Q network
            self.dqn = DeepQNetwork()
            # model = self.load_q_from_file(saved_q)
            # if model: self.dqn._model = model

        elif self.q_architecture == QType.TABULAR:
            # Load weights from the already trained model
            self._Q = self.load_q_from_file(saved_q)
        else:
            raise NotImplementedError("Q type not implemented.")
    
    @property
    def init_reward(self):
        if not self._init_reward:
            self._init_reward = lambda: np.random.uniform(-1, 1)
        return self._init_reward()
    
    # Showing the board
    def show_board(self, board=None, size=BOARD_SIZE):
        if not board:
            board = self.board
            
        display(chess.svg.board(board, size=size))
        
    def game_real_time(self):
        """Displays the board and keeps the board positions"""
        while True:
            board, action = yield
            clear_output()
            display(chess.svg.board(board, size=BOARD_SIZE, lastmove=chess.Move.from_uci(action)))
    
    def display_next_board(self, action):
        gamert = self.game_real_time()
        next(gamert)
        gamert.send((self.board, action))
        
    def show_game(self, states, size=BOARD_SIZE):
        """Fix later"""
        t_board = chess.Board()
        boards = []
        for state in states:
            t_board.set_fen(state)
            boards.append(chess.svg.board(t_board, size=size))
        
        wg.interact(lambda x: display(boards[x]), x=wg.IntSlider(min=0, max=len(states)-1, step=1))
        
        del t_board
    
    # State space stuff
    def reset_to_initial_state(self):
        self.board.reset()
    
    @property
    def current_state(self):
        return str(self.board.fen())

    @property
    def turn(self):
        return "white" if self.board.turn else "black"
    
    @property 
    def terminal(self):
        return self.board.is_game_over()
    
    def simplify_state(self, state: str) -> str:
        """Returns only the position of pieces, everything else is removed from fen."""
        self.temp_board.set_fen(state)
        return f"{self.temp_board}".replace(' ', '').replace('\n', '')
    
    def num_rep_state(self, state: str) -> np.array:
        """
        Returns the numerical representation of the board state. 
        Sum can be used for static evaluation of the board.
        This will be used as the input to the neural net.
        """
        state = self.simplify_state(state)
        return np.array([PIECE_MAP[piece] for piece in state]).reshape(8, 8)
   
    # Action stuff 
    def num_rep_action_q_values(self, state: str, action: str = None, value: float = None):
        """
        Returns the Q values for the actions at a given state. 
        This will be used as the label for a given state to train the neural net.
        """
        one_hot_action_Qs = np.zeros(len(ALL_LEGAL_MOVES))
        
        action_qs = {ALL_LEGAL_MOVES[action]: self.normalized_Q(state, action) 
                     for action in self.possible_actions(state)}
        action_qs[ALL_LEGAL_MOVES[action]] = value

        np.put(one_hot_action_Qs, action_qs.keys(), action_qs.values())
        return one_hot_action_Qs
    
    def possible_actions(self, state):
        self.temp_board.set_fen(state)
        return [str(move) for move in self.temp_board.legal_moves]
    
    def choose_greedy_action(self, epsilon: float) -> bool:
        return True if np.random.uniform(0, 1) < (1 - epsilon) else False
        
    def epsilon_greedy_action(self, state, epsilon):
        if self.choose_greedy_action(epsilon):
            return self.argmax_Q(state)
        else:
            actions = self.possible_actions(state)
            random_action_index = random.randint(0, len(actions)-1)
            return actions[random_action_index]
            
    def take_action(self, action: str):
        # Taking action
        self.board.push_uci(action)
        
        # Get the new state because of taken action
        new_state = self.current_state
        
        # Get the reward because of the taken action
        reward = self.reward
        
        return (new_state, reward)
    
    # Reward Stuff
    @property
    def reward(self):
        if self.terminal:
            if self.board.result in ["1-0", "0-1"]:
                return 1
            else:
                return -1
        # if self.board.is_fivefold_repetition():
        #     return -5
        # if self.board.can_claim_threefold_repetition():
        #     return -3
        return 0
        
    
    # Q stuff
    def update_q_cache(self, state, simple_state):
        numerical_state = self.num_rep_state(state)
        normalized_q = self.dqn.get_Q(numerical_state)
        self.q_cache[simple_state] = self.denormalize_Q(normalized_q)
        
    def Q(self, state, action):
        """
        Return the Q values from the Deep Q Network if DQN is used otherwise return q values from the lookup table.
        """
        action = ALL_LEGAL_MOVES[action]
        simple_state = self.simplify_state(state) if isinstance(state, dict) else state
        if simple_state in self.q_cache:
            return self.q_cache[simple_state][action]
        
        if self.q_architecture == QType.DQN:
            # DQN will predict the q values
            self.update_q_cache(state, simple_state)
            return self.q_cache[simple_state][action]

        else: 
            # If q architecture is not a nn then q values will be obtained from the q lookup table.
            if simple_state not in self._Q:
                for a in self.possible_actions(state):
                    self.set_Q(state=state, action=a, value=self.init_reward)
                    
            elif action not in self._Q[simple_state]:
                self._Q[simple_state][action] = self.init_reward
                
            return self._Q[simple_state][action]
    
    def normalized_Q(self, state, action):
        return (self.Q(state, action) - self.reward_lb) / (self.reward_ub - self.reward_lb)
    
    def denormalize_Q(self, normalized_q_value):
        return normalized_q_value * (self.reward_ub - self.reward_lb)

    def set_Q(self, state, action, value=None):
        
        if self.q_architecture == QType.DQN:
            # Keep accumulating data
            state_feature = self.num_rep_state(state)
            action_label = self.num_rep_action_q_values(state, action, value)
            self.experience_replay.append((state_feature, action_label))
            self.experience_replay_visits += 1

            # Train dqn with this accumulated data if have enough data
            train_interval_ok = not self.experience_replay_visits % TRAIN_INTERVAL
            have_enough_training_data = len(self.experience_replay) == self.experience_replay_size_threshold

            if train_interval_ok and have_enough_training_data:            
                # Train if enough data accumulated
                self.dqn.train(self.experience_replay, save_path=self.dqn_model_path)
                self.q_cache = {}

        else:
            state = self.simplify_state(state)
            self._Q.setdefault(state, {})
            self._Q[state][action] = value
        
    def max_Qa(self, state):
        return max([(self.Q(state, action), action) for action in self.possible_actions(state)])
        
    def max_Q(self, state):
        """Returns the max Q value for the best action."""
        return self.max_Qa(state)[0]
    
    def argmax_Q(self, state):
        """Returns the best action for which max Q was found for the input state."""
        return self.max_Qa(state)[1]
        

    # Learning Stuff
    def q_learning(self, episodes, epsilon, alpha, gamma, max_episode_length, show_interval):
        # Initialize progress bar
        f = wg.IntProgress(min=0, max=episodes)
        display(f)
        
        for episode in range(episodes):
            
            # Show progress
            f.value = episode

            # Initialize a state
            self.reset_to_initial_state()
            state = self.current_state
            
            # Storing just to visualize games
            play_states = [state]
            
            while not self.terminal and len(play_states) <= max_episode_length:
                
                # Select epsilon greedy action
                action = self.epsilon_greedy_action(state, epsilon)
                
                # Take action get reward and new state
                new_state, reward = self.take_action(action)
                
                # Storing just to visualize games
                play_states.append(new_state)
                
                if not self.terminal:
                    # Q-learning update
                    off_policy_td = (reward + gamma * self.max_Q(new_state) - self.Q(state, action))
                    new_Q_sa = self.Q(state, action) + alpha * off_policy_td
                else:
                    new_Q_sa = reward
                
                # Update the DQN model or the Q lookup table
                self.set_Q(state, action, new_Q_sa)
                
                # Update the state
                state = new_state        
            
            if episode % show_interval == 0:
                print(len(play_states))
                self.show_game(states=play_states)
                del play_states
        
        self.save_q_to_file('dqn')
        
        
    def sarsa(self, episodes, epsilon, alpha, gamma, max_episode_length, show_interval):    
        # Initialize progress bar
        f = wg.IntProgress(min=0, max=episodes)
        display(f)
        
        for episode in range(episodes):
            
            # Show progress
            f.value = episode

            # Initialize a state
            self.reset_to_initial_state()
            state = self.current_state
            
            # Storing just to visualize a game
            play_states = [state]

            # Select epsilon greedy action
            action = self.epsilon_greedy_action(state, epsilon)
            
            while not self.terminal and len(play_states) <= max_episode_length:
                
                # Take action get reward and new state
                new_state, reward = self.take_action(action)
                
                # Storing just to visualize games
                play_states.append(new_state)
                
                if not self.terminal:
                    # SARSA update
                    new_action = self.epsilon_greedy_action(new_state, epsilon)
                    on_policy_td = (reward + gamma * self.Q(new_state, new_action) - self.Q(state, action))
                    new_Q_sa = self.Q(state, action) + alpha * on_policy_td
                else:
                    new_Q_sa = reward
                
                # Update the DQN model or the Q lookup table
                self.set_Q(state, action, new_Q_sa)
                
                # Update state and action
                state = new_state
                action = new_action
            
            if episode % show_interval == 0:
                print(len(play_states))
                self.show_game(states=play_states)
                del play_states
        
        self.save_q_to_file('sarsa_nn')

    
    def truncated_rollout(self, episodes, epsilon, alpha, gamma, truncation_depth, max_episode_length, show_interval):
        pass
    
    def monte_carlo_tree_search(self, episodes, epsilon, alpha, gamma, truncation_depth, max_episode_length, show_interval):    
        pass
           
           
    # Heuristic policy
    def minimax_with_alpha_beta_pruning(self):
        pass
    
    # Static Evaluation of game state
    def remaining_materials(self, state):
        return Counter(str(self.simplify_state(state)))
    
    def material_score(self, state):
        weighted_materials = self.num_rep_state(state)

        # Replace the knights scores, which is set to 4, to 3
        weighted_materials[weighted_materials == 4] = 3

        # Replace the kings scores from 10 to 200 according to Claude Shannon's formulation 
        weighted_materials[weighted_materials == 10] = 200

        return np.sum(weighted_materials)
    
    def mobility_score(self, state):
        """
        Returns the mobility score at state s 
        mobility score(s) = current players legal moves(s) - opponents legal moves(s)
        """
        # Create a temporary board to check the legal moves number at different states 
        # without affecting the actual game.
        self.temp_board.set_fen(state)
        
        # Get the current player's number legal moves
        current_player_moves = len(self.temp_board.legal_moves)

        # Pass control to the opponent
        self.temp_board.push(chess.Move.null())

        # Get the opponent player's number of legal move 
        opponents_moves = len(self.temp_board.legal_moves)
        
        return current_player_moves - opponents_moves
        
    def static_analysis_of_board(self, state):
        """Returns the positional evaluation of the board."""
        return self.material_score(state) + self.mobility_score(state) 
    
    # Saving and loading q values 
    def save_q_to_file(self, filename: str):
        filename += 'py_chess'
        if self.q_architecture == QType.DQN:
            self.dqn.model.save(filename)

        elif self.q_architecture == QType.TABULAR: 
            with open(filename, 'wb') as file:
                pickle.dump(self._Q, file, protocol=pickle.HIGHEST_PROTOCOL)

        else:
            raise NotImplementedError("Unknown Q architecture")
        
    def load_q_from_file(self, filename: str):

        if self.q_architecture == QType.DQN:
            dqn_model = None
            try:
                dqn_model = keras.models.load_model(filename)
            except FileNotFoundError:
                print('File does not exist')
            except OSError:
                print('File does not exist')

            return dqn_model

        elif self.q_architecture == QType.TABULAR:
            q_lookup_table = {}
            try:
                with open(filename, 'rb') as file:
                    q_lookup_table = pickle.load(file)
            except FileNotFoundError:
                pass
            return q_lookup_table

        else:
            raise NotImplementedError("Unknown Q architecture")
    
    def play(self):
        self.reset_to_initial_state()
        state = self.current_state
        states = [state]
        while not self.terminal:
            action = self.argmax_Q(state)
            state, _ = self.take_action(action)
            states.append(state)
        self.show_game(states=states)
            

In [12]:
with cProfile.Profile() as pr:
    C = ChessRL(saved_q='dqn_py_chess', experience_replay_size_threshold=1000)
    C.q_learning(episodes=2, epsilon=0.5, alpha=0.1, gamma=0.95, max_episode_length=200, show_interval=500)
    # C.sarsa(episodes=10000, epsilon=0.5, alpha=0.1, gamma=0.95, max_episode_length=200, show_interval=500)
    # C.play()
    

IntProgress(value=0, max=2)

201


interactive(children=(IntSlider(value=0, description='x', max=200), Output()), _dom_classes=('widget-interact'…

INFO:tensorflow:Assets written to: dqn_pychess\assets


In [13]:
from pprint import pprint
pprint(pr.print_stats())

         11855584 function calls (11473706 primitive calls) in 10.040 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     1202    0.001    0.000    0.047    0.000 193473104.py:101(terminal)
      801    0.003    0.000    0.161    0.000 193473104.py:105(simplify_state)
      801    0.003    0.000    0.174    0.000 193473104.py:110(num_rep_state)
      801    0.003    0.000    0.003    0.000 193473104.py:117(<listcomp>)
      400    0.033    0.000    3.114    0.008 193473104.py:120(num_rep_action_q_values)
     1200    0.004    0.000    0.305    0.000 193473104.py:133(possible_actions)
     1200    0.019    0.000    0.183    0.000 193473104.py:135(<listcomp>)
      400    0.001    0.000    0.007    0.000 193473104.py:137(choose_greedy_action)
      400    0.001    0.000    1.560    0.004 193473104.py:140(epsilon_greedy_action)
      400    0.001    0.000    0.078    0.000 193473104.py:148(take_action)
      400    0.000    0

In [None]:
# Try initializing a constant q value for all (state, action) pair. Since the probability of getting that initial q
# value from the model is zero (because q values are continuous) if we see the initial q value then it would mean we
# have not explored this (state, action) pair yet. So, if q value is the initial constant, take epsilon-greedy action with
# epsilon = 1 otherwise epsilon = 0.

In [None]:
def play():
    C.reset_to_initial_state()
    state = C.current_state
    states = [state]
    
    i = 0
    while not C.terminal:
        action = C.epsilon_greedy_action(state, 0)
        
        # if not i % 2:
        #     action = C.epsilon_greedy_action(state, 0)
        # else:
        #     while True:
        #         action = input(f'move {C.possible_actions(state)}')
        #         if action in C.possible_actions(state):
        #             break
                
        # time.sleep(0.1)
        # action = C.argmax_Q(state)
        state, _ = C.take_action(action)
        states.append(state)
        i += 1
        
        C.display_next_board(action)
    C.show_game(states=states)

In [None]:
play()