In [33]:
from collections import deque
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from tqdm.notebook import trange

In [34]:
class Resnet(nn.Module):

    def __init__(self, in_channels, out_channels, resnet_blocks = 5):
        super(Resnet, self).__init__()
        self.start = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

        self.blocks = nn.ModuleList(
            [ResidualConnection(out_channels) for _ in range(resnet_blocks)]
        )

        self.policy_head = nn.Sequential(
            nn.Conv2d(out_channels, out_channels = 32, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 6 * 7, 7)
        )

        self.value_head = nn.Sequential(
            nn.Conv2d(out_channels, out_channels = 3, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3 * 6 * 7, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.start(x)
        for block in self.blocks:
            x = block(x)
        policy = self.policy_head(x)
        value = self.value_head(x)
        return policy, value
        

class ResidualConnection(nn.Module):

    def __init__(self, out_channels):
        super(ResidualConnection, self).__init__()

        self.block = nn.Sequential(
            nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(out_channels)     
        )

    def forward(self, x):
        features = self.block(x)
        return torch.relu(features + x)

In [35]:
class Connect4:

    def __init__(self):
        self.W = 7
        self.H = 6
        self.board = np.zeros((self.H, self.W))
        self.all_moves = np.arange(self.W)


    def reset(self):
        self.board = np.zeros((self.H, self.W))
        return self.board
        
    def get_valid_moves(self, state):
        return (state[0] == 0).astype(int)

    def step(self, board, move, player):
        board = board.copy()
        row, col = move//self.W, move%self.W
        brow, bcol = self.H-1, col
        while True:
            if board[brow, bcol] == 0:
                board[brow, bcol] = player
                break
            brow -= 1
        done, result = self.is_terminal(board)
        return board, result, done

    def is_terminal(self, board,H = 6, W = 7):
        for r in range(H):
            for c in range(W - 3):
                val = board[r, c]
                if val in [1, -1] and val == board[r, c+1] == board[r, c+2] == board[r, c+3]:
                    return True, val
        for r in range(H - 3):
            for c in range(W):
                val = board[r, c]
                if val in [1, -1] and val == board[r+1, c] == board[r+2, c] == board[r+3, c]:
                    return True, val
        for r in range(H - 3):
            for c in range(W - 3):
                val = board[r, c]
                if val in [1, -1] and val == board[r+1, c+1] == board[r+2, c+2] == board[r+3, c+3]:
                    return True, val
        for r in range(3, H):
            for c in range(W - 3):
                val = board[r, c]
                if val in [1, -1] and val == board[r-1, c+1] == board[r-2, c+2] == board[r-3, c+3]:
                    return True, val
        if not (board == 0).any():
            return True, 0
        return False, 0
        
    def stackedStates(self, state, current_player):
        
        return np.stack((
            state == current_player, 
            state == -current_player,
            state == 0
        )).astype(np.float32)

    
    def show(self, state):
        print(state)

In [36]:
import numpy as np

class Node:

    def __init__(self, state, parent, move, player, prob = 0):
        self.H = 6
        self.W = 7
        self.state = state
        self.parent = parent
        self.move = move

        self.player = player
        self.children = {}
        
        self.prob = prob
        
        self.N = 0
        self.W = 0
        
    def is_fully_expanded(self):
        return len(np.argwhere(self.state[0] == 0)) == 0

    
    def UCB1(self, c = 2):
        if self.N == 0:
            return float('inf')
        return self.W/self.N + c * self.prob * np.sqrt(self.parent.N)/(self.N + 1)

In [37]:
class MCTS:

    def __init__(self, env, root_state, root_player, model, device, PARAMS):
        self.env = env
        self.H = 6
        self.W = 7
        self.policy = 0
        self.model = model
        self.device = device
        self.root = Node(state = root_state,
                         parent = None,
                         move = None,
                         player = root_player
                        )
        self.PARAMS = PARAMS
        
    def selection(self):
        current = self.root
        
        while True:
            is_terminal, _ = self.env.is_terminal(current.state)
            if is_terminal:
                return current
            if current.is_fully_expanded():
                return current
            if len(current.children) == 0:
                return current
            best_child = max(current.children.values(), key = lambda c: c.UCB1())
            current = best_child
            
        return current

    def expansion(self, node, policy):
        valid_states = self.env.get_valid_moves(node.state)

        for action, prob in enumerate(policy):
            if valid_states[action] == 1 and action not in node.children:
                new_state = self.env.step(node.state, action, node.player)[0]
                child = Node(state = new_state,
                             parent = node,
                             move = action,
                             player = -node.player,
                             prob = prob
                            )
                node.children[action] = child
                
    def backpropagation(self, child, value):
        current = child
        while current is not None:
            current.N += 1
            current.W += value
            value = -value
            current = current.parent

    @torch.no_grad()
    def search(self):
    
        for _ in range(self.PARAMS['SEARCHES']):
            node = self.selection()
            is_terminal, value = self.env.is_terminal(node.state)
            
            if is_terminal:
                value = value * node.player
            else:
                stackedStates = self.env.stackedStates(node.state, node.player)
                policy, value = self.model(torch.tensor(stackedStates, dtype = torch.float).unsqueeze(0).to(self.device))
                policy = torch.softmax(policy, dim = 1).squeeze(0).cpu().detach().numpy()
                if node is self.root:
                    alpha = self.PARAMS['ALPHA']
                    eps = self.PARAMS['EPSILON']
                    noise = np.random.dirichlet([alpha] * len(self.env.all_moves))
                    policy = (1 - eps) * policy + eps * noise

                value = value.squeeze(0).cpu().item()
                valid_moves = self.env.get_valid_moves(node.state)
                policy = policy * valid_moves
                if np.sum(policy)>0:
                    policy = policy/(np.sum(policy))
                else:
                    policy = valid_moves/np.sum(valid_moves)
    
                self.expansion(node, policy)
                if len(node.children) > 0:
                    best_action = max(node.children.keys(), key=lambda a: node.children[a].prob)
                    child = node.children[best_action]
                else:
                    child = node
            target_node = node if is_terminal else child
    
            self.backpropagation(target_node, value)
        actions = np.zeros(7)

        if len(self.root.children) == 0:
            valid_moves = self.env.get_valid_moves(self.root.state)
            action_probs = valid_moves / np.sum(valid_moves)
            return action_probs

        total_visits = sum(child.N for child in self.root.children.values())
        if total_visits > 0:
            for action, child in self.root.children.items():
                actions[action] = child.N
            actions /= total_visits
        else:
            valid_moves = self.env.get_valid_moves(self.root.state)
            actions = valid_moves / np.sum(valid_moves)
        return actions

    def __repr__(self):
        return f"{self.root.state}\n{self.root.parent}\n{self.root.children}\n{self.root.action}\n{self.root.children.n}"

In [51]:
class AlphaZero:

    def __init__(self, env, model, policy_loss, value_loss, optimizer, device, PARAMS, current_player = 1):
        self.env = env
        self.model = model
        self.model.to(device)
        self.optimizer = optimizer
        self.mainBuffer = deque(maxlen = 30000)
        self.policy_loss = policy_loss
        self.value_loss = value_loss
        self.device = device
        self.current_player = current_player
        self.tau = 1.25
        self.PARAMS = PARAMS

    
    def selfPlayData(self, ):
        buffer = []
        state = self.env.reset()
        current_player = 1
        
        while True:
            root = MCTS(self.env, state, current_player, self.model, self.device, self.PARAMS)

            actions = root.search()
            temperature_probs = actions ** (1/self.PARAMS['TAU'])
            temperature_probs /= np.sum(temperature_probs)
            buffer.append((state, actions, current_player))
            action = np.random.choice(self.env.all_moves, p = temperature_probs)
            state, result, done = self.env.step(state, action, current_player)
            if done:
                mainBuffer = []
                for states, actions, player in buffer:
                    if result == 0:
                        value = 0
                    elif result == player:
                        value = 1
                    else:
                        value = -1
                    mainBuffer.append((self.env.stackedStates(states, current_player), actions, value))
                return mainBuffer

            current_player = -current_player

    def train(self):
        for i in trange(self.PARAMS['TOTAL_ITERATIONS']):
            self.model.eval()
            self.mainBuffer = []
            for _ in trange(self.PARAMS['SELF_PLAY_ITERATIONS']):
                self.mainBuffer.extend(self.selfPlayData())
            
            self.model.train()
            losses = []
            epoch_loss = 0
            previous_model = self.model
            for epoch in trange(self.PARAMS['EPOCHS']):

                epoch_loss = 0
                batch_count = 0
                np.random.shuffle(self.mainBuffer)
                for idx in range(0, len(self.mainBuffer), self.PARAMS["BATCH_SIZE"]):
                    data = self.mainBuffer[idx:idx+self.PARAMS["BATCH_SIZE"]]
                    states, actions, value = zip(*data)
    
                    actions = torch.tensor(actions, dtype = torch.float).to(self.device)
                    values = torch.tensor(value, dtype = torch.float).unsqueeze(1).to(self.device)
                    states = torch.tensor(states, dtype = torch.float).to(self.device)

                    model_policy, model_value = self.model(states)

                    policy_loss = self.policy_loss(model_policy, actions)
                    value_loss = self.value_loss(model_value, values)
                    total_loss = policy_loss + value_loss
                    
                    self.optimizer.zero_grad()
                    total_loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                    self.optimizer.step()            

                    epoch_loss += total_loss.item()
                    batch_count += 1

                avg_loss = epoch_loss/batch_count
                    
                print(f"Epoch {epoch + 1}\tTotal Loss : {avg_loss}")
                losses.append(avg_loss)
            
            self.evaluation(previous_model)
            self.save_model()

    def save_model(self):

        checkpoint = {
            "model" : self.model.state_dict(),
            "optimizer_state" : self.optimizer.state_dict()
        }
        torch.save(checkpoint, 'checkpoint.pth')
        print("Hit Checkpoint")
        
        
    

    def evaluation(self, pmodel):
        wins = 0
        self.model.eval()

        for game in trange(self.PARAMS['EVALUATION_GAMES']):
                state = self.env.reset()
                current_player = 1

                while True:
                    root = MCTS(self.env, state, current_player, self.model, self.device, self.PARAMS)
                    actions = root.search()
                    action = np.argmax(actions)
                    state, result, done = self.env.step(state, action, current_player)
                    if done:
                        if result == current_player:
                            wins += 1
                        break
                    current_player = -current_player

                    root = MCTS(self.env, state, current_player, pmodel, self.device, self.PARAMS)
                    actions = root.search()
                    action = np.argmax(actions)
                    state, result, done = self.env.step(state, action, current_player)
                    if done:
                        break
                    current_player = -current_player

        win_rate = wins/self.PARAMS["EVALUATION_GAMES"]
        print(f"Evaluation : {wins}/{self.PARAMS['EVALUATION_GAMES']} Wins : {wins}")
                            

In [52]:
PARAMETERS = {
    
    "IN_CHANNELS" : 3,
    "OUT_CHANNELS" : 64,
    "RESNET_BLOCKS" : 5,
    
    "SELF_PLAY_ITERATIONS" : 500,
    "EPOCHS" : 20,
    "BATCH_SIZE" : 128,
    "TOTAL_ITERATIONS" : 20,
    
    "SEARCHES" : 100,
    "EVALUATION_GAMES" : 20,
    
    "ALPHA" : 0.3,
    "EPSILON" : 0.25,
    "TAU" : 1.25
    }

In [53]:
env = Connect4()

model = Resnet(PARAMETERS['IN_CHANNELS'], 
               PARAMETERS['OUT_CHANNELS'], 
               PARAMETERS['RESNET_BLOCKS']
              )

optimizer = optim.Adam(model.parameters(), lr = 0.001, weight_decay = 1e-4)
policy_loss = nn.CrossEntropyLoss()
value_loss = nn.MSELoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

agent= AlphaZero(env, 
                 model, 
                 policy_loss, 
                 value_loss, 
                 optimizer, 
                 device, 
                 PARAMETERS
                )

agent.train()

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1	Total Loss : 3.398463726043701
Epoch 2	Total Loss : 2.948635458946228
Epoch 3	Total Loss : 2.8905352354049683
Epoch 4	Total Loss : 2.8924070596694946
Epoch 5	Total Loss : 2.837430953979492
Epoch 6	Total Loss : 2.8480865955352783
Epoch 7	Total Loss : 2.816851854324341
Epoch 8	Total Loss : 2.773837447166443
Epoch 9	Total Loss : 2.7471930980682373
Epoch 10	Total Loss : 2.686057925224304
Epoch 11	Total Loss : 2.7202560901641846
Epoch 12	Total Loss : 2.636173129081726
Epoch 13	Total Loss : 2.5652647018432617
Epoch 14	Total Loss : 2.509036064147949
Epoch 15	Total Loss : 2.407310366630554
Epoch 16	Total Loss : 2.412977695465088
Epoch 17	Total Loss : 2.284764528274536
Epoch 18	Total Loss : 2.2325801849365234
Epoch 19	Total Loss : 2.1369320154190063
Epoch 20	Total Loss : 2.116098165512085


  0%|          | 0/5 [00:00<?, ?it/s]

Evaluation : 5/5 Wins : 5
Hit Checkpoint


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1	Total Loss : 3.3337481021881104
Epoch 2	Total Loss : 2.828295946121216
Epoch 3	Total Loss : 2.6811617612838745
Epoch 4	Total Loss : 2.4485028982162476
Epoch 5	Total Loss : 2.4347630739212036
Epoch 6	Total Loss : 2.2462315559387207
Epoch 7	Total Loss : 2.1551326513290405
Epoch 8	Total Loss : 2.1901028156280518
Epoch 9	Total Loss : 1.9728506803512573
Epoch 10	Total Loss : 2.0071810483932495
Epoch 11	Total Loss : 1.9085388779640198
Epoch 12	Total Loss : 1.8616562485694885
Epoch 13	Total Loss : 2.029250681400299
Epoch 14	Total Loss : 1.983208954334259
Epoch 15	Total Loss : 2.1405811309814453
Epoch 16	Total Loss : 1.882051706314087
Epoch 17	Total Loss : 1.9679740071296692
Epoch 18	Total Loss : 1.8436037302017212
Epoch 19	Total Loss : 1.7245543599128723
Epoch 20	Total Loss : 1.742594838142395


  0%|          | 0/5 [00:00<?, ?it/s]

Evaluation : 3/5 Wins : 3
Hit Checkpoint
