In [14]:
from itertools import combinations
from collections import namedtuple, defaultdict
from random import choice
from copy import deepcopy
from abc import ABC, abstractmethod

from tqdm.auto import tqdm
import numpy as np

## Problem : Reinforcement Learning for Tic-Tac-Toe Game

In general, to develop a reinforcement learning algorithm, we need to define the following components:
* **Environment** 
  * Possible states in the game environment
  * Possible actions in each state
  * Rewards for each action in each state

* **Agent**
  * Policy: the strategy to choose an action given a state
  * Value function: the expected return of each state under a given policy
  * Model: the agent's representation of the environment

* **Learning Algorithm**
  * How the agent updates its policy and value function based on the experience

In this problem, we will implement a reinforcement learning algorithm for the Tic-Tac-Toe game.

State is a namedtuple with two fields, x and o, representing the positions of X and O in the board.

MAGIC is a list of values that can be used to check whether a player has won the game. They are based on the magic square of order 3.

<table>
  <tr>
    <td>2</td>
    <td>7</td>
    <td>6</td>
  </tr>
  <tr>
    <td>9</td>
    <td>5</td>
    <td>1</td>
  </tr>
  <tr>
    <td>4</td>
    <td>3</td>
    <td>8</td>
  </tr>
</table>

In this way, the sum of three numbers in any row, column, or diagonal is always 15.

In [15]:
State = namedtuple('State', ['x', 'o'])
MAGIC = [2, 7, 6, 9, 5, 1, 4, 3, 8]

## General Player

In [16]:
# stolen from quixo repo
class Player(ABC):
    def __init__(self) -> None:
        '''You can change this for your player if you need to handle state/have memory'''
        pass

    @abstractmethod
    def make_move(self):
        '''
        game: the Quixo game. You can use it to override the current game with yours, but everything is evaluated by the main game
        return values: this method shall return a tuple of X,Y positions and a move among TOP, BOTTOM, LEFT and RIGHT
        '''
        pass

## Game

In [17]:
class Game:
    def __init__(self):
        self.state = State(set(), set()) # actual state of the game
        self.trajectory = list() # list of states of the game
        self.available_moves = set(range(1, 10)) # available moves
        self.winner = None # winner of the game

    def play(self, player1, player2):
        """Play a game between two players"""
        # print("available moves: ", self.available_moves)
        local_winner = -1
        while local_winner == -1 and len(self.available_moves) > 0:
            # player1 makes a move
            move = player1.make_move(self.state, self.available_moves)
            # print("player1 move: ", move)

            # the move is added to the state
            self.state.x.add(move)

            # the trajectory is updated
            self.trajectory.append(deepcopy(self.state))

            # the move is removed from the available moves
            self.available_moves.remove(move)
            # print("available moves: ", self.available_moves)

            # check if the game is over
            local_winner = self.check_winner()
            if local_winner != -1 or len(self.available_moves) == 0:
                break

            # same for player2
            move = player2.make_move(self.state, self.available_moves)
            # print("player2 move: ", move)
            self.state.o.add(move)
            self.trajectory.append(deepcopy(self.state))
            self.available_moves.remove(move)
            # print("available moves: ", self.available_moves)
            local_winner = self.check_winner()
            if local_winner != -1 or len(self.available_moves) == 0:
                break

        self.winner = local_winner
        # print("winner: ", self.winner)

    def check_winner(self):
        """Set the winner: 1 for player1, 2 for player2, -1 for draw"""
        if self.win(self.state.x):
            return 1
        elif self.win(self.state.o):
            return 2
        else:
            return -1
        
    # win() function checks if any of the combinations of 3 elements in the set sums to 15 (winning condition)
    def win(self, elements):
        """Checks if elements is winning"""
        return any(sum(c) == 15 for c in combinations(elements, 3))

defaultdict is a subclass of dict that returns a default value when the key is not found, so that it is not needed to check whether a key is in the dictionary.

frozenset is an immutable version of set, which can be used as a key in a dictionary.

## Random Player

In [18]:
class RandomPlayer(Player):
    def __init__(self):
        super().__init__()

    def make_move(self, state, available_moves):
        return choice(list(available_moves))

## Reinforcement Learning Player

In [59]:
class reinforcement_player():
    def __init__(self, player_index, random_move = 0.0):
        self.value_dictionary = defaultdict(float) # state of the game and its value
        # self.trajectory = list() # list of states visited during the game, used to update the value_dictionary
        self.hit_state = defaultdict(int) # state of the game and how many times it was visited during the training phase
        self.epsilon = 0.001 # learning rate
        self.player_index = player_index # index of the player (1 or 2)
        self.random_move = random_move # a value between 0 and 1, used to choose a random move when training

    # in make_moves we have to sometimes choose a random move when training
    def make_move(self, state, available_moves):
        """Returns best move for the actual state"""
        # it checks the value of the new_state for each possible move and returns the move with the highest value
        best_move_score = -1
        best_move = None
        if np.random.rand() < self.random_move:
            return choice(list(available_moves))
        else:
            for move in available_moves:
                new_state = deepcopy(state)
                hashable_state = (frozenset(new_state.x), frozenset(new_state.o))
                if self.player_index == 1:
                    new_state.x.add(move)
                else:
                    new_state.o.add(move)
                actual_move_score = self.value_dictionary[hashable_state]
                if actual_move_score > best_move_score:
                    best_move_score = actual_move_score
                    best_move = move

        return best_move
    
    def give_reward(self, reward, trajectory):
        """Updates the value of the states visited during the game"""
        for state in reversed(trajectory):
            hashable_state = (frozenset(state.x), frozenset(state.o))
            self.hit_state[hashable_state] += 1
            self.value_dictionary[hashable_state] += self.epsilon * (reward - self.value_dictionary[hashable_state])

    def print_value_dictionary(self):
        """Prints the value of each state"""
        return sorted(self.value_dictionary.items(), key=lambda e: e[1], reverse=True)[:10]
    
    def set_random_move(self, random_move):
        """Sets the value of random_move"""
        self.random_move = random_move


In [82]:
test_game = Game()
player1 = reinforcement_player(1, 1)
player2 = RandomPlayer()

for _ in tqdm(range(500_000)):
    # decrease the probability of choosing a random move by 50% every 100_000 games
    if _ % 100_000 == 0:
        player1.set_random_move(player1.random_move / 10)
    test_game = Game()
    test_game.play(player1, player2)
    if (test_game.winner == 1):
        player1.give_reward(1, test_game.trajectory)
    elif (test_game.winner == 2):
        player1.give_reward(-1, test_game.trajectory)
    else:
        if(player1.player_index == 1):
            player1.give_reward(0.1, test_game.trajectory)
        else:
            player1.give_reward(0.3, test_game.trajectory)

  0%|          | 0/500000 [00:00<?, ?it/s]

In [83]:
print(len(player1.value_dictionary))
print(len(player1.hit_state))
print(player1.print_value_dictionary())

3060
3059
[((frozenset({1, 2, 3, 4, 8}), frozenset({9, 5, 6, 7})), 0.9999999999999303), ((frozenset({1, 2, 3, 4, 9}), frozenset({8, 5, 6, 7})), 0.9999999999999292), ((frozenset({1, 2, 3, 4}), frozenset({8, 5, 6, 7})), 0.9999999999998813), ((frozenset({1, 2, 3, 4}), frozenset({9, 5, 6, 7})), 0.9999999999998788), ((frozenset({1, 2, 3, 5, 7}), frozenset({8, 9, 4, 6})), 0.9999999998844641), ((frozenset({1, 2, 3, 5, 8}), frozenset({9, 4, 6, 7})), 0.9999999998780503), ((frozenset({1, 2, 3, 5, 9}), frozenset({8, 4, 6, 7})), 0.9999999998770702), ((frozenset({1, 2, 3, 5}), frozenset({8, 9, 4, 6})), 0.9999999998176714), ((frozenset({1, 2, 3, 5}), frozenset({8, 4, 6, 7})), 0.9999999998094659), ((frozenset({1, 2, 3, 5}), frozenset({9, 4, 6, 7})), 0.9999999998071645)]


In [84]:
player1.set_random_move(0.0)

win_rate = 0
draw_rate = 0
loss_rate = 0
for _ in tqdm(range(10000)):
    test_game = Game()
    test_game.play(player1, player2)
    if (test_game.winner == 1):
        win_rate += 1
    elif (test_game.winner == -1):
        draw_rate += 1
    else:
        loss_rate += 1

print("wins: ", win_rate/100, "%")
print("draws: ", draw_rate/100, "%")
print("losses: ", loss_rate/100, "%")  

  0%|          | 0/10000 [00:00<?, ?it/s]

wins:  52.25 %
draws:  30.93 %
losses:  16.82 %
