In [1]:
from collections import deque
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from tqdm.notebook import trange

In [2]:
class Resnet(nn.Module):

    def __init__(self, in_channels, out_channels, resnet_blocks = 5):
        super(Resnet, self).__init__()
        self.start = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

        self.blocks = nn.ModuleList(
            [ResidualConnection(out_channels) for _ in range(resnet_blocks)]
        )

        self.policy_head = nn.Sequential(
            nn.Conv2d(out_channels, out_channels = 32, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 6 * 7, 7)
        )

        self.value_head = nn.Sequential(
            nn.Conv2d(out_channels, out_channels = 3, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3 * 6 * 7, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.start(x)
        for block in self.blocks:
            x = block(x)
        policy = self.policy_head(x)
        value = self.value_head(x)
        return policy, value
        

class ResidualConnection(nn.Module):

    def __init__(self, out_channels):
        super(ResidualConnection, self).__init__()

        self.block = nn.Sequential(
            nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(out_channels)     
        )

    def forward(self, x):
        features = self.block(x)
        return torch.relu(features + x)

In [3]:
class Connect4:

    def __init__(self):
        self.W = 7
        self.H = 6
        self.board = np.zeros((self.H, self.W))
        self.all_moves = np.arange(self.W)


    def reset(self):
        self.board = np.zeros((self.H, self.W))
        return self.board
        
    def get_valid_moves(self, state):
        return (state[0] == 0).astype(int)

    def step(self, move, player):
        row, col = move//self.W, move%self.W

        brow, bcol = self.H-1, col
        while True:
            if self.board[brow, bcol] == 0:
                self.board[brow, bcol] = player
                break
            brow -= 1
        result1 = self.check_winner(1)
        result2 = self.check_winner(-1)

        result = result1 or result2
        
        if result is not None:
            done = True
        elif not(self.board == 0).any():
            done = True
        else:
            done = False

        return self.board, result, done

    def check_winner(self, player):
        for r in range(self.H):
            for c in range(self.W -3):
                if self.board[r,c] == self.board[r, c + 1] == self.board[r, c+2] == self.board[r, c+3] == player:
                    return player
                    
        for r in range(self.H - 3):
            for c in range(self.W):
                if self.board[r,c] == self.board[r + 1, c] == self.board[r + 2, c] == self.board[r+3, c] == player:
                    return player

        for r in range(self.H - 3):
            for c in range(self.W -3):
                if self.board[r,c] == self.board[r + 1, c + 1] == self.board[r + 2, c+2] == self.board[r + 3, c+3] == player:
                    return player
                    
        for r in range(3, self.H):
            for c in range(self.W -3):
                if self.board[r,c] == self.board[r - 1, c + 1] == self.board[r - 2, c+2] == self.board[r - 3, c+3] == player:
                    return player

        return None
        
    def stackedStates(self, state, current_player):
        
        return np.stack((
            state == current_player, 
            state == -current_player,
            state == 0
        )).astype(np.float32)

    
    def show(self, state):
        print(state)

In [4]:
import numpy as np

class Node:

    def __init__(self, state, parent, move, player, prob = 0):
        self.H = 6
        self.W = 7
        self.state = state
        self.parent = parent
        self.move = move

        self.player = player
        self.children = {}
        
        # self.unexpanded_children = np.argwhere((self.state.flatten() == 0)).flatten()
        self.prob = prob
        
        self.N = 0
        self.W = 0

# can optimize it
    def is_terminal(self, H = 6, W = 7):

        for r in range(H):
            for c in range(W - 3):
                val = self.state[r, c]
                if val in [1, -1] and val == self.state[r, c+1] == self.state[r, c+2] == self.state[r, c+3]:
                    return True, val
        
        # p
        for r in range(H - 3):
            for c in range(W):
                val = self.state[r, c]
                if val in [1, -1] and val == self.state[r+1, c] == self.state[r+2, c] == self.state[r+3, c]:
                    return True, val
        
        # Diagonal \ (down-right)
        for r in range(H - 3):
            for c in range(W - 3):
                val = self.state[r, c]
                if val in [1, -1] and val == self.state[r+1, c+1] == self.state[r+2, c+2] == self.state[r+3, c+3]:
                    return True, val
        
        # Diagonal / (up-right)
        for r in range(3, H):
            for c in range(W - 3):
                val = self.state[r, c]
                if val in [1, -1] and val == self.state[r-1, c+1] == self.state[r-2, c+2] == self.state[r-3, c+3]:
                    return True, val
        
        # return False
        if not (self.state == 0).any():
            return True, 0
            
        return False, 0

    def is_fully_expanded(self):
        # print("hereeeeee")
        # print(self.unexpanded_children)
        # print(len(self.unexpanded_children) == 0)
        # print((np.argwhere(self.state[0] == 0)))
        return len(np.argwhere(self.state[0] == 0)) == 0

    
    def UCB1(self, c = 2):
        if self.N == 0:
            return float('inf')
        return self.W/self.N + c * self.prob * np.sqrt(self.parent.N)/(self.N + 1)

In [5]:
class MCTS:

    def __init__(self, env, root_state, root_player, model, device):
        self.env = env
        self.H = 6
        self.W = 7
        self.policy = 0
        self.model = model
        self.device = device
        self.root = Node(
            state = root_state,
            parent = None,
            move = None,
            player = root_player)
        
    def selection(self):

        current = self.root

        while True:
            is_terminal, _ = current.is_terminal()
            if is_terminal:
                return current

            if current.is_fully_expanded():
                return current
            if len(current.children) == 0:
                return current
            best_child = max(current.children.values(), key = lambda c: c.UCB1())
            # print(best_child)
            # max_UCB = 0
            # for move, child in current.children.items():
            #     if child.UCB1() > max_UCB:
            #         max_UCB = child.UCB1()
            #         best_child = child
                    
            current = best_child
        # print(current)

        return current

    def expansion(self, node, policy):
        # print(policy)
        valid_states = self.env.get_valid_moves(node.state)
        # valid_moves = np.where(valid_states)
        for action, prob in enumerate(policy):
            if valid_states[action] == 1 and action not in node.children:
            # print(len(node.unexpanded_children))
                # if len(node.unexpanded_children)>0 and not node.is_terminal():
                    
                    # print(action, prob)

                # print(action, prob)
                new_state = self.step(node.state, action, node.player)[0]
                
                child = Node(state = new_state,
                             parent = node,
                             move = action,
                             player = -node.player,
                             prob = prob
                            )
    
                
                node.children[action] = child
                    # node.unexpanded_children = node.unexpanded_children[node.unexpanded_children != action]
            # np.delete(node.unexpanded_children, np.where(node.unexpanded_children == move))

        # return child
        # return None


    # def simulation(self, node):
    #     policy, value = self.model(data)
    #     return 
        # state = node.state.copy()
        # current_player = node.player
        
        # while True:
            
        #     winner1 = self.check_winner(state, 1)
        #     winner2 = self.check_winner(state, -1)

        #     winner = winner1 or winner2
            
            
        #     if winner is not None:
        #         return winner, state
        #     legal_moves = np.argwhere((state.flatten() == 0)).flatten()
        #     if len(legal_moves) == 0:
        #         return 0, state
        #     move = np.random.choice(legal_moves)
        #     state, winner = self.step(state, move, current_player)

        #     if winner is not None:
        #         return winner, state
            
        #     current_player = -current_player


    def backpropagation(self, child, value):
        current = child
        while current is not None:
            current.N += 1
            current.W += value
            value = -value
            current = current.parent

        # print(f"Child : {child}\nMove : {child.move}\n Number : {child.N}\nWin : {child.W}")
        
        # current = child
        # how_deep = 0
        # while current is not None:

        #     current.N += 1
        #     # print(f"Current State : \n{current.state}")
        #     if winner == -current.player:
        #         # print(f"Winner : {winner}, Current Player : {current.player}\nMove : {current.move} --> Plus One Win\n\n")
        #         current.W += 1
        #     elif winner == 0:
        #         # print(f"Winner : {winner}, Current Player : {current.player}\nMove : {current.move} --> Draw (Plus 0.5)\n\n")
        #         current.W += 0.5
        #     # else:
        #         # print(f"Winner : {winner}, Current Player : {current.player}\nMove : {current.move} --> No Win (Plus 0)\n\n")
                

        #     current = current.parent

    # @torch.no_grad()
    # def search(self, iterations=10):

    #     for _ in range(iterations):
    #         # Select child based on UCB
    #         child = self.selection()

    #         # get stacked states
    #         is_terminal, value = child.is_terminal()
    #         print(is_terminal, value)

    #         # GET THE VVALUES IMPROVE LATER
    #         if not is_terminal:
    #             stackedStates = self.env.stackedStates(child.state, child.player)
    
    #             # get policy and value from model
    #             policy, value = self.model(torch.tensor(stackedStates, dtype = torch.float).unsqueeze(0).to(self.device))
    
    #             # turn them into probs
    #             policy = torch.softmax(policy, dim = 1).squeeze(0).cpu().detach().numpy()
    #             value = value.squeeze(0).cpu().detach().numpy()
    
    #             # Mask all the invalid moves
    #             policy = policy * self.env.get_valid_moves(child.state)
    
    #             # Normalize after masking
    #             policy = policy/(np.sum(policy))
    
                
    #             self.expansion(child, policy)
    #             # if len(child.children) >= 0:
                    
    #         # if current is None:

    #                     #     current = child

    #         # winner, state = self.simulation(current)
    #         self.backpropagation(child, value)
    #         # self.policy = policy

    #     actions = np.zeros(7)


    #     if len(self.root.children) == 0:
    #         print("here")
    #         valid_moves = self.env.get_valid_moves(self.root.state)
    #         action_probs = valid_moves / np.sum(valid_moves)
    #         print(action_probs)
    #         return action_probs
    #     print("heehehe")
        
    #     for action, child in self.root.children.items():
    #         # print(action, child, child.N)
    #         actions[action] = child.N
    #         # print(actions)
    #     actions /= np.sum(actions)
    #     print(actions)
    #     return actions
    @torch.no_grad()
    def search(self, iterations=10):
    
        for _ in range(iterations):
            # Select child based on UCB
            node = self.selection()
    
            # get stacked states
            is_terminal, value = node.is_terminal()
            if is_terminal:
                value = value * node.player
                
            # print(is_terminal, value)
    
            # GET THE VALUES IMPROVE LATER
            else:
                stackedStates = self.env.stackedStates(node.state, node.player)
    
                # get policy and value from model
                policy, value = self.model(torch.tensor(stackedStates, dtype = torch.float).unsqueeze(0).to(self.device))
                policy = torch.softmax(policy, dim = 1).squeeze(0).cpu().detach().numpy()

                if node is self.root:
                    alpha = 0.3
                    eps = 0.25
                    noise = np.random.dirichlet([alpha] * len(self.env.all_moves))
                    policy = (1 - eps) * policy + eps * noise
                # turn them into probs
                value = value.squeeze(0).cpu().item()
    
                # Mask all the invalid moves
                valid_moves = self.env.get_valid_moves(node.state)
                policy = policy * valid_moves
    
                # Normalize after masking
                if np.sum(policy)>0:
                    
                    policy = policy/(np.sum(policy))
                else:
                    policy = valid_moves/np.sum(valid_moves)
    
                # Expand the node (creates children)
                self.expansion(node, policy)
                
                # *** KEY FIX: Pick one of the newly created children to backprop from ***
                if len(node.children) > 0:
                    best_action = max(node.children.keys(), key=lambda a: node.children[a].prob)
                    child = node.children[best_action]
                else:
                    child = node
            # else:
            #     child = node

            target_node = node if is_terminal else child
    
            # Backpropagate from the CHILD, not the parent
            self.backpropagation(target_node, value)
        actions = np.zeros(7)

        if len(self.root.children) == 0:
            # print("here")
            valid_moves = self.env.get_valid_moves(self.root.state)
            action_probs = valid_moves / np.sum(valid_moves)
            # print(action_probs)
            return action_probs
        # print("heehehe")
        total_visits = sum(child.N for child in self.root.children.values())
        if total_visits > 0:
            
            for action, child in self.root.children.items():
            # print(action, child, child.N)
                actions[action] = child.N
            # print(actions)
        # if total_visits > 0:
            actions /= total_visits
            
        
        
        else:
            valid_moves = self.env.get_valid_moves(self.root.state)
            actions = valid_moves / np.sum(valid_moves)
        
        
        # print(actions)
        return actions



    
    def __repr__(self):
        return f"{self.root.state}\n{self.root.parent}\n{self.root.children}\n{self.root.action}\n{self.root.children.n}"

    def train(self):
        P, value = self.network(self.board)
        loss = self.loss(qValue, target_qValues)
        # self.loss_value += loss.item()
        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
        self.optimizer.step()

    
    @staticmethod
    def check_winner(state, player, H = 6, W = 7):
        if not (state == 0).any():
            return 0

        for r in range(H):
            for c in range(W -3):
                if state[r,c] == state[r, c + 1] == state[r, c+2] == state[r, c+3] == player:
                    return player
                    
        for r in range(H - 3):
            for c in range(W):
                if state[r,c] == state[r + 1, c] == state[r + 2, c] == state[r+3, c] == player:
                    return player

        for r in range(H - 3):
            for c in range(W -3):
                if state[r,c] == state[r + 1, c + 1] == state[r + 2, c+2] == state[r + 3, c+3] == player:
                    return player
                    
        for r in range(3, H):
            for c in range(W -3):
                if state[r,c] == state[r - 1, c + 1] == state[r - 2, c+2] == state[r - 3, c+3] == player:
                    return player

        return None
        
    @staticmethod
    def step(state, action, player, H = 6, W = 7):
        state = state.copy()


        col = action%W
        # self.board[row, col] = 1
        # print(move, row, col)
        brow, bcol = H-1, col
        
        while True:
            # print(brow, bcol)
            # print(state)
            if state[brow, bcol] == 0:
                state[brow, bcol] = player
                break
            brow -= 1
        result1 = MCTS.check_winner(state, 1)
        result2 = MCTS.check_winner(state, -1)

        result = result1 or result2
        
        if result is not None:
            done = True
        elif not(state == 0).any():
            done = True
        else:
            done = False

        return state, result, done

        # self.player = -self.player
        # print("Ended From Here")
        # return state, result

In [6]:
class AlphaZero:

    def __init__(self, env, model, policy_loss, value_loss, optimizer, device, current_player = 1):
        self.env = env
        self.model = model
        self.model.to(device)
        self.optimizer = optimizer
        self.mainBuffer = deque(maxlen = 30000)
        self.policy_loss = policy_loss
        self.value_loss = value_loss
        self.device = device
        self.current_player = current_player
        self.tau = 1.25

    
    def selfPlayData(self, ):
        buffer = []
        state = self.env.reset()
        current_player = 1
        
        while True:
            root = MCTS(self.env, state, current_player, self.model, self.device)

            actions = root.search(iterations = 100)
            temperature_probs = actions ** (1/self.tau)
            temperature_probs /= np.sum(temperature_probs)
            buffer.append((state, actions, current_player))

            # if np.random.random() < 0.1:
            #     # print("There")
            #     valid_moves = self.env.get_valid_moves(state)
            #     # print(valid_moves)
            #     valid_probs = valid_moves/np.sum(valid_moves)
            #     action = np.random.choice(len(valid_moves), p = valid_probs)
            # else:
                # print("h1ere")
                # print(actions)
            
            action = np.random.choice(self.env.all_moves, p = temperature_probs)
            # print(actions)
            # print(action)
            # print(state)
            
            state, result, done = root.step(state, action, current_player)
            if done:
                mainBuffer = []
                for states, actions, player in buffer:
                    if result == 0:
                        value = 0
                    elif result == player:
                        value = 1
                    else:
                        value = -1
                    mainBuffer.append((self.env.stackedStates(states, current_player), actions, value))
                return mainBuffer

            current_player = -current_player

    def train(self, selfPlayIterations, epochs, batch_size, total_iterations = 50):
        for i in trange(total_iterations):
            self.model.eval()
            self.mainBuffer = []
            for _ in trange(selfPlayIterations):
                self.mainBuffer.extend(self.selfPlayData())

            
            self.model.train()
            losses = []
            epoch_loss = 0
            previous_model = self.model
            for epoch in trange(epochs):

                epoch_loss = 0
                batch_count = 0
                np.random.shuffle(self.mainBuffer)
                for idx in range(0, len(self.mainBuffer), batch_size):
                    data = self.mainBuffer[idx:idx+batch_size]
                    states, actions, value = zip(*data)
    
                    actions = torch.tensor(actions, dtype = torch.float).to(self.device)
                    values = torch.tensor(value, dtype = torch.float).unsqueeze(1).to(self.device)
                    states = torch.tensor(states, dtype = torch.float).to(self.device)

                    model_policy, model_value = self.model(states)

                    policy_loss = self.policy_loss(model_policy, actions)
                    value_loss = self.value_loss(model_value, values)
                    total_loss = policy_loss + value_loss
                    
                    self.optimizer.zero_grad()
                    total_loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                    self.optimizer.step()            

                    epoch_loss += total_loss.item()
                    batch_count += 1

                avg_loss = epoch_loss/batch_count
                    
                print(f"Epoch {epoch + 1}\tTotal Loss : {avg_loss}")
                losses.append(avg_loss)
            
            self.evaluation(previous_model)

    def evaluation(self, pmodel,num_games = 20):
        wins = 0
        self.model.eval()

        for game in range(num_games):
                state = self.env.reset()
                current_player = 1

                while True:
                    root = MCTS(self.env, state, current_player, self.model, self.device)
                    actions = root.search(iterations = 100)
                    action = np.argmax(actions)
                    state, result, done = root.step(state, action, current_player)
                    if done:
                        if result == current_player:
                            wins += 1
                        break
                    current_player = -current_player

                    root = MCTS(self.env, state, current_player, pmodel, self.device)
                    actions = root.search(iterations = 100)
                    action = np.argmax(actions)
                    state, result, done = root.step(state, action, current_player)
                    if done:
                        break
                    current_player = -current_player

        win_rate = wins/num_games
        print(f"Evaluation : {wins}/{num_games} Wins : {wins}")
                            

In [None]:
env = Connect4()
model = Resnet(3, 64, 5)
optimizer = optim.Adam(model.parameters(), lr = 0.001, weight_decay = 1e-4)
policy_loss = nn.CrossEntropyLoss()
value_loss = nn.MSELoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
agent= AlphaZero(env, model, policy_loss, value_loss, optimizer, device)
agent.train(500, 10, 128)

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]