In [26]:
from itertools import combinations
from collections import namedtuple, defaultdict
from random import choice
from copy import deepcopy
from abc import ABC, abstractmethod

from tqdm.auto import tqdm
import numpy as np

## Problem : Reinforcement Learning for Tic-Tac-Toe Game

In general, to develop a reinforcement learning algorithm, we need to define the following components:
* **Environment** 
  * Possible states in the game environment
  * Possible actions in each state
  * Rewards for each action in each state

* **Agent**
  * Policy: the strategy to choose an action given a state
  * Value function: the expected return of each state under a given policy
  * Model: the agent's representation of the environment

* **Learning Algorithm**
  * How the agent updates its policy and value function based on the experience

In this problem, we will implement a reinforcement learning algorithm for the Tic-Tac-Toe game.

State is a namedtuple with two fields, x and o, representing the positions of X and O in the board.

MAGIC is a list of values that can be used to check whether a player has won the game. They are based on the magic square of order 3.

<table>
  <tr>
    <td>2</td>
    <td>7</td>
    <td>6</td>
  </tr>
  <tr>
    <td>9</td>
    <td>5</td>
    <td>1</td>
  </tr>
  <tr>
    <td>4</td>
    <td>3</td>
    <td>8</td>
  </tr>
</table>

In this way, the sum of three numbers in any row, column, or diagonal is always 15.

In [2]:
State = namedtuple('State', ['x', 'o'])
MAGIC = [2, 7, 6, 9, 5, 1, 4, 3, 8]

## General Player

In [28]:
# stolen from quixo repo
class Player(ABC):
    def __init__(self) -> None:
        '''You can change this for your player if you need to handle state/have memory'''
        pass

    @abstractmethod
    def make_move(self):
        '''
        game: the Quixo game. You can use it to override the current game with yours, but everything is evaluated by the main game
        return values: this method shall return a tuple of X,Y positions and a move among TOP, BOTTOM, LEFT and RIGHT
        '''
        pass

## Game

In [51]:
class Game:
    def __init__(self):
        self.state = State(set(), set()) # actual state of the game
        self.trajectory = list() # list of states of the game
        self.available_moves = set(range(1, 10)) # available moves
        self.winner = None # winner of the game

    def play(self, player1, player2):
        """Play a game between two players"""
        print("available moves: ", self.available_moves)
        local_winner = -1
        while local_winner == -1 and len(self.available_moves) > 0:
            # player1 makes a move
            move = player1.make_move(self.state, self.available_moves)
            print("player1 move: ", move)

            # the move is added to the state
            self.state.x.add(move)

            # the move is removed from the available moves
            self.available_moves.remove(move)
            print("available moves: ", self.available_moves)

            # check if the game is over
            local_winner = self.check_winner()
            if local_winner != -1 or len(self.available_moves) == 0:
                break

            # same for player2
            move = player2.make_move(self.state, self.available_moves)
            print("player2 move: ", move)
            self.state.o.add(move)
            self.available_moves.remove(move)
            print("available moves: ", self.available_moves)
            local_winner = self.check_winner()
            if local_winner != -1 or len(self.available_moves) == 0:
                break

        self.winner = local_winner
        print("winner: ", self.winner)

    def check_winner(self):
        """Set the winner: 1 for player1, 2 for player2, -1 for draw"""
        if self.win(self.state.x):
            return 1
        elif self.win(self.state.o):
            return 2
        else:
            return -1
        
    # win() function checks if any of the combinations of 3 elements in the set sums to 15 (winning condition)
    def win(self, elements):
        """Checks if elements is winning"""
        return any(sum(c) == 15 for c in combinations(elements, 3))

defaultdict is a subclass of dict that returns a default value when the key is not found, so that it is not needed to check whether a key is in the dictionary.

frozenset is an immutable version of set, which can be used as a key in a dictionary.

In [None]:
value_dictionary = defaultdict(float) # state of the game and its value
hit_state = defaultdict(int) # state of the game and how many times it was visited during the training phase
epsilon = 0.001

for steps in tqdm(range(500_000)):
    trajectory = random_game()
    final_reward = state_value(trajectory[-1])
    for state in trajectory:
        hashable_state = (frozenset(state.x), frozenset(state.o))
        hit_state[hashable_state] += 1
        value_dictionary[hashable_state] = value_dictionary[
            hashable_state
        ] + epsilon * (final_reward - value_dictionary[hashable_state])

## Random Player

In [33]:
class RandomPlayer(Player):
    def __init__(self):
        super().__init__()

    def make_move(self, state, available_moves):
        return choice(list(available_moves))

## Reinforcement Learning Player

In [None]:
class reinforcement_player():
    def __init__(self):
        self.value_dictionary = defaultdict(float) # state of the game and its value
        self.states = list() # list of states visited during the game, used to update the value_dictionary
        self.hit_state = defaultdict(int) # state of the game and how many times it was visited during the training phase
        self.epsilon = 0.001 # learning rate

    def make_move(self, state):
        """Returns best move for state"""
        hashable_state = (frozenset(state.x), frozenset(state.o))
        possible_moves = set(range(1, 10)) - set(state.x + state.o)
        possible_states = [
            State(state.x + [move], state.o) for move in possible_moves
        ]
        possible_values = [
            self.value_dictionary[(frozenset(s.x), frozenset(s.o))]
            for s in possible_states
        ]
        return possible_states[np.argmax(possible_values)]


In [83]:
test_game = Game()
player1 = RandomPlayer()
player2 = RandomPlayer()
test_game.play(player1, player2)

available moves:  {1, 2, 3, 4, 5, 6, 7, 8, 9}
player1 move:  6
available moves:  {1, 2, 3, 4, 5, 7, 8, 9}
player2 move:  7
available moves:  {1, 2, 3, 4, 5, 8, 9}
player1 move:  9
available moves:  {1, 2, 3, 4, 5, 8}
player2 move:  8
available moves:  {1, 2, 3, 4, 5}
player1 move:  3
available moves:  {1, 2, 4, 5}
player2 move:  1
available moves:  {2, 4, 5}
player1 move:  2
available moves:  {4, 5}
player2 move:  5
available moves:  {4}
player1 move:  4
available moves:  set()
winner:  1
