In [1]:
import torch
from torch import nn

In [2]:
def create_board():
    us = torch.zeros(3, 3)
    them = torch.zeros(3, 3)
    return torch.stack([us, them])

def get_move_coords(move_index):
    return move_index % 3, move_index // 3

def is_valid_move(state, move_index):
    if move_index < 0 or move_index >= 9:
        return False
    x, y = get_move_coords(move_index)
    return torch.all(state[:, y, x] == 0).item()

def make_move(state, move_index):
    assert(is_valid_move(state, move_index))
    x, y = get_move_coords(move_index)
    us, them = state.unbind()
    us = us.clone()
    us[y, x] = 1
    return torch.stack([them, us])

def has_won(ply):
    for x in range(3):
        if torch.sum(ply[:, x]) == 3:
            return True
    for y in range(3):
        if torch.sum(ply[y, :]) == 3:
            return True
    if sum([ply[i, i].item() for i in range(3)]) == 3:
        return True
    if sum([ply[2 - i, i].item() for i in range(3)]) == 3:
        return True
    return False

def get_score(state):
    us, them = state.unbind()
    if has_won(us):
        return 1 
    if has_won(them):
        return -1
    return 0

def is_complete(state):
    if torch.sum(state, (-3, -2, -1)) == 9:
        return True
    return get_score(state) != 0

In [3]:
state = create_board()

In [4]:
is_valid_move(state, 6)

True

In [5]:
state = make_move(state, 4)

In [6]:
state

tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 1., 0.],
         [0., 0., 0.]]])

In [7]:
get_score(state)

0

In [8]:
state = make_move(state, 0)

In [9]:
state = make_move(state, 3)

In [10]:
state

tensor([[[1., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [1., 1., 0.],
         [0., 0., 0.]]])

In [11]:
state = make_move(state, 1)

In [12]:
get_score(state)

0

In [13]:
state = make_move(state, 5)

In [14]:
get_score(state)

-1

In [15]:
state

tensor([[[1., 1., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [1., 1., 1.],
         [0., 0., 0.]]])

In [16]:
is_complete(state)

True