Copyright **`(c)`** 2023 Giovanni Squillero `<giovanni.squillero@polito.it>`  
[`https://github.com/squillero/computational-intelligence`](https://github.com/squillero/computational-intelligence)  
Free for personal or classroom use; see [`LICENSE.md`](https://github.com/squillero/computational-intelligence/blob/master/LICENSE.md) for details.  

# LAB4

Use reinforcement learning to devise a tic-tac-toe player.

### Deadlines:

* Submission: [Dies Natalis Solis Invicti](https://en.wikipedia.org/wiki/Sol_Invictus)
* Reviews: [Befana](https://en.wikipedia.org/wiki/Befana)

Notes:

* Reviews will be assigned  on Monday, December 4
* You need to commit in order to be selected as a reviewer (ie. better to commit an empty work than not to commit)

In [1]:
import numpy as np
from itertools import combinations
from collections import namedtuple
from random import choice
from copy import deepcopy
from collections import defaultdict

## Numerical tic-tac-toe
- ``Position`` describe the state of our game. With ``x`` we define the numbers placed by player 1 while with ``o`` those of player 2.
- ``Magic`` is the matrix where the sum of each row, column and diagonal is 15. it's used for the ``win`` function.

In [2]:
State = namedtuple('State', ['x', 'o'])
MAGIC = [2, 7, 6, 
         9, 5, 1, 
         4, 3, 8]

## Utility functions

``print_board``: it just print the game board given the state of it

In [3]:
def print_board(pos):
    for r in range(3):
        for c in range(3):
            index = r * 3 + c
            if MAGIC[index] in pos.x:
                print('X', end='')
            elif MAGIC[index] in pos.o:
                print('O', end='')
            else:
                print(' ', end='')
        print()
    print()

## Reward functions
- ``win``: return true if the position of the game state is a win so if the sum of a row, column or diagonal is 15
- ``state_value``: evaluate the state and return 1 if the first player won, -1 if the second player won or 0 if neither player won.
- ``random_game``: function that play a game with random move. It returns a list of every board state seen in the game just played.

In [4]:
def win (squares):
    return any(sum(c) == 15 for c in combinations(squares, 3))

def state_value(position: State):
    if win(position.x):
        return 1
    elif win(position.o):
        return -1
    else:
        return 0

def random_game():
    trajectory = list()
    state = State(set(), set())
    available = set(range(1, 9+1))
    while available:
        x = choice(list(available))
        state.x.add(x)
        trajectory.append(deepcopy(state))
        available.remove(x)
        if win(state.x) or not available:
            break
        
        o = choice(list(available))
        state.o.add(o)
        trajectory.append(deepcopy(state))
        available.remove(o)
        if win(state.o):
            break
    return trajectory

## Code Test

Quello che succede qua è che viene fatto un game con mosse random per un certo numero ``steps`` di volte. ``trajectory`` sarebbe una lista di tutti gli stati del gioco visti nella partita appena svolta. Viene fatto quindi una hash table con tutti gli stati visti per tutte le partite svolte e il reward che ha portato quella mossa sempre in tutte le partite.

In [6]:
value_dict = defaultdict(float)
epsilon = .001
 
for steps in range(10000):
    trajectory = random_game()
    final_reward = state_value(trajectory[-1])
    for state in trajectory:
        hashable_state = (frozenset(state.x), frozenset(state.o))
        value_dict[hashable_state] = value_dict[hashable_state] + epsilon * (final_reward - value_dict[hashable_state])

In [7]:
sorted(value_dict.items(), key=lambda e:e[1], reverse=True)[:10]

[((frozenset({5}), frozenset()), 0.3271607449087342),
 ((frozenset({6}), frozenset()), 0.22887483403383943),
 ((frozenset({4}), frozenset()), 0.22595169171854304),
 ((frozenset({2}), frozenset()), 0.20876481660630122),
 ((frozenset({8}), frozenset()), 0.18909459442766977),
 ((frozenset({9}), frozenset()), 0.17073537963913504),
 ((frozenset({3}), frozenset()), 0.1437255238501484),
 ((frozenset({1}), frozenset()), 0.13388542265102382),
 ((frozenset({7}), frozenset()), 0.10630248965396563),
 ((frozenset({6}), frozenset({1})), 0.07626599721500074)]

In [21]:
any(sum(c) == 15 for c in combinations({1,2,3,4,5,6}, 3))

True

In [26]:
state = State({2,4,3,8}, {6,9,5})
state

Position(x={8, 2, 3, 4}, o={9, 5, 6})