In [260]:
import random
from collections import defaultdict


class TicTacToe:
    def __init__(self, curr_player):
        self.board_rows = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]  # rows
        self.board_cols = [[1, 4, 7], [2, 5, 8], [3, 6, 9]]  # columns
        self.board_diag = [[1, 5, 9], [3, 5, 7]]             # diagonals
        self.x_positions = set()
        self.o_positions = set()
        self.free_positions = set(range(1, 10)) # at the beginning they are all free
        self.current_player = curr_player

    def make_move(self, position):
        # Adds a position w.r.t the player taken positions
        self.x_positions.add(position) if self.current_player == 'X' else self.o_positions.add(position)



    ### ALL THESE FUNCTIONS ARE CALLED ONCE AFTER THE MOVE HAS BEEN DONE ###
    def check_winner(self):
        # Check rows, columns, and diagonals for a win
        win_patterns = [tuple(row) for row in self.board_rows] + [tuple(col) for col in self.board_cols] + [tuple(diag) for diag in self.board_diag]

        for pattern in win_patterns:
            if all(pos in self.x_positions for pos in pattern):
                return 'X'
            elif all(pos in self.o_positions for pos in pattern):
                return 'O'
        return 'P'  # No winner yet
    
    def check_win(self):
        win_patterns = [tuple(row) for row in self.board_rows] + [tuple(col) for col in self.board_cols] + [tuple(diag) for diag in self.board_diag]
        
        opponent_positions = self.o_positions if self.current_player == 'X' else self.x_positions
        my_positions = self.x_positions if self.current_player == 'X' else self.o_positions

        for pattern in win_patterns:
            if all(pos in my_positions for pos in pattern):
                return 10 #you win
            elif all(pos in opponent_positions for pos in pattern):
                return -10 #opponent wins

        return 0 #nobody wins
    
    
    def check_is_winning(self, possible_positions):
        win_patterns = [tuple(row) for row in self.board_rows] + [tuple(col) for col in self.board_cols] + [tuple(diag) for diag in self.board_diag]
        
        for pattern in win_patterns:
            if all(pos in possible_positions for pos in pattern):
                return True
        return False


    def switch_player(self):
        self.current_player = 'O' if self.current_player == 'X' else 'X'


    ### METHODS USED TO CHECK POSITIVE REWARD ###

    def check_two_pattern(self, position): # checks wether our move is going to fill a row or a colum with 2 X's or O's
        positions = self.x_positions if self.current_player == 'X' else self.o_positions
        
        for i in range(3):
            # Check rows, columns and diags
            if (positions.intersection(self.board_rows[i]) == 2 and position in self.board[i]) or (positions.intersection(self.board_cols[i]) == 2 and position in self.board_cols[i]): # if the position I choose belong to a 2-value column/row or diagonal, then retrieve the reward
                return 1
            
            if i==0 or i==1:
               if (positions.intersection(self.board_diag[i]) == 2 and position in self.board_diag[i]): # if the position I choose belong to a 2-value column/row or diagonal, then retrieve the reward
                return 1 
        return 0
               
    def block_opponent(self, position): # tells if we actually are blocking the opponent to win
        
        opponent_positions = self.o_positions if self.current_player == 'X' else self.x_positions
        opponent_positions_tmp = opponent_positions.union({position})
        #win_patterns = [self.board_rows, self.board_cols, self.board_diag]
        win_patterns = [tuple(row) for row in self.board_rows] + [tuple(col) for col in self.board_cols] + [tuple(diag) for diag in self.board_diag]

        for pattern in win_patterns:
            if all(pos in opponent_positions_tmp for pos in pattern) and position in pattern: # in case the opponent would have already a col/row filled with two values, and we're going to fill the entire rol/col
                return 2
        return 0
    
    def create_fork(self): #tells me if my move is going to let me win ! I have two possible way of winning at once, regardless of which action the opponent will take
        # I traverse all the blanck spaces (number still not assigned), and i look w.r.t. my positions, if i could win at leat twice
        positions = self.x_positions if self.current_player == 'X' else self.o_positions
        num_wins = 0

        for possible_winning_move in range(1, 10):
            if possible_winning_move in self.free_positions:
                possible_positions = positions.union({possible_winning_move})

                if self.check_is_winning(possible_positions):
                    num_wins += 1

        if num_wins >= 2:
            return 5
        
        return 0
    

    ### METHODS USED TO CHECK NEGATIVE REWARD ###

    def allow_opponent_two_pattern(self, position):
        opponent_positions = self.o_positions if self.current_player == 'X' else self.x_positions
        opponent_positions_tmp = opponent_positions.union({position})
        
        for i in range(3):
            # Check rows, columns and diags
            if (opponent_positions_tmp.intersection(self.board_rows[i]) == 2 and position in self.board[i]) or (opponent_positions_tmp.intersection(self.board_cols[i]) == 2 and position in self.board_cols[i]): # if the position I choose belong to a 2-value column/row or diagonal, then retrieve the reward
                return -1
            
            if i==0 or i==1:
               if (opponent_positions_tmp.intersection(self.board_diag[i]) == 2 and position in self.board_diag[i]): # if the position I choose belong to a 2-value column/row or diagonal, then retrieve the reward
                return -1 
        return 0
    
    def allow_opponnent_fork(self, position):
        opponent_positions = self.o_positions if self.current_player == 'X' else self.x_positions
        opponent_positions_tmp = opponent_positions.union({position})
        num_wins = 0

        for possible_winning_move in range(1, 10):
            if possible_winning_move in self.free_positions:
                possible_positions = opponent_positions_tmp.union({possible_winning_move})

                if self.check_is_winning(possible_positions):
                    num_wins += 1

        if num_wins >= 2:
            return -5
        return 0
    

## REINFORCEMENT LEARNING ALGORITHM

In [261]:
#number of matches we will run
n_matches = 100000

#initialize the exploration probability
exploration_prob = 0.2

#discounted factor
gamma = 0.95

#learning rate
lr = 0.01

# Initialize Q-table
Q_table = defaultdict()


In [262]:
def check_Q_table(Q_table: defaultdict, tic_tac_toe: TicTacToe): #return the state-actions related to a specific state
    hashable_state = (frozenset(tic_tac_toe.x_positions), frozenset(tic_tac_toe.o_positions))

    if hashable_state not in Q_table:
        # in case the actual state is not contained inside the Q_table, we insert it and initialize all state-actions to 0.0
        Q_table[hashable_state] = list([(action, 0.0) for action in tic_tac_toe.free_positions])
        
    return Q_table[hashable_state]


def epsilon_greedy_policy(exploration_prob, state_actions, tic_tac_toe): #based on epsilon, i take a random free position, or the one with the highest value in the Q_table
    
    if random.random() < exploration_prob:
        position = random.choice(list(tic_tac_toe.free_positions)) 
    else :
        position = max(state_actions, key=lambda x: x[1])[0] # take the position (action) with the maximum Q-value

    tic_tac_toe.free_positions.remove(position)
    tic_tac_toe.make_move(position)
    #exploration_prob -= exploration_decreasing_decay

    return position

def compute_reward(position, tic_tac_toe):
    total_reward = 0

    total_reward += tic_tac_toe.check_two_pattern(position)
    total_reward += tic_tac_toe.block_opponent(position)
    total_reward += tic_tac_toe.create_fork()
    total_reward += tic_tac_toe.allow_opponent_two_pattern(position)
    total_reward += tic_tac_toe.allow_opponnent_fork(position)
    total_reward += tic_tac_toe.check_win()

    # prova a farlo con una media sul numero di possibili reward
    if total_reward == 0: # this means that nothing meaningful happened in the game with our move
        return 0.5 # we give a little reward
    
    return total_reward


def update_Q_table(Q_table, tic_tac_toe, curr_reward, state_actions_curr, position_chosen, hashable_state_curr):
    if len(tic_tac_toe.free_positions) != 0:
        next_state_actions = check_Q_table(Q_table, tic_tac_toe)  # verifies the next state best action value
        #print(f'{next_state_actions}')
        expected_return = curr_reward + gamma*(max(next_state_actions, key=lambda x: x[1])[1])
        index = next(i for i, v in enumerate(state_actions_curr) if v[0] == position_chosen)

        updated_action_val = Q_table[hashable_state_curr][index][1] + lr*(expected_return - Q_table[hashable_state_curr][index][1])
        
        Q_table[hashable_state_curr][index] = (position_chosen, updated_action_val)


In [263]:
# Example usage:

def train_model(curr_player):
    reward_game_curr = 0
    num_iter = 0
    tic_tac_toe = TicTacToe(curr_player) # curr player starts first

    while len(tic_tac_toe.free_positions) != 0:
        num_iter += 1

        # 1- CHECK THAT THE EXPECTED ENTRY IS ACTUALLY IN THE Q TABLE, OTHERWISE I CREATE AND INSERT IT
        state_actions_curr = check_Q_table(Q_table, tic_tac_toe)
        hashable_state_curr = (frozenset(tic_tac_toe.x_positions), frozenset(tic_tac_toe.o_positions))

        # 2- BASED ON A CERTAIN PROBABILITY THRESHOLD, I'LL CHOOSE A RANDPOM VALUE OR A VALUE FROM THE Q TABLE (EXPLORATION/EXPLOITATION) 
        position_chosen = epsilon_greedy_policy(exploration_prob, state_actions_curr, tic_tac_toe) #we modify the state with the call to this function

        # 3- AT THE END, I COMPUTE THE COST OF MY LAST ACTION
        curr_reward = compute_reward(position_chosen, tic_tac_toe)
        #print(f'{curr_reward}')
        reward_game_curr += curr_reward

        # 4- THE Q TABLE IS UPDATED BASED ON THE NEW STATE AND THE OBTAINED REWARD
        update_Q_table(Q_table, tic_tac_toe, curr_reward, state_actions_curr, position_chosen, hashable_state_curr)

        if tic_tac_toe.check_winner() == 'X' or tic_tac_toe.check_winner() == 'O': # In case someone wins, then stop the game
            break
    
        tic_tac_toe.switch_player()

    return reward_game_curr/num_iter


def use_model(curr_player):
    tic_tac_toe = TicTacToe(curr_player)

    while len(tic_tac_toe.free_positions) != 0:
        if tic_tac_toe.current_player == 'X':
            hashable_state_curr = (frozenset(tic_tac_toe.x_positions), frozenset(tic_tac_toe.o_positions))

            if hashable_state_curr in Q_table: # NOT ALL THE POSSIBLE STATES COULD BE EMBEDDED INTO THE Q TABLE, SO WE TAKE A RANDOM MOVE ONCE IN A WHILE
                position = max(Q_table[hashable_state_curr], key=lambda x: x[1])[0]
            else: position = random.choice(list(tic_tac_toe.free_positions)) 
            
            tic_tac_toe.free_positions.remove(position)
            tic_tac_toe.make_move(position)
        
        else:
            position = random.choice(list(tic_tac_toe.free_positions))
            tic_tac_toe.free_positions.remove(position)
            tic_tac_toe.make_move(position)

        if tic_tac_toe.check_winner() == 'X' or tic_tac_toe.check_winner() == 'O': # In case someone wins, then stop the game
            break
    
        tic_tac_toe.switch_player()

    if tic_tac_toe.check_winner() == 'X':
        return 1
    return 0

## TRAINING 

In [264]:
import itertools


rewards_stats = []

for step in range(n_matches):
    if step%2 == 0:
        rewards_stats.append(train_model('O')) # play starting with O first
    else: rewards_stats.append(train_model('X')) # play starting with X first



## INFERENCE

In [265]:
wins = 0 
n_matches_inference = 1000

for step in range(n_matches_inference):
    if step%2 == 0:
        wins += use_model('O')          
    else: wins += use_model('X')
    

print(f'Accuracy playing with a random gamer: {wins/n_matches_inference}')

Accuracy playing with a random gamer: 0.718


In [266]:
from itertools import chain
import matplotlib.pyplot as plt

for key, value in itertools.islice(Q_table.items(), 10):
    print(key, value)


(frozenset(), frozenset()) [(1, 16.511233098379996), (2, 13.264024049876095), (3, 11.806731712284375), (4, 9.894354189959099), (5, 9.38981471745391), (6, 10.387484106089376), (7, 10.658189390722166), (8, 11.119415794956796), (9, 10.936421767621058)]
(frozenset(), frozenset({1})) [(2, 9.621842116450242), (3, 16.216555037535446), (4, 8.467994690983092), (5, 9.130541669746316), (6, 6.925915894225976), (7, 9.993067625250026), (8, 11.43101964734717), (9, 10.780756780388629)]
(frozenset({3}), frozenset({1})) [(2, 16.508421969562338), (4, 12.236252881742441), (5, 6.259631519253611), (6, 15.325586923556317), (7, 12.995621508709618), (8, 11.813560275490834), (9, 14.697150332415449)]
(frozenset({3}), frozenset({1, 2})) [(4, 16.976324090439675), (5, 4.580828825842993), (6, 11.988136822898475), (7, 15.250304794570345), (8, 12.202154789041087), (9, 9.491113829937804)]
(frozenset({3, 5}), frozenset({1, 2})) [(4, -0.9044137581866071), (6, -0.733009841046057), (7, -0.2115267118609852), (8, 10.22866720