# Imports and shit

In [2]:
import numpy as np
import gymnasium as gym
from gymnasium import spaces
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import time
import sys
import threading
from concurrent import futures
from asyncio import locks

sys.setrecursionlimit(10**5)
torch.manual_seed(1337)
np.random.seed(1337)
colab_env = 'google.colab' in sys.modules
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device", device)
print("In colab:", colab_env)

Using device cuda:0
In colab: False


In [3]:
learning_rate = 0.001
batch_size = 100

searches = 16
cpuct = 1.0
l2_weight = 1.0

In [4]:
def get_actions(state):
    actions = []
    for i in range(3):
        for j in range(3):
            if state[i, j] == 0:
                actions.append((i, j))
    return actions

def check_win(state, player):
    # Check rows
    for i in range(3):
        if np.all(state[i] == player):
            return True
    # Check columns
    for i in range(3):
        if np.all(state[:,i] == player):
            return True
    # Check diagonals
    if np.all(np.diag(state) == player) or np.all(np.diag(np.fliplr(state)) == player):
        return True
    return False

def check_terminal(state):
    return (check_win(state, 1), check_win(state, -1))

In [5]:
class TicTacToeEnv(gym.Env):
    metadata = {'render.modes': ['human']}

    def __init__(self):
        self.state = torch.zeros((3,3), dtype=torch.float32, device=device)
        self.turn = 1
        self.action_space = spaces.MultiDiscrete([3, 3])
        self.observation_space = spaces.Box(low=-1, high=2, shape=(3,3), dtype=np.float32)
    
    def get_actions(self):
        actions = []
        for i in range(3):
            for j in range(3):
                if self.state[i, j] == 0:
                    actions.append((i, j))
        return actions

    def step(self, action):
        if self.state[action[0], action[1]] != 0:
            return self.state, 0, False, {"message": "Invalid move"}
        self.state[action[0], action[1]] = self.turn
        if self.turn == 1:
            self.turn = -1
        else:
            self.turn = 1
        
        if self._check_win(1):
            return self.state, 1, True, {"message": "Player 1 wins"}
        elif self._check_win(-1):
            return self.state, -1, True, {"message": "Player 2 wins"}
        elif np.count_nonzero(self.state) == 9:
            return self.state, 0, True, {"message": "Draw"}
        else:
            return self.state, 0, False, {"message": "Next turn"}
    
    def reset(self):
        self.state = np.zeros((3,3), dtype=np.float32)
        self.turn = 1
        return self.state
    
    def render(self, mode='human'):
        print("Current state:")
        print(self.state)
    
    def _check_win(self, player):
        return check_win(self.state, player)

In [6]:
def array_to_board(array):
    board = ''
    for i in range(3):
        for j in range(3):
            if array[i, j] == 1:
                board += ' X '
            elif array[i, j] == -1:
                board += ' O '
            else:
                board += '   '
            if j < 2:
                board += '|'
        board += '\n'
        if i < 2:
            board += '---+---+---\n'
    return board

# MCTS

In [12]:
class Node():
    def __init__(self, state:np.ndarray, parent):
        self.is_leaf = True
        self.state = state.copy()
        self.is_terminal = check_terminal(state)
        self.parent = parent
        self.visits = 0
        self.num_backproped = 0
        self.prior_probability = 0
        self.total_value = 0
        self.children = []
        self.lock = locks.Lock()

        if parent != None:
            parent.children.append(self)
    
    def get_uct(self):
        u = cpuct * self.prior_probability / (1 + self.visits)
        q = math.sqrt(self.total_value) / (1 + self.num_backproped)
        return u + q
    
    def get_actions(self):
        return get_actions(self.state)
    
    def backprop(self, value):
        self.total_value += value
        self.num_backproped += 1
        if self.parent != None:
            self.parent.backprop(value)

# because model used is the same, we don't need to provide opponent

def search(root:Node, model, turn):
    leaf = root
    leaf.visits += 1
    while not leaf.is_leaf:
        maxxed = -10000
        maxxed_node = None
        for i in leaf.children:
            if maxxed < i.get_uct():
                maxxed_node = i
                maxxed = i.get_uct()
        leaf = maxxed_node
        leaf.visits += 1
        turn *= -1
    
    while True:
        player1_win, player2_win = leaf.is_terminal
        if player1_win or player2_win or len(get_actions(leaf.state)) == 0:
            break

        leaf.is_leaf = False

        tmp = None
        tmp_uct = -100000

        turn *= -1
        prior_now, _ = model(torch.tensor(leaf.state, device=device) * turn)

        for x, y in leaf.get_actions():
            new = Node(state=leaf.state, parent=leaf)
            new.state[x][y] = turn

            _, val = model(torch.tensor(new.state, device=device) * turn) # TODO add cache for parralelization (I spelled that wrong L activity)
            val = F.softmax(val, dim=-1)

            new.prior_probability = prior_now[x][y]
            new.backprop(val[0].item())

            if tmp_uct < new.get_uct():
                tmp = new
                tmp_uct = tmp.get_uct()

        leaf = tmp
        leaf.visits += 1

def mcts(env, root:Node, model):
    state = torch.tensor(root.state, device=device)

    prior, _ = model(state)

    if len(root.children) == 0:
        for i, (x, y) in enumerate(root.get_actions()):
            search_node = Node(root.state, root)
            search_node.state[x][y] = env.turn
            search_node.prior_probability = prior[x][y].item()
            search_node.total_value = model(torch.tensor(search_node.state * env.turn, device=device))[1][0]

    with futures.ThreadPoolExecutor(max_workers=9) as t:
        s = 0
        threads = []
        while s < searches:
            threads.append(t.submit(
                search(root.children[s % len(root.children)], model, env.turn)
            ))
            s += 1
        futures.as_completed(threads)

# Model

In [146]:
class Block(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc = nn.Linear(9, 9)
    
    def forward(self, x):
        temp = F.leaky_relu(self.fc(x))
        return x + temp

class TicTacToeModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.net = nn.Sequential(
            *[Block() for i in range(1)]
        )
        self.policy = nn.Sequential(
            *[Block() for i in range(1)]
        )
        self.value = nn.Sequential(
            Block(),
            nn.Linear(9, 3)
        )

    def forward(self, x):
        x = x.view(torch.prod(torch.tensor(x.size())))
        temp = self.net(x)
        pol = self.policy(temp).view(3, 3)
        val = self.value(temp)

        mask_bool = (x == 0).view(3, 3)
        logits_masked = torch.where(mask_bool, pol, 0)
        return logits_masked, val
    
    def predict(self, x, multinomial=True):
        # returns a tuple of ((x, y), value)
        # value is softmaxxed
        policy, value = self.forward(x)
        policy = policy.view(9)
        prob = F.softmax(policy, dim=-1)
        pred = torch.argmax(prob, dim=-1).item()
        if multinomial:
            pred = torch.multinomial(prob, num_samples=1).item()
        return (pred // 3, pred % 3), F.softmax(value, dim=-1)

In [149]:
env = TicTacToeEnv()
env.reset()

model = TicTacToeModel().to(device=device)

optimizer = optim.Adam(model.parameters(), lr = 0.001)
loss_fn = nn.CrossEntropyLoss()

def learn(buffer):
    av = 0
    for x in buffer:
        state, q, v = x
        policy, val = model(state)
        
        policy_loss = loss_fn(policy, q)
        val_loss = loss_fn(val, v)
        loss = policy_loss + val_loss
        
        optimizer.zero_grad()
        loss.backward()
        av += loss.item()
        optimizer.step()
    print(av / len(buffer))

for i in range(100):
    buffer = []
    done = False
    del root
    root = Node(env.state, None)

    while not done:
        mcts(env, root, model)
        maxxed, maxxed_uct = None, -10000
        soft = np.array(F.softmax(torch.tensor([k.get_uct() for k in root.children])).tolist())
        soft /= soft.sum()
        maxxed = max(root.children, key=lambda x: x.get_uct())
        choice = np.random.choice(root.children, 1, p=soft)[0]

        x0, y0 = np.where((maxxed.state - env.state) == env.turn) 
        x, y = np.where((choice.state - env.state) == env.turn)
        state, rew, done, meta = env.step((x.item(), y.item()))
        root = choice
        root.parent = None

        sparsed = torch.zeros(size=(3, 3), device="cpu", dtype=torch.float32)
        sparsed[x0.item()][y0.item()] = 1.0
        sparsed = sparsed.to(device=device)

        buffer.append([
            torch.tensor(env.state, device=device) * env.turn * -1,
            sparsed,
            torch.zeros(size=(3,), dtype=torch.float32, device="cpu")
        ])

        if i % 10 == 0:
            print(array_to_board(state.astype(np.int16)))

    temp_turn = 0 if rew == 1 else 1
    if rew > 0:
        for i in range(len(buffer)):
            buffer[i][-1][temp_turn] = 1.0
            temp_turn = 0 if temp_turn == 1 else 1
            buffer[i][-1] = buffer[i][-1].to(device=device)
    else:
        for i in range(len(buffer)):
            buffer[i][-1][-1] = 1.0
            buffer[i][-1] = buffer[i][-1].to(device=device)
            
    for i in buffer:
        pass
    learn(buffer)

    env.reset()



tensor([[-0.0030, -0.0022,  0.2602],
        [ 0.2164,  0.2019,  0.0380],
        [ 0.2038,  0.2703, -0.0028]], device='cuda:0',
       grad_fn=<WhereBackward0>)
tensor([[ 0.0000,  0.1551,  0.1387],
        [ 0.4883,  0.4135,  0.1213],
        [ 0.2269,  0.8872, -0.0052]], device='cuda:0',
       grad_fn=<WhereBackward0>)
tensor([[-0.0029,  0.0000,  0.0422],
        [ 0.4568,  0.2957, -0.0014],
        [ 0.3547,  0.6973,  0.2972]], device='cuda:0',
       grad_fn=<WhereBackward0>)
tensor([[-0.0068, -0.0016,  0.0000],
        [ 0.1895,  0.4695,  0.5999],
        [ 0.0016,  0.1790,  0.0731]], device='cuda:0',
       grad_fn=<WhereBackward0>)
tensor([[-0.0042,  0.1569, -0.0029],
        [ 0.0000,  0.1439,  0.0790],
        [ 0.4446, -0.0028, -0.0054]], device='cuda:0',
       grad_fn=<WhereBackward0>)
tensor([[ 0.2552,  0.0960,  0.3845],
        [-0.0013,  0.0000,  0.3871],
        [ 0.4115,  0.4635,  0.0962]], device='cuda:0',
       grad_fn=<WhereBackward0>)
tensor([[ 0.2157,  0.1465,  

  soft = np.array(F.softmax(torch.tensor([k.get_uct() for k in root.children])).tolist())


tensor([[ 0.0000, -0.0028,  0.5512],
        [ 0.1952,  0.3835,  0.2657],
        [-0.0038,  0.0000,  0.0000]], device='cuda:0',
       grad_fn=<WhereBackward0>)
tensor([[ 0.0000, -0.0065,  0.1643],
        [ 0.6553,  0.6165,  0.0000],
        [ 0.3578,  0.6286,  0.0000]], device='cuda:0',
       grad_fn=<WhereBackward0>)
tensor([[ 0.0000,  0.0000, -0.0069],
        [ 0.9366,  0.4845,  0.0000],
        [ 0.6041,  1.0634,  0.0000]], device='cuda:0',
       grad_fn=<WhereBackward0>)
tensor([[0.0000, 0.0286, 0.0000],
        [0.4426, 0.6903, 0.0000],
        [0.2447, 0.4813, 0.0000]], device='cuda:0', grad_fn=<WhereBackward0>)
tensor([[ 0.0000, -0.0054, -0.0040],
        [ 0.0000,  0.4855,  0.0000],
        [ 0.6007,  0.1821,  0.0000]], device='cuda:0',
       grad_fn=<WhereBackward0>)
tensor([[0.0000, 0.0342, 0.3213],
        [0.4429, 0.0000, 0.0000],
        [0.5991, 0.9241, 0.0000]], device='cuda:0', grad_fn=<WhereBackward0>)
tensor([[0.0000, 0.0551, 0.0443],
        [1.1330, 0.4554, 0

KeyboardInterrupt: 

In [139]:
t100 = 1

In [147]:



s = torch.tensor([
    [0, 1, 1],
    [0, -1, 0],
    [0, 0, -1]
], device=device, dtype=torch.float32)

stated = s * ((t100 % 2) * 2 - 1)
print(t100)
t100 += 1
print(stated)
print(model.predict(stated, multinomial=False))


7
tensor([[ 0.,  1.,  1.],
        [ 0., -1.,  0.],
        [ 0.,  0., -1.]], device='cuda:0')
((1, 2), tensor([0.2798, 0.2413, 0.4790], device='cuda:0', grad_fn=<SoftmaxBackward0>))
