# AlphaGo Zero
---
For this exercise we will implement AlphaZero which is a more generalized form of AlphaGO and train it on a game of 
connect 4. We will first play against AlphaZero when it has not trained at all then train it for a few epochs then try 
to beat it again. We will repeat this process a few time so that you can tangably observe the algorithm getting smarter. 

## 1. Game Definition
---
First lets set up the connect 4 game. This part has been done for you. 

In [1]:
import numpy as np

class board():
    def __init__(self):
        self.init_board = np.zeros([6,7]).astype(str)
        self.init_board[self.init_board == "0.0"] = " "
        self.player = 0
        self.current_board = self.init_board
    
    def drop_piece(self, column):
        if self.current_board[0, column] != " ":
            return "Invalid move"
        else:
            row = 0; pos = " "
            while (pos == " "):
                if row == 6:
                    row += 1
                    break
                pos = self.current_board[row, column]
                row += 1
            if self.player == 0:
                self.current_board[row-2, column] = "O"
                self.player = 1
            elif self.player == 1:
                self.current_board[row-2, column] = "X"
                self.player = 0
    
    def check_pieces(self, piece_string):
        for row in range(6):
                for col in range(7):
                    if self.current_board[row, col] != " ":
                        # rows
                        try:
                            if self.current_board[row, col] == piece_string and self.current_board[row + 1, col] == piece_string and \
                                self.current_board[row + 2, col] == piece_string and self.current_board[row + 3, col] == piece_string:
                                #print("row")
                                return True
                        except IndexError:
                            next
                        # columns
                        try:
                            if self.current_board[row, col] == piece_string and self.current_board[row, col + 1] == piece_string and \
                                self.current_board[row, col + 2] == piece_string and self.current_board[row, col + 3] == piece_string:
                                #print("col")
                                return True
                        except IndexError:
                            next
                        # \ diagonal
                        try:
                            if self.current_board[row, col] == piece_string and self.current_board[row + 1, col + 1] == piece_string and \
                                self.current_board[row + 2, col + 2] == piece_string and self.current_board[row + 3, col + 3] == piece_string:
                                #print("\\")
                                return True
                        except IndexError:
                            next
                        # / diagonal
                        try:
                            if self.current_board[row, col] == piece_string and self.current_board[row + 1, col - 1] == piece_string and \
                                self.current_board[row + 2, col - 2] == piece_string and self.current_board[row + 3, col - 3] == piece_string\
                                and (col-3) >= 0:
                                #print("/")
                                return True
                        except IndexError:
                            next
    def check_winner(self):
        if self.player == 1:
            self.check_pieces("O")
            
        if self.player == 0:
            self.check_pieces("X")
            
    def actions(self): # returns all possible moves
        acts = []
        for col in range(7):
            if self.current_board[0, col] == " ":
                acts.append(col)
        return acts
    

## 2. One Network Two Heads
---
Next lets implement the Neural Network. Recall that AlphaGo Zero uses Convolutional 
ResNet architecture with two heads. One that outputs a probability distribution over all possible moves $(p)$ and another that 
outputs a single scalar value $(v)$ representing the value of the current state. Because we have two heads we need to create a 
custom loos function.   

The neural network is defined as:  
$$f_\theta (s) = (\mathbf{p,v})$$  
The loss function is defined as:  
$$l = (z - \mathbf{v})^2 - \pi^T log(\mathbf{p}) + c||\theta||^2$$  

In [3]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import matplotlib
matplotlib.use("Agg")
import numpy as np

class board_data(Dataset):
    def __init__(self, dataset): # dataset = np.array of (s, p, v)
        self.X = dataset[:,0]
        self.y_p, self.y_v = dataset[:,1], dataset[:,2]
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self,idx):
        return np.int64(self.X[idx].transpose(2,0,1)), self.y_p[idx], self.y_v[idx]
    

class ConvBlock(nn.Module):
    """
    This is the first part of the ResNet where we define the convolutions
    """
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.action_size = 7
        self.conv1 = nn.Conv2d(3, 128, 3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(128)

    def forward(self, s):
        s = s.view(-1, 3, 6, 7)  # batch_size x channels x board_x x board_y
        s = F.relu(self.bn1(self.conv1(s)))
        return s

class ResBlock(nn.Module):
    """
    Here is where we define the Residual blocks
    """
    def __init__(self, inplanes=128, planes=128, stride=1, downsample=None):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = F.relu(self.bn1(out))
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = F.relu(out)
        return out

class OutBlock(nn.Module):
    """
    This is where the NN splits into two heads
    """
    def __init__(self):
        super(OutBlock, self).__init__()
        """
        Value Head
        """
        self.conv = nn.Conv2d(128, 3, kernel_size=1)
        self.bn = nn.BatchNorm2d(3)
        self.fc1 = nn.Linear(3*6*7, 32)
        self.fc2 = nn.Linear(32, 1)
        
        """
        Policy Head
        """
        self.conv1 = nn.Conv2d(128, 32, kernel_size=1) # policy head
        self.bn1 = nn.BatchNorm2d(32)
        self.logsoftmax = nn.LogSoftmax(dim=1)
        self.fc = nn.Linear(6*7*32, 7)
    
    def forward(self,s):
        """
        Value Head
        """
        v = F.relu(self.bn(self.conv(s)))
        v = v.view(-1, 3*6*7)  # batch_size X channel X height X width
        v = F.relu(self.fc1(v))
        v = torch.tanh(self.fc2(v))
        
        """
        Policy Head
        """
        p = F.relu(self.bn1(self.conv1(s))) # policy head
        p = p.view(-1, 6*7*32)
        p = self.fc(p)
        p = self.logsoftmax(p).exp()
        return p, v

class ConnectNet(nn.Module):
    """
    Here we bring the ResNet and the two head together
    """
    def __init__(self):
        super(ConnectNet, self).__init__()
        self.conv = ConvBlock()
        for block in range(10):
            setattr(self, "res_%i" % block,ResBlock())
        self.outblock = OutBlock()
    
    def forward(self,s):
        s = self.conv(s)
        for block in range(10):
            s = getattr(self, "res_%i" % block)(s)
        s = self.outblock(s)
        return s
    
class AlphaLoss(torch.nn.Module):
    """
    Here we define the loss function
    """
    def __init__(self):
        super(AlphaLoss, self).__init__()

    def forward(self, y_value, value, y_policy, policy):
        value_error = (value - y_value) ** 2
        policy_error = torch.sum((-policy * (1e-8 + y_policy.float()).float().log()), 1)
        
        total_error = (value_error.view(-1).float() + policy_error)
        return total_error

## 3. Make the Board Accessible to the NN
---
The game is encoded in a 6 x 7 matrix. We need to make this into a data structure that can be transformed into a tensor 
that our NN will understand. Our NN takes a tensor of 6 x 7 x 3. The first two channels indicate each players pieces on 
the board and the third indicates whose turn it is. 

In [None]:
def encode_board(board):
    """
    Take board representation and turns in into 3 matrices so it can be easily converted into a tensor. 
    :param board: 6 x 7 matrix of strings
    :return: 
    """
    board_state = board.current_board
    encoded = np.zeros([6,7,3]).astype(int)
    encoder_dict = {"O":0, "X":1}
    for row in range(6):
        for col in range(7):
            if board_state[row,col] != " ":
                encoded[row, col, encoder_dict[board_state[row,col]]] = 1
    if board.player == 1:
        encoded[:,:,2] = 1 # player to move
    return encoded

def decode_board(encoded):
    """
    DO WE NEED THIS? 
    :param encoded: 
    :return: 
    """
    decoded = np.zeros([6,7]).astype(str)
    decoded[decoded == "0.0"] = " "
    decoder_dict = {0:"O", 1:"X"}
    for row in range(6):
        for col in range(7):
            for k in range(2):
                if encoded[row, col, k] == 1:
                    decoded[row, col] = decoder_dict[k]
    cboard = board()
    cboard.current_board = decoded
    cboard.player = encoded[0,0,2]
    return cboard

## 4. Monte Carlo Tree Search (MCTS)
---
Now for the guts of the algorithm. To determine the best move(edge) $a$ at each state(node) $s$ we do a MCTS using the neural network 
to guide our exploration. During exploration, we chose the max over all possible actions $A$.  

$$a = \max_{a \in A} (Q + U)$$  
In this equation $Q$ is our state-action value and represents exploitation. The Q term represents the exploration term. 
If we expand $U$ we get the following:  
$$Q + c_{punct} P(s,a) \frac{\sqrt{\sum_b N(s,b)}}{1 + N(s,a)}$$  
$c_{punct}$ is a constant that controls exploration  
$$P(s,a)$$ is the prior probability of choosing action $a$ from policy $p$ given by $f_\theta (s)$   
$\sum_b N(s,b)$ is the parent node visit count  
$$N(s,a)$$ is the number of visits to the current node. 

Recall that each node in the tree represents a state $s$. Each edge represents an action $a$ that we can take from that 
state 

In [None]:
import copy

class UCTNode():
    def __init__(self, game, move, parent=None):
        self.game = game # state s
        self.move = move # action index
        self.is_expanded = False
        self.parent = parent  
        self.children = {}
        self.child_priors = np.zeros([7], dtype=np.float32)
        self.child_total_value = np.zeros([7], dtype=np.float32)
        self.child_number_visits = np.zeros([7], dtype=np.float32)
        self.action_idxes = []
        
    @property
    def number_visits(self):
        return self.parent.child_number_visits[self.move]

    @number_visits.setter
    def number_visits(self, value):
        self.parent.child_number_visits[self.move] = value
    
    @property
    def total_value(self):
        return self.parent.child_total_value[self.move]
    
    @total_value.setter
    def total_value(self, value):
        self.parent.child_total_value[self.move] = value
    
    def child_Q(self):
        return self.child_total_value / (1 + self.child_number_visits)
    
    def child_U(self):
        return np.sqrt(self.number_visits) * (
            abs(self.child_priors) / (1 + self.child_number_visits))
    
    def best_child(self):
        if self.action_idxes != []:
            bestmove = self.child_Q() + self.child_U()
            bestmove = self.action_idxes[np.argmax(bestmove[self.action_idxes])]
        else:
            bestmove = np.argmax(self.child_Q() + self.child_U())
        return bestmove
    
    def select_leaf(self):
        current = self
        while current.is_expanded:
          best_move = current.best_child()
          current = current.maybe_add_child(best_move)
        return current
    
    def add_dirichlet_noise(self, action_idxs, child_priors):
        """
        Adds dirichlet noise to priors for actions from root node to ensure exploration.
        :param action_idxs: 
        :param child_priors: 
        :return: 
        """
        # select only legal moves entries in child_priors array
        valid_child_priors = child_priors[action_idxs] 
        valid_child_priors = 0.75*valid_child_priors + 0.25*np.random.dirichlet(np.zeros([len(valid_child_priors)], \
                                                                                          dtype=np.float32)+192)
        child_priors[action_idxs] = valid_child_priors
        return child_priors
    
    def expand(self, child_priors):
        self.is_expanded = True
        action_idxs = self.game.actions()
        
        # c_p is set to child priors because c_p = 1 for now
        c_p = child_priors
        
        # if there are no legal actions do not expand node
        if action_idxs == []:
            self.is_expanded = False
        self.action_idxes = action_idxs
        
        # mask all illegal actions
        for i in range(len(child_priors)):
            if i not in action_idxs:
                c_p[i] = 0.000000000
                        
        # add dirichlet noise to child_priors in root node
        if self.parent.parent == None: 
            c_p = self.add_dirichlet_noise(action_idxs,c_p)
        self.child_priors = c_p
    
    def decode_n_move_pieces(self,board,move):
        board.drop_piece(move)
        return board
            
    def maybe_add_child(self, move):
        if move not in self.children:
            copy_board = copy.deepcopy(self.game) # make copy of board
            copy_board = self.decode_n_move_pieces(copy_board,move)
            self.children[move] = UCTNode(copy_board, move, parent=self)
        return self.children[move]
    
    def backup(self, value_estimate: float):
        current = self
        while current.parent is not None:
            current.number_visits += 1
            if current.game.player == 1: # same as current.parent.game.player = 0
                current.total_value += (1*value_estimate) # value estimate +1 = O wins
            elif current.game.player == 0: # same as current.parent.game.player = 1
                current.total_value += (-1*value_estimate)
            current = current.parent
 