Copyright **`(c)`** 2023 Giovanni Squillero `<giovanni.squillero@polito.it>`  
[`https://github.com/squillero/computational-intelligence`](https://github.com/squillero/computational-intelligence)  
Free for personal or classroom use; see [`LICENSE.md`](https://github.com/squillero/computational-intelligence/blob/master/LICENSE.md) for details.  

# LAB10

Use reinforcement learning to devise a tic-tac-toe player.

### Deadlines:

* Submission: [Dies Natalis Solis Invicti](https://en.wikipedia.org/wiki/Sol_Invictus)
* Reviews: [Befana](https://en.wikipedia.org/wiki/Befana)

Notes:

* Reviews will be assigned  on Monday, December 4
* You need to commit in order to be selected as a reviewer (ie. better to commit an empty work than not to commit)

In [156]:
from itertools import permutations
from collections import namedtuple
import numpy as np
from treelib import Node, Tree
from copy import deepcopy
from itertools import product
import tqdm
from tqdm import tqdm


In [129]:
TICTACTOE_MAP = np.array([[1, 6, 5], [8, 4, 0], [3, 2, 7]])

def tictactoe_map(set_pos):
    return {TICTACTOE_MAP[pos//3,pos%3] for pos in set_pos}

def display(x, o):
    for r in range(3):
        for c in range(3):
            if (r*3+c) in x:
                print("X", end=" ")
            elif (r*3+c) in o:
                print("O", end=" ")
            else:
                print(".", end=" ")
        print()

def won(cells):
    return any(sum(h) == 12 for h in permutations(tictactoe_map(cells), 3))

In [139]:
St = namedtuple('State', ['x', 'o','next_turn']) # x={}, o={}
Ply = namedtuple('Ply', ['turn','pos'])

def static_eval(state):
    if won(state.x):
        return 1
    elif won(state.o):
        return -1
    else:
        return 0
        

In [178]:
N_GAMES = 10_000

In [147]:
states_map={} # map (state, total reward)
              #all possible state in tic-tac-toe are <19000
action_map={}


DEBUG = False
#position of x fit in 5bit variable, same for position of o, ply fit to 1bit variable
#FROM PAPER: "quixo is solved"
def from_xo_to_key(x, o, next_turn):

    key = 0

    for pos in x:
        key += 2**pos
    for pos in o:
        key += 2**(pos+9)
    
    key+= next_turn*(2**18)

    return key

def from_key_to_xo(key): 

    x ={bin_mask for bin_mask in range(9) if (key & 1<<bin_mask)!=0}
    o ={bin_mask for bin_mask in range(9) if (key & 1<<(bin_mask+9))!=0}
    next_turn = (key & (1<<18))>>18
    return St(x, o, next_turn)

def from_board_to_xo(board):
    x = set([])
    o = set([])
    for i in range(board.shape[0]):
        for j in range(board.shape[1]):
            if board[i][j] == 0:
                x |= {i*3+j}
            elif board[i][j] == 1:
                o |= {i*3+j}
        
    return (x, o)

def from_xo_to_board(x, o):
    board = np.ones((3,3),dtype=np.int8)*-1
    for pos in x:
        board[pos//3][pos%3] = 0
    for pos in o:
        board[pos//3][pos%3] = 1
    return board

def calculate_equivalent(x, o, next_turn): #return keys
    
    board = np.ones((3,3), dtype=np.uint8) * -1
    
    for pos in x:
        board[pos//3,pos%3] = 0
    for pos in o:
        board[pos//3,pos%3] = 1
    
    
    xo_set = from_board_to_xo(board)
    xo_set_T = from_board_to_xo(board.T)

    equiv_state_ids = [from_xo_to_key(xo_set[0],xo_set[1], next_turn), from_xo_to_key(xo_set_T[0],xo_set_T[1], next_turn)]

    if DEBUG:
        print(board)
        print("%s" % (np.binary_repr(equiv_state_ids[-2], width=32)))
        print(board.T)
        print("%s" % (np.binary_repr(equiv_state_ids[-1], width=32)))

    for _ in range(3):

        board = np.rot90(board)

        xo_set = from_board_to_xo(board)
        xo_set_T = from_board_to_xo(board.T)

        equiv_state_ids.append(from_xo_to_key(xo_set[0],xo_set[1], next_turn))
        equiv_state_ids.append(from_xo_to_key(xo_set_T[0],xo_set_T[1], next_turn))
        if DEBUG:
            print(board)
            print("%s" % (np.binary_repr(equiv_state_ids[-2], width=32)))
            print(board.T)
            print("%s" % (np.binary_repr(equiv_state_ids[-1], width=32)))

    return equiv_state_ids

def key_mapping(x, o, next_turn):
    #check if an equivalent state is already stored in map
    equiv_states_keys = calculate_equivalent(x, o, next_turn)
    equiv_key = None
    for state_key in equiv_states_keys:
        if state_key in states_map.keys():
            equiv_key = state_key
            break

    if equiv_key != None:
        return equiv_key
    else:
        state_key = from_xo_to_key(x, o, next_turn)
        states_map[state_key] = static_eval(St(x, o, next_turn))
        return state_key

def generate_all_states():
    
    for el in list(product([-1,0,1], repeat=9)):
        n_0 = len([1 for i in el if i == 0])
        n_1 = len([1 for i in el if i == 1])
        if 0<= n_0-n_1 <= 1: 
            next_player = n_0-n_1
            x = set([])
            o = set([])
            for i,xo_ in enumerate(el):
                if xo_ == 0:
                    x = x|{i}
                elif xo_ == 1:
                    o = o|{i}
            key_mapping(x, o, next_player)

            if DEBUG:
                xo = St(x, o, next_player)
                
                key = from_xo_to_key(x, o, next_player)
                xo_again = from_key_to_xo(key)
                assert xo == xo_again, f"wrong xo conversion{xo,key,xo_again}"

                board = from_xo_to_board(x, o)
                xo_again = from_board_to_xo(board)
                assert (x, o) == xo_again, f"wrong board conversion{(x, o),board,xo_again}"

def from_ply_to_key(ply): #4 bits are enought to represent pos, fifth bit is turn 1/0
    return (ply.turn<<4) + ply.pos

def from_key_to_ply(key):
    return Ply(key>>4,key-((key>>4)<<4))

def make_ply(state_xo, ply):

    assert state_xo.next_turn == ply.turn, "error: wrong turn"

    if ply.turn == 0:
        return St(state_xo.x|set({ply.pos}), state_xo.o, (ply.turn+1)%2)
    else:
        return St(state_xo.x, state_xo.o|set({ply.pos}), (ply.turn+1)%2)
    
def generate_all_actions():
    for state_key in states_map.keys():
        
        state_xo=from_key_to_xo(state_key)

        for pos_ply in set(range(9)) - state_xo.x - state_xo.o:

            if DEBUG:
                print(pos_ply)

            ply = Ply(state_xo.next_turn, pos_ply)
            next_state_xo = make_ply(state_xo, ply)
            
            ply_result = None
            for equiv_key in calculate_equivalent(next_state_xo.x,next_state_xo.o,next_state_xo.next_turn):
                ply_result = states_map.get(equiv_key)
                if ply_result!=None:
                    break
                
            assert ply_result != None, f"wrong ply {state_xo, ply, next_state_xo, key}"

            action_map[(state_key,from_ply_to_key(ply))] = ply_result

generate_all_states() #4s 850 states
print(f"states generated: {len(states_map.keys())}")
if DEBUG:
    DEBUG = not DEBUG
    print(states_map)
    # X O X
    # O X O
    # X O X
    print("X WINS, 8 SIMMETRIES")
    next = 1
    x={0,2,4,6,8} 
    o={1,3,5,7}
    equiv=calculate_equivalent(x,o,next)
    for e in equiv:
        print(states_map.get(e))

    # X O O
    # X O X
    # O X X
    print("O WINS")
    next = 1
    x = {0,3,5,7,8}
    o = {1,2,4,6}
    equiv=calculate_equivalent(x,o,next)
    for e in equiv:
        print(states_map.get(e))
    
    # X X O
    # - O X 
    # O O X
    print("O WINS")
    next = 0
    x = {0,1,5,8}
    o = {2,4,6,7}
    equiv=calculate_equivalent(x,o,next)
    for e in equiv:
        print(states_map.get(e))

    # X X O
    # - O X 
    # O O X
    print("INVALID")
    next = 1
    x = {0,1,5,8}
    o = {2,4,6,7}
    equiv=calculate_equivalent(x,o,next)
    for e in equiv:
        print(states_map.get(e))

    # - - X
    # O O X 
    # - - X
    # SIMMETRY (T of) 2 correspondences, X wins
    # - O -
    # - O -
    # X X X
    print("X WINS, 2 SIMMETRIES")
    next = 1
    x = {2,5,8}
    o = {3,4}
    equiv=calculate_equivalent(x,o,next)
    for e in equiv:
        print(states_map.get(e))

    DEBUG = not DEBUG

generate_all_actions()
print(f"actions generated: {len(action_map.keys())}")
if DEBUG:
    for player in [0,1]:
        for pos in range(9):
            ply = Ply(player, pos)
            key = from_ply_to_key(ply)
            ply_again = from_key_to_ply(key)
            assert  ply == ply_again, f"error converting ply{ply,key,ply_again}"

    print(action_map)


states generated: 850
actions generated: 2702


In [181]:
state_0 = St({}, {}, 0)
state = state_0
#not use discount rate


def best_ply(state, player=0, test = False) -> (int, int): 
    """
    default "intelligent" player is X (0), default is training (use_learned = False)

    
    """

    assert len(state.x) + len(state.o) < 9, "cannot move more"
    assert static_eval(state) == 0, "someone already won"

    if test:

        equiv_states = calculate_equivalent(state.x, state.o, state.next_turn)
        
        for n_transform,equiv_state_key in enumerate(equiv_states): # we need to check from which state we start but since not all states are mapped we need to find the equivalent mapped

            if states_map.get(equiv_state_key) != None:

                equiv_state_xo = from_key_to_xo(equiv_state_key) # equiv state have different free positions
                possible_pos = set(range(9)) - equiv_state_xo.x - equiv_state_xo.o

                best = None
                if player == 0:
                    best = -1_000_000_000
                else:
                    best =  1_000_000_000
                
                starting_state_best_ply = (None, None)
                for ply_pos in possible_pos:
                    
                    ply_on_equiv = Ply(equiv_state_xo.next_turn, ply_pos)
                    ply_on_equiv_key = from_ply_to_key(ply_on_equiv)
                    
                    credits = action_map.get((equiv_state_key, ply_on_equiv_key))
                    if credits != None:
                        if (credits>best and player == 0) or (credits<best and player == 1):
                            best = credits
                            starting_state_best_ply = (equiv_state_key, ply_on_equiv_key)
                
                assert starting_state_best_ply != (None, None), "best move must exist"
                
                
                for i in range(4):
                    for j in range(2):
                        if n_transform == i*2+j:
                            if HUMAN_PLAYER:
                                if i != 0 and j != 0:
                                    print(f"rotation:{-i*90}°, transpose: {j}")

                if HUMAN_PLAYER:    
                    print(from_xo_to_board(equiv_state_xo.x,equiv_state_xo.o))
                
                return starting_state_best_ply
        
    else:
        equiv_states = calculate_equivalent(state.x, state.o, state.next_turn)
        
        for n_transform,equiv_state_key in enumerate(equiv_states): # we need to check from which state we start but since not all states are mapped we need to find the equivalent mapped
            
            if states_map.get(equiv_state_key) != None: # ok here

                equiv_state_xo = from_key_to_xo(equiv_state_key) # equiv state have different free positions
                possible_pos = set(range(9)) - equiv_state_xo.x - equiv_state_xo.o

                
                for i in range(4):
                    for j in range(2):
                        if n_transform == i*2+j:
                            if HUMAN_PLAYER:
                                if i != 0 and j != 0:
                                    print(f"rotation:{i*-90}°, transpose: {j}")
                
                if HUMAN_PLAYER:    
                    print(from_xo_to_board(equiv_state_xo.x,equiv_state_xo.o))
                    
                return (equiv_state_key, from_ply_to_key(Ply(equiv_state_xo.next_turn,np.random.choice(list(possible_pos)))))

    
eps = 0.005

def assign_rewards(winner, ply_played_keys):
    
    if winner == -1:
        winner = 0
    elif winner == 0:
        winner = 1
    else:
        winner = -1

    for state_ply_key in enumerate(ply_played_keys):
        ply = from_key_to_ply(state_ply_key[1])
        state = from_key_to_xo(state_ply_key[0])
        
        action_map[state_ply_key] += eps*(winner-action_map[state_ply_key])
            


custom_bar_format = "{l_bar}{bar:50}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]"
progress_bar = tqdm(range(N_GAMES),dynamic_ncols=True,desc="Generation",colour="green",total=N_GAMES,mininterval=0.5,bar_format=custom_bar_format,ncols=100)

HUMAN_PLAYER = False

for game in progress_bar:
    player = 0
    ply_played_keys = []
    winner = -1
    state_ply_xo = St(set(),set(),0)
    
    for turn in range(9): #it finish after 9 ply maximum

        state_ply = best_ply(state_ply_xo, player, False)

        ply_played_keys.append(state_ply)
        state_ply_xo = make_ply(from_key_to_xo(state_ply[0]), from_key_to_ply(state_ply[1]))

        if static_eval(state_ply_xo) != 0:
            winner=player
            break

        player = (player+1)%2
    
    assign_rewards(winner, ply_played_keys)
    

HUMAN_PLAYER = True

wins = 0
draws = 0
for game in range(10000):
    player = 0 #half I start half random start
    winner = -1
    state_ply_xo = St(set(),set(),0)

    print(f"AI turn {game%2}")

    for turn in range(9): #it finish after 9 ply maximum
        
        if (player == (game%2)): #AI system switch player at each game
            state_ply = best_ply(state_ply_xo, player, True)
        else:
            state_ply = best_ply(state_ply_xo, player, False)

        state_ply_xo = make_ply(from_key_to_xo(state_ply[0]), from_key_to_ply(state_ply[1]))

        if static_eval(state_ply_xo) != 0:
            winner=player
            break

        player = (player+1)%2

    if winner == (game%2):
        wins +=1
    elif winner == -1:
        draws +=1
    
    

print(wins/10000)
print(draws/10000)

    

Generation: 100%|[32m██████████████████████████████████████████████████[0m| 10000/10000 [00:27<00:00]


AI turn 0
[[-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
[[-1 -1 -1]
 [-1  0 -1]
 [-1 -1 -1]]
[[-1 -1 -1]
 [-1  0 -1]
 [-1  1 -1]]
[[-1 -1 -1]
 [-1  0 -1]
 [-1  1  0]]
[[-1 -1 -1]
 [-1  0  1]
 [-1  1  0]]
AI turn 1
[[-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
[[-1 -1 -1]
 [-1 -1 -1]
 [-1 -1  0]]
[[-1 -1 -1]
 [-1 -1 -1]
 [-1  1  0]]
[[-1 -1 -1]
 [-1 -1 -1]
 [ 0  1  0]]
rotation:-90°, transpose: 1
[[-1 -1 -1]
 [-1 -1  1]
 [ 0  1  0]]
rotation:-90°, transpose: 1
[[-1 -1  0]
 [ 1 -1 -1]
 [ 0  1  0]]
[[ 0 -1  0]
 [-1 -1  1]
 [ 1  1  0]]
[[ 0 -1  0]
 [ 0 -1  1]
 [ 1  1  0]]
[[ 0 -1  0]
 [ 0  1  1]
 [ 1  1  0]]
AI turn 0
[[-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
[[-1 -1 -1]
 [-1  0 -1]
 [-1 -1 -1]]
[[-1 -1 -1]
 [-1  0 -1]
 [-1  1 -1]]
[[-1 -1 -1]
 [-1  0 -1]
 [-1  1  0]]
[[-1 -1 -1]
 [-1  0  1]
 [ 1 -1  0]]
AI turn 1
[[-1 -1 -1]
 [-1 -1 -1]
 [-1 -1 -1]]
[[-1 -1 -1]
 [-1 -1 -1]
 [-1  0 -1]]
rotation:-180°, transpose: 1
[[-1 -1 -1]
 [ 0 -1  1]
 [-1 -1 -1]]
rotation:-270°, transpose: 1
[[-1 -1 -1]
 [ 0 -1  1]
 [-1  0 