In [55]:
import numpy as np
import math
from attax_game import Attaxx
from go import Go
#from matplotlib import pyplot
from copy import deepcopy
from tqdm.notebook import trange

import torch
import torch.nn as nn
import torch.nn.functional as F

In [56]:
class ResNet(nn.Module):
    def __init__(self, game, num_resBlocks, num_hidden):
        super().__init__()
        self.startBlock = nn.Sequential(
            nn.Conv2d(3, num_hidden, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_hidden),
            nn.ReLU()
        )
        
        self.backBone = nn.ModuleList(
            [ResBlock(num_hidden) for i in range(num_resBlocks)]
        )
        
        self.policyHead = nn.Sequential(
            nn.Conv2d(num_hidden, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * game.board_size * game.board_size, game.board_size * game.board_size )
        )
        
        self.valueHead = nn.Sequential(
            nn.Conv2d(num_hidden, 3, kernel_size=3, padding=1),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3 * game.board_size * game.board_size, 1),
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.startBlock(x)
        for resBlock in self.backBone:
            x = resBlock(x)
        policy = self.policyHead(x)
        value = self.valueHead(x)
        return policy, value
        
        
class ResBlock(nn.Module):
    def __init__(self, num_hidden):
        super().__init__()
        self.conv1 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_hidden)
        self.conv2 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_hidden)
        
    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        x = F.relu(x)
        return x

In [57]:

class Node():
    def __init__(self, game, args, state, player = None, parent=None, action_taken=None, prior=0):
        self.game = game
        self.args = args
        self.state = state
        self.player = player
        self.parent = parent
        self.action_taken = action_taken
        self.prior = prior
        self.expandable_moves = game.get_valid_moves(self.state, self.player)

        self.children = []

        self.visit_count = 0
        self.value_sum = 0
    
    def is_fully_expanded(self):
        # print("EXPANDABLE MOVES: " + str(len(self.expandable_moves)))
        return len(self.expandable_moves) == 0 and len(self.children) > 0

    def select(self):
        best_child = None
        best_ucb = -np.inf

        for child in self.children:
            ucb  = self.get_ucb(child)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb
        
        return best_child

    def get_ucb(self, child):
        q_value = 1 - ((child.value_sum/child.visit_count)+1) / 2
        return q_value + self.C * math.sqrt(math.log(self.visit_count) / child.visit_count)
    
    def expand(self, policy):
        for action, prob in enumerate(policy):
            if prob > 0:
                print("Action: " + str(action) + " Prob: " + str(prob))
                action_coords = self.game.get_action_from_index(action)
                child_state = deepcopy(self.state)
                child_state = self.game.get_next_state(child_state, action_coords, 1)
                child_state = self.game.change_perspective(child_state, player=-1)

                print("Action Before Adding to Children: " + str(action))
                child = Node(game=self.game, args=self.args, state=child_state, parent=self, action_taken=action, prior=prob)
                self.children.append(child)
        print("Children: " + str(len(self.children)))
        return child
    
    def backpropagate(self, value):
        print("Backpropagating in Node" + str(self.action_taken) + " with value " + str(value))
        self.value_sum += value
        self.visit_count += 1

        value = -value
        if self.parent is not None:
            self.parent.backpropagate(value)

In [58]:
class MCTS():
    def __init__(self, game, args, model) -> None:
        self.game = game
        self.args = args
        self.C = args['C']
        self.num_searches = args['num_searches']
        self.model = model
    
    @torch.no_grad()
    def search(self, state, player):
        #define root
        root = Node(self.game, self.args, state, player = player)
        #selection 
        for search in range(self.num_searches):
            node = root
            if search % 10 == 0:
                print("Searches Done: " + str(search))
            # DEBUG
            while node.is_fully_expanded():
                node = node.select()

            print("Node Children: " + str(len(node.children)))
            
            value, is_terminal = self.game.get_value_and_terminated(node.state)
            value = -value
            
            if not is_terminal:
                # print(self.game.get_encoded_state(node.state))
                policy, value = self.model(torch.tensor(self.game.get_encoded_state(node.state), dtype= torch.float32).unsqueeze(0))
                print("POLICY: \n" + str(policy))
                print("VALUE: \n" + str(value))
                print("ENCODED STATE: \n" + str(self.game.get_encoded_state(node.state)))
                policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()

                valid_moves = self.game.get_valid_moves(node.state, player)

                print("VALID MOVES:\n " + str(valid_moves))

                policy *= valid_moves

                policy /= np.sum(policy)
                
                value = value.item()
                
                node.expand(policy)

            print(node.children)
            print("VALUE: " + str(value))
            node.backpropagate(value)
            print("Backpropagated, Value Sum: " + str(node.value_sum) + " Visit Count: " + str(node.visit_count))

        action_probs = np.zeros(self.game.action_size)
        for child in root.children:
            print("Action: " + str(child.action_taken) + " Visit Count: " + str(child.visit_count))
            action_probs[child.action_taken] = child.visit_count
        action_probs /= np.sum(action_probs) if np.sum(action_probs) > 0 else 1
        return action_probs

In [59]:
class AlphaZero:
    def __init__(self, model, game, optimizer, args) -> None:
        self.model = model
        self.game = game
        self.optimizer = optimizer
        self.args = args
        self.mcts = MCTS(game, args, model)

    def selfPlay(self):
        memory = []
        player = 1
        state = self.game.get_initial_state()
        while True:
            neutral_state = self.game.change_perspective(state, player)
            action_probs = self.mcts.search(neutral_state, player)
            memory.append((neutral_state, action_probs, player))
            action = np.random.choice(self.game.action_size, p=action_probs)
            
            state = self.game.get_next_state(state, action, player)
            
            value, is_terminal = self.game.get_value_and_terminated(state, action)
            
            if is_terminal:
                returnMemory = []
                for hist_neutral_state, hist_action_probs, hist_player in memory:
                    hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
                    returnMemory.append((
                        self.game.get_encoded_state(hist_neutral_state),
                        hist_action_probs,
                        hist_outcome
                    ))
                return returnMemory
            
            player = self.game.get_opponent(player)

    def learn(self):
        for iteration in range(self.args['num_iterations']):
            memory = []
            
            self.model.eval()
            for selfPlay_iteration in trange(self.args['num_selfPlay_iterations']):
                memory += self.selfPlay()
                
            self.model.train()
            
            for epoch in trange(self.args['num_epochs']):
                self.train(memory)
            
            torch.save(self.model.state_dict(), f"model_{iteration}.pt")
            torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}.pt")
    
    def train(self, memory):
        pass

In [60]:
game_args = [9, 5.5]

go = Go(game_args)

model = ResNet(go, 5, 32)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

args = {
    'C': 1,
    'num_searches': 100,
    'num_iterations': 10,
    'num_selfPlay_iterations': 10,
    'num_epochs': 10
}

alpazero = AlphaZero(model, go, optimizer, args)
alpazero.learn()

  0%|          | 0/10 [00:00<?, ?it/s]

Searches Done: 0
Node Children: 0
POLICY: 
tensor([[-0.0178, -0.0707,  0.1988,  0.0014,  0.1965,  0.0647,  0.1416,  0.0036,
          0.0246,  0.2090, -0.1359, -0.0134, -0.1428, -0.1714, -0.0445, -0.0105,
          0.1280, -0.0929, -0.1129, -0.0134,  0.0821, -0.0732, -0.1805,  0.0807,
         -0.0610, -0.0052,  0.1026,  0.1291,  0.1169,  0.1800, -0.0059, -0.0696,
          0.0345,  0.0166,  0.1760,  0.1471, -0.1779, -0.0912, -0.0894, -0.1894,
         -0.0946,  0.0219,  0.1315, -0.0028,  0.0324, -0.1206, -0.2836,  0.0650,
          0.1390,  0.0204, -0.2049,  0.0687,  0.0161,  0.2855, -0.2344, -0.0136,
         -0.0257, -0.0752, -0.0084, -0.1205,  0.0210, -0.0623,  0.0441,  0.0023,
          0.0054,  0.1276,  0.0114, -0.0400, -0.1312, -0.0719,  0.2352,  0.1303,
         -0.1354, -0.0855,  0.1055,  0.0247,  0.0990,  0.0480,  0.1585,  0.0360,
          0.0125]])
VALUE: 
tensor([[0.0722]])
ENCODED STATE: 
[[[0 0 0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0 0

TypeError: cannot unpack non-iterable int object