In [22]:
import numpy as np
print(np.__version__)
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

1.22.3


In [26]:
class TicTacToe:
    def __init__(self):
        self.row_count = 3
        self.column_count = 3
        self.action_size = self.row_count * self.column_count

    def get_initial_state(self):
        return np.zeros((self.row_count, self.column_count))
    
    def get_next_state(self, state, action, player):
        # Actions are 0-8, 0 is top left corner, 1 is top middle, ... 8 is bottom right
        row = action // self.column_count
        column = action % self.column_count
        state[row, column] = player
        return state
    
    def get_valid_moves(self, state):
        return (state.reshape(-1) == 0).astype(np.uint8)
    
    def check_win(self, state, action):
        if (action == None):
            return False
        
        row = action // self.column_count
        column = action % self.column_count
        player = state[row, column]

        # Checking all possible win conditions
        return (
            # three in a row
            np.sum(state[row, :]) == player * self.column_count or
            # three in column
            np.sum(state[ :, column]) == player * self.row_count or
            # left diagonal
            np.sum(np.diag(state)) == player * self.row_count or
            # right diagonal
            np.sum(np.diag(np.flip(state, axis=0))) == player * self.row_count
        )
    
    def get_value_and_terminated(self, state, action):
        if (self.check_win(state, action)):
            return 1, True
        # Checking for draw
        if (np.sum(self.get_valid_moves(state)) == 0):
            # no possible moves
            return 0, True
        # No winner, and not terminated
        return 0, False
    
    def get_opponent(self, player):
        return -player
    
    def get_opponent_value(self, value):
        return -value
    
    def change_perspective(self, state, player):
        return state * player
    
    def get_encoded_state(self, state):
        encoded_state = np.stack(
            (state == -1, state == 0, state == 1)
        ).astype(np.float32)
        return encoded_state

In [27]:
class ResNet(nn.Module):
    def __init__(self, game, num_resBlocks, num_hidden):
        super().__init__()
        self.startBlock = nn.Sequential(
            nn.Conv2d(3, num_hidden, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_hidden),
            nn.ReLU()
        )

        self.backBone = nn.ModuleList(
            [ResBlock(num_hidden) for i in range(num_resBlocks)]
        )

        self.policyHead = nn.Sequential(
            nn.Conv2d(num_hidden, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * game.row_count * game.column_count, game.action_size)
        )

        self.valueHead = nn.Sequential(
            nn.Conv2d(num_hidden, 3, kernel_size=3, padding=1),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3 * game.row_count * game.column_count, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.startBlock(x)
        for resBlock in self.backBone:
            x = resBlock(x)
        policy = self.policyHead(x)
        value = self.valueHead(x)
        return (policy, value)

class ResBlock(nn.Module):
    def __init__(self, num_hidden):
        super().__init__()
        self.conv1 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_hidden)
        self.conv2 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_hidden)

    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        x = F.relu(x)
        return x

In [28]:
tictactoe = TicTacToe()

state = tictactoe.get_initial_state()
state = tictactoe.get_next_state(state, 2, 1)
state = tictactoe.get_next_state(state, 7, -1)

print(state)

encoded_state = tictactoe.get_encoded_state(state)

print(encoded_state)

tensor_state = torch.tensor(encoded_state).unsqueeze(0)

model = ResNet(tictactoe, 4, 64)

policy, value = model(tensor_state)
value = value.item()
policy = torch.softmax(policy, axis=1).squeeze(0).detach().cpu().numpy()

print(value, policy)

[[ 0.  0.  1.]
 [ 0.  0.  0.]
 [ 0. -1.  0.]]
[[[0. 0. 0.]
  [0. 0. 0.]
  [0. 1. 0.]]

 [[1. 1. 0.]
  [1. 1. 1.]
  [1. 0. 1.]]

 [[0. 0. 1.]
  [0. 0. 0.]
  [0. 0. 0.]]]


RuntimeError: Could not infer dtype of numpy.float32

In [3]:
## Monte Carlo Tree Search
class Node:
    def __init__(self, game, args, state, parent=None, action_taken=None):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken

        self.children = []
        # possible child we could expand an add
        # if we have any expandable moves, we are a "leaf"
        self.expandable_moves = game.get_valid_moves(state)

        self.visit_count = 0
        self.value_sum = 0

    def is_fully_expanded(self):
        # fully expanded if no expandable_moves
        # and we have some children
        return np.sum(self.expandable_moves) == 0 and len(self.children) > 0
    
    def select(self):
        # look through all our children and find the best UCB score
        best_child = None
        best_ucb = -np.inf
        for child in self.children:
            ucb = self.get_ucb(child)
            if (ucb > best_ucb):
                best_child = child
                best_ucb = ucb
        return best_child
    
    def get_ucb(self, child):
        # Q(state, action) + C sqrt(ln(n_s)/n_{s,a})
        # Likelihood of winning + constant times less visited nodes (so they are visited more)
        q_value = (1 - ((child.value_sum / child.visit_count) + 1) / 2) # normalized to [0, 1]
        # this q_value is what the child thinks of itself, but we are on opposite teams
        # so we want to pick the action that will have the WORST children q_value
        # this is why we take 1 - NORMAL_Q_VALUE
        return q_value + self.args["C"] * math.sqrt(math.log(self.visit_count) / child.visit_count)
    
    def expand(self):
        action = np.random.choice(np.where(self.expandable_moves == 1)[0])
        self.expandable_moves[action] = 0

        child_state = self.state.copy()

        child_state = self.game.get_next_state(child_state, action, 1)

        child_state = self.game.change_perspective(child_state, player=-1)

        child = Node(self.game, self.args, child_state, self, action)

        self.children.append(child)
        return child
    
    def simulate(self):
        value, is_terminated = self.game.get_value_and_terminated(self.state, self.action_taken)
        value = self.game.get_opponent_value(value)

        if is_terminated:
            return value
        
        # only simulate is not terminal
        rollout_state = self.state.copy()
        rollout_player = 1
        while True:
            valid_moves = self.game.get_valid_moves(rollout_state)
            action = np.random.choice(np.where(valid_moves == 1)[0])
            rollout_state = self.game.get_next_state(rollout_state, action, rollout_player)
            value, is_terminated = self.game.get_value_and_terminated(rollout_state, action)
            # maybe we have finished now
            if (is_terminated):
                if (rollout_player == -1):
                    value = self.game.get_opponent_value(value)
                return value
            # we continue, after flipping who gets to play
            rollout_player = self.game.get_opponent(rollout_player)

    def backpropagate(self, value):
        self.value_sum += value
        self.visit_count += 1

        if (self.parent is not None):
            opponent_value = self.game.get_opponent_value(value)
            self.parent.backpropagate(opponent_value)
            

class MCTS:
    def __init__(self, game, args):
        self.game = game
        self.args = args

    def search(self, state):
        # define root node
        root = Node(self.game, self.args, state)

        for search in range(self.args["num_searches"]):
            node = root
            # selection
            while node.is_fully_expanded():
                # select down to a child
                node = node.select()

            value, is_terminated = self.game.get_value_and_terminated(node.state, node.action_taken)
            value = self.game.get_opponent_value(value)

            if (not is_terminated):
              # expansion - creates a new child for a random action
              node = node.expand()
              # simulation
              value = node.simulate()
            # backpropagation
            node.backpropagate(value)
        action_probs = np.zeros(self.game.action_size)
        for child in root.children:
            action_probs[child.action_taken] = child.visit_count
        action_probs /= np.sum(action_probs)
        return action_probs
        # return visit_counts


In [4]:
tictactoe = TicTacToe()
player = 1

args = {
    "C": 1.41,
    "num_searches": 1000
}

mcts = MCTS(tictactoe, args)

state = tictactoe.get_initial_state()

while True:
    print(state)

    if (player == 1):
      valid_moves = tictactoe.get_valid_moves(state)
      print("valid_moves", [i for i in range(tictactoe.action_size) if valid_moves[i] == 1])

      action = int(input(f"{player} Move: "))

      if (valid_moves[action] == 0):
          print("Action not valid")
          continue
    else:
        neutral_state = tictactoe.change_perspective(state, player)
        mcts_probs = mcts.search(neutral_state)
        action = np.argmax(mcts_probs)

    # action is valid
    state = tictactoe.get_next_state(state, action, player)

    # checking end game
    value, is_terminated = tictactoe.get_value_and_terminated(state, action)

    if (is_terminated):
        print(state)
        if (value == 1):
            print(f"Player {player} has won")
        else:
            print("DRAW")
        break
    
    # game continues
    player = tictactoe.get_opponent(player)



[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
valid_moves [0, 1, 2, 3, 4, 5, 6, 7, 8]
[[0. 0. 0.]
 [0. 1. 0.]
 [0. 0. 0.]]
[[ 0.  0.  0.]
 [ 0.  1.  0.]
 [ 0. -1.  0.]]
valid_moves [0, 1, 2, 3, 5, 6, 8]
[[ 0.  0.  0.]
 [ 1.  1.  0.]
 [ 0. -1.  0.]]
[[ 0.  0.  0.]
 [ 1.  1. -1.]
 [ 0. -1.  0.]]
valid_moves [0, 1, 2, 6, 8]
[[ 1.  0.  0.]
 [ 1.  1. -1.]
 [ 0. -1.  0.]]
[[ 1.  0.  0.]
 [ 1.  1. -1.]
 [-1. -1.  0.]]
valid_moves [1, 2, 8]
[[ 1.  0.  0.]
 [ 1.  1. -1.]
 [-1. -1.  1.]]
Player 1 has won
