# Ultimate-Tic-Tac-Toe-Zero Model

In [1]:
import os
import sys

import numpy as np
import torch
import torchx
from multiprocessor.multiprocessor import Multiprocessor

In [2]:
sys.path.append("../../UTTT/Python")

In [3]:
from MCTS import MCTS
from UTTT import UTTT

## Environment

In [4]:
def MCTS_numpy_board(m: MCTS):
    board = np.zeros((9, 3, 3))
    for i in range(9):
        for j in range(9):
            p = m[i, j]
            if p == UTTT.P1:
                p = 1
            elif p == UTTT.P2:
                p = -1
            else:
                p = 0
                
            board[i, j//3, j%3] = p
            
    return board

In [15]:
def select_expand(games):
    return [g.select().expand() for g, m in games]

class Environment:
    def __init__(self, num_games=1000, num_iterations=400):
        self.num_games = num_games
        self.num_iterations = num_iterations
        
    def reset(self):
        self.games = [(MCTS(), []) for i in range(self.num_games)]
        self.finished_games = []
        self.iteration = 0
        self.leaves = None
        
    def __enter__(self):
        self.reset()
        return self
    
    def __bool__(self):
        return hasattr(self, "games") and len(self.games) > 0
    
    def __next__(self):
        self.iteration += 1
#         self.leaves = [
#             l
#             for leaves in Multiprocessor(cpus=10).process(select_expand, self.games)
#             for l in leaves 
#         ]
        self.leaves = select_expand(self.games)
        return np.array([MCTS_numpy_board(l) for l in self.leaves])
            
    def step(self, winners):
        assert len(winners) == len(self.leaves)
        for leaf, winner in zip(self.leaves, winners):
            leaf.backprop(winner)
            
        self.leaves = None
        
        if self.iteration > self.num_iterations:
            self.iteration = 0
            for game, moves in self.games:
                moves.append(game.make_move())
                
            self.finished_games.extend([(game, moves) for game, moves in self.games if game.winner != UTTT.N])
            self.games = [(game, moves) for game, moves in self.games if game.winner == UTTT.N]

In [16]:
# %%time
# env = Environment(num_games=25, num_iterations=400)
# env.reset()
# i = 0
# while env:
#     boards = next(env)
#     winners = np.random.choice((UTTT.T, UTTT.P1, UTTT.P2), size=len(env.leaves))
#     env.step(winners)

CPU times: user 26.4 s, sys: 0 ns, total: 26.4 s
Wall time: 26.4 s


## Model

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def to_device(obj):
    return obj.to(device)

In [9]:
class DeepUTTT(torchx.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.block1 = torch.nn.Sequential(
            torchx.nn.Conv2dBatch(9, 64, kernel_size=1),
            torchx.nn.Conv2dBatch(64, 128, kernel_size=1),
            torchx.nn.Conv2dBatch(128, 256, kernel_size=2),
        )
        self.block2 = torch.nn.Sequential(
            *[
                torch.nn.Sequential(
                    torchx.nn.Conv2dBatch(256, 512, kernel_size=1),
                    torchx.nn.Conv2dBatch(512, 256, kernel_size=1),
                )
                for i in range(3)
            ],
            torchx.nn.Conv2dBatch(256, 512, kernel_size=1),
            torchx.nn.Conv2dBatch(512, 1024, kernel_size=2),
        )
        self.winner_predict_block = torch.nn.Sequential(
            torch.nn.Linear(1024, 256, bias=False),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(256, 3, bias=False),
        )
        
        self.reset_parameters()
        
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = torch.flatten(x, 1)
        return self.winner_predict_block(x)
    
model = DeepUTTT()
model = to_device(model)

In [11]:
model.num_params()

3430272

In [10]:
data = torch.from_numpy(np.random.randint(-1, 2, size=(512, 9, 3, 3))).float()
data = to_device(data)

In [12]:
%%timeit
output = model.forward(data)

2.37 ms ± 6.93 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
