In [138]:
import torch
import itertools
from tqdm import tqdm
import os
import torch.nn.functional as F
import math

In [139]:
device = ("cuda" if torch.cuda.is_available() else "cpu")
device 

'cuda'

Generating all possible combinations of a 3x3 player board with -1,0,1 as values. -1 and 1 are the players and 0 is an empty spot. The number of combinations generated match how many there should be mathematically.

In [140]:
def generate_all_matrix_combinations(n):
    possible_values = [-1, 0, 1]
    combinations = itertools.product(possible_values, repeat=n*n)
    matrices = []
    for combination in combinations:
        tensor = torch.tensor(combination).view(n, n)
        matrices.append(tensor)
    return matrices

all_matrix_combinations = torch.stack(generate_all_matrix_combinations(3))
(all_matrix_combinations.shape[0]) == 3 ** 9

True

Here we decide that player 1 always goes first. Therefore there should be either one more 1 than -1 or an equal number of them, since the players take turns.Therefore the sum of all values of each board/state should be 0 or 1. Therefore board combinations that do not follow this are invalid and will never be reached in a legal game.

In [141]:
sums = all_matrix_combinations.sum(dim=1).sum(dim=1)
bi_turns_matrxies = all_matrix_combinations[ (sums == 0) | (sums == 1)]
bi_turns_matrxies.shape

torch.Size([6046, 3, 3])

In [142]:
bi_turns_matrxies[0]

tensor([[-1, -1, -1],
        [-1,  0,  1],
        [ 1,  1,  1]])

There are no legal states with both players winning. Therefore they can also be removed.

In [143]:
streaks = torch.cat((bi_turns_matrxies.sum(dim = 1), bi_turns_matrxies.sum(dim = 2)), dim = 1)
pos_won = torch.any(streaks == 3, dim = 1)
neg_won = torch.any(streaks == -3, dim = 1)
~(pos_won & neg_won), streaks

(tensor([False, False,  True,  ...,  True, False, False]),
 tensor([[-1,  0,  1, -3,  0,  3],
         [-1,  1,  0, -3,  0,  3],
         [-2,  1,  1, -3,  1,  2],
         ...,
         [ 2,  0, -1,  3,  0, -2],
         [ 1,  0,  0,  3,  1, -3],
         [ 1,  1, -1,  3,  1, -3]]))

In [144]:
final_matrxies = bi_turns_matrxies[~(pos_won & neg_won)]
final_matrxies.shape

torch.Size([5890, 3, 3])

The first method creates a (n,n,2,3,3) tensor assuming that there are n valid boards/states. It contains every combination possible of the valid states.
The second one finds which pairs one can legally move from the first to the second state. There are many conditions for this, such as one piece must be placed on an empty spot, no pieces can be removed and the starting state can not have a win, as the game ends then and no more moves can legally be made.

In [145]:
def make_combi(game_states_tensor):
    size =game_states_tensor.shape[0]
    game_size = game_states_tensor.shape[-1]
    a = game_states_tensor.to(device).repeat(size,1,1).reshape(size,size,game_size,game_size)
    b = game_states_tensor.to(device).repeat(size,1,1).reshape(size,size,game_size,game_size).permute(1,0,2,3)
    return a, b
def gather_edges(parent, child):
    parent = parent.to(device)
    child = child.to(device)
    add_one_pice_on_empty = torch.sum((parent == 0) & ((child == 1) | (child == -1)), dim = (-2,-1)) == 1
    one_diff = torch.sum(~(parent == child), dim = (-2,-1)) == 1

    streaks = torch.abs((torch.cat((parent.sum(dim = -2), parent.sum(dim = -1)), dim = -1)))
    row_or_col_win = torch.any(streaks == 3, dim = -1)
    main_diagonal = parent[:,:, torch.arange(3), torch.arange(3)]
    anti_diagonal = parent[:,:, torch.arange(2, -1, -1), torch.arange(3)]
    s1 = main_diagonal.sum(dim=-1).unsqueeze(-1)
    s2 =  anti_diagonal.sum(dim=-1).unsqueeze(-1)
    diagonal_streaks = torch.abs((torch.cat((s1,s2), dim=-1)))
    diagonal_win = torch.any(diagonal_streaks == 3, dim = -1) 
    parent_win = row_or_col_win | diagonal_win

    mask = one_diff & add_one_pice_on_empty & (~parent_win)

    edges = torch.stack((parent[mask], child[mask]),dim = 1)
    return edges

16167 edges was found

In [146]:
torch.cuda.empty_cache()
edges = gather_edges(*make_combi(final_matrxies))
edges.shape

torch.Size([16167, 2, 3, 3])

In [147]:
def to_tuple(t):

    return to_tuple_prim(t.tolist())

def to_tuple_prim(ls):

    if type(ls[0]) != list:
        return tuple(ls)
    
    return tuple([to_tuple_prim(r) for r in ls])


Here we translate the states into ids, since we mostly only need to a represent of the states, not their actual data

In [148]:
matrix2id = {to_tuple(m):i for i, m in enumerate(final_matrxies)}
id2matrix = {i:to_tuple(m) for i, m in enumerate(final_matrxies)}
matrix2id

{((-1, -1, -1), (-1, 1, 1), (0, 1, 1)): 0,
 ((-1, -1, -1), (-1, 1, 1), (1, 0, 1)): 1,
 ((-1, -1, -1), (-1, 1, 1), (1, 1, 0)): 2,
 ((-1, -1, -1), (0, 0, 1), (0, 1, 1)): 3,
 ((-1, -1, -1), (0, 0, 1), (1, 0, 1)): 4,
 ((-1, -1, -1), (0, 0, 1), (1, 1, 0)): 5,
 ((-1, -1, -1), (0, 1, 0), (0, 1, 1)): 6,
 ((-1, -1, -1), (0, 1, 0), (1, 0, 1)): 7,
 ((-1, -1, -1), (0, 1, 0), (1, 1, 0)): 8,
 ((-1, -1, -1), (0, 1, 1), (-1, 1, 1)): 9,
 ((-1, -1, -1), (0, 1, 1), (0, 0, 1)): 10,
 ((-1, -1, -1), (0, 1, 1), (0, 1, 0)): 11,
 ((-1, -1, -1), (0, 1, 1), (0, 1, 1)): 12,
 ((-1, -1, -1), (0, 1, 1), (1, -1, 1)): 13,
 ((-1, -1, -1), (0, 1, 1), (1, 0, 0)): 14,
 ((-1, -1, -1), (0, 1, 1), (1, 0, 1)): 15,
 ((-1, -1, -1), (0, 1, 1), (1, 1, -1)): 16,
 ((-1, -1, -1), (0, 1, 1), (1, 1, 0)): 17,
 ((-1, -1, -1), (1, -1, 1), (0, 1, 1)): 18,
 ((-1, -1, -1), (1, -1, 1), (1, 0, 1)): 19,
 ((-1, -1, -1), (1, -1, 1), (1, 1, 0)): 20,
 ((-1, -1, -1), (1, 0, 0), (0, 1, 1)): 21,
 ((-1, -1, -1), (1, 0, 0), (1, 0, 1)): 22,
 ((-1, -1, -

Here we calculate which states are predecessors (parents) to others and which are next states (children) of others. This will be useful for the MCTS.

In [149]:
children_ids = {}
parents_ids = {}
for r in edges:
    parent = matrix2id[to_tuple(r[0])]
    child = matrix2id[to_tuple(r[1])]

    children = children_ids.get(parent, [])
    children.append(child)
    children_ids[parent] = children

    parents = parents_ids.get(child, [])
    parents.append(parent)
    parents_ids[child] = parents


children_ids

{59: [0, 56],
 409: [0, 406],
 1750: [0, 1747],
 62: [1, 60],
 412: [1, 410],
 1753: [1, 1751],
 64: [2, 63],
 414: [2, 413],
 1755: [2, 1754],
 78: [3, 47, 66, 75],
 428: [3, 397, 416, 425],
 1769: [3, 1738, 1757, 1766],
 81: [4, 48, 67, 79],
 431: [4, 398, 417, 429],
 1772: [4, 1739, 1758, 1770],
 83: [5, 49, 68, 82],
 433: [5, 399, 418, 432],
 1774: [5, 1740, 1759, 1773],
 91: [6, 52, 84, 88],
 441: [6, 402, 434, 438],
 1782: [6, 1743, 1775, 1779],
 94: [7, 53, 85, 92],
 444: [7, 403, 435, 442],
 1785: [7, 1744, 1776, 1783],
 96: [8, 54, 86, 95],
 446: [8, 404, 436, 445],
 1787: [8, 1745, 1777, 1786],
 99: [9, 56],
 449: [9, 406],
 1790: [9, 1747],
 102: [10, 57, 97, 100],
 452: [10, 407, 447, 450],
 1793: [10, 1748, 1788, 1791],
 104: [11, 58, 98, 103],
 454: [11, 408, 448, 453],
 1795: [11, 1749, 1789, 1794],
 106: [13, 60],
 456: [13, 410],
 1797: [13, 1751],
 108: [14, 61, 105, 107],
 458: [14, 411, 455, 457],
 1799: [14, 1752, 1796, 1798],
 109: [16, 63],
 459: [16, 413],
 1800

Calculates if a player (1 or -1) has won with a given board.

In [150]:
def win_states(board, pice):
    # Check rows
    for row in board:
        if row[0] == row[1] == row[2] == pice:
            return True
    
    # Check columns
    for col in board.t():
        if col[0] == col[1] == col[2] == pice:
            return True
    
    # Check main diagonal
    if board[0][0] == board[1][1] == board[2][2] == pice:
        return True
    
    # Check anti-diagonal
    if board[0][2] == board[1][1] == board[2][0] == pice:
        return True
    
    return False

Predictively, the only state without parents is the starting state.

In [151]:
start = list(set(children_ids.keys()) - set(parents_ids.keys()))
assert len(start) == 1
id2matrix[start[0]]

((0, 0, 0), (0, 0, 0), (0, 0, 0))

Modified UCB1. Since visits start at 0 we increase it by 1 to avoid undefined behavior.

In [152]:
def UCB1(s,v,c,pv):
    return s / (v + 1) + c * math.sqrt(math.log(pv + 1) / (v + 1))

MCTS. Works its way down the graph by picking the child node with the highest UCB1 score, and so one. All these nodes are recorded and once a terminal state is reach (it has no children), it's score is calculated and propagated up through the graph as well as increasing the visited count for each. Reward is given for a player 1 win and loss for if player -1 wins. Tie is neutral.

In [153]:
scores = {}
visits = {}
iterations = 50000
c_constant = 10

for i in tqdm(range(iterations)):
    parent = start[0]
    children = children_ids[parent]
    node_path = [parent]
    while True:
        ucb1_scores = torch.tensor([UCB1(scores.get(c, 0), visits.get(c, 0), c_constant, visits.get(parent, 0)) for c in children])
        choosen_child = torch.argmax(ucb1_scores)
        parent = children[choosen_child]
        node_path.append(parent)
        children = children_ids.get(parent, None)
        if children == None:
            break
    board = torch.tensor(id2matrix[parent])
    if win_states(board, 1):
        score = 1
    elif win_states(board, -1):
        score = -1
    else:
        score = 0
    for node in node_path:
        scores[node] = scores.get(node, 0) + score
        visits[node] = visits.get(node, 0) + 1



100%|██████████| 50000/50000 [00:24<00:00, 2062.10it/s]


This method uses the scores and visited values for each child of the given state to weight the random choosing of one of them. The scores are multiplied with ai_pice. This is done since the rewards are tuned to player 1. Since it is a zero sum game, the best move player -1 can perform is that which is the most disadvantageous for player 1. Therefore if the score is flipped it can be used as score for player -1. After, the scores has its distribution of the scores shifted such that the scores which are below average will be below 1 and the ones over will be over 1. Therefore, when high power is applied, the score increases or decreases proportional to their distance to 1 (the mean). Therefore the high probabilities will become more prominent and the lower and will fade away. Or one could say that this increases the variance in the distribution. Therefore, the higher the power the closer the production policy will be to just picking the best, and the smaller it gets the closer it will be to uniformly random. We have used this to create a difficult scale.

In [154]:
def ai_make_move(state, ai_pice, difficulty = 24):
    state_id = matrix2id[to_tuple(state)]
    children = children_ids[state_id]
    child_scores = (torch.tensor([scores[c] * ai_pice / visits[c] for c in children]) + 1)
    exp = (child_scores - torch.mean(child_scores) + 1)
    child_scores_high_var = torch.pow(exp, difficulty)
    normalized_scores = child_scores_high_var / torch.sum(child_scores_high_var)  
    choosen_child = torch.multinomial(normalized_scores, num_samples=1).item()
    return torch.tensor(id2matrix[children[choosen_child]])
ai_make_move(torch.tensor(((0, 0, 0), (0, 0, 0), (0, 0, 0))), 1,)

tensor([[0, 0, 0],
        [0, 1, 0],
        [0, 0, 0]])

Play by entering your move, 02 = row 0 column 2.
Adjust difficult to your liking.

In [159]:

def play(difficulty = 24, player_start = False):
    state = torch.zeros((3,3))
    
    while True:
        if player_start:
            print(state)
            y,x = tuple([int(c) for c in input("Make move:")])
            state[y,x] = 1
            if children_ids.get(matrix2id[to_tuple(state)], None) == None:
                print(state)
                break
            state = ai_make_move(state, -1, difficulty)
            if children_ids.get(matrix2id[to_tuple(state)], None) == None:
                print(state)
                break
        else:
            
            state = ai_make_move(state, 1, difficulty) 
            if children_ids.get(matrix2id[to_tuple(state)], None) == None:
                break
            print(state)
            y,x = tuple([int(c) for c in input("Make move:")])
            state[y,x] = -1
            if children_ids.get(matrix2id[to_tuple(state)], None) == None:
                break
    print(state)
    if win_states(state, 1) and player_start:
        print("Player won")
    elif win_states(state, 1) and not player_start:
        print("AI won")
    elif win_states(state, -1) and player_start:
        print("AI won")
    elif win_states(state, -1) and not player_start:
        print("Player won")
    else:
        print("Tie")
play(difficulty=30)

tensor([[0, 0, 0],
        [0, 1, 0],
        [0, 0, 0]])
tensor([[ 1,  0, -1],
        [ 0,  1,  0],
        [ 0,  0,  0]])
tensor([[ 1,  0, -1],
        [ 0,  1,  1],
        [ 0,  0, -1]])
tensor([[ 1,  1, -1],
        [-1,  1,  1],
        [ 0,  0, -1]])
tensor([[ 1,  1, -1],
        [-1,  1,  1],
        [ 1, -1, -1]])
Tie
