In [1]:
from collections import deque
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from tqdm.notebook import trange

In [2]:
class Resnet(nn.Module):

    def __init__(self, in_channels, out_channels, resnet_blocks = 5):
        super(Resnet, self).__init__()
        self.start = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

        self.blocks = nn.ModuleList(
            [ResidualConnection(out_channels) for _ in range(resnet_blocks)]
        )

        self.policy_head = nn.Sequential(
            nn.Conv2d(out_channels, out_channels = 32, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 6 * 7, 7)
        )

        self.value_head = nn.Sequential(
            nn.Conv2d(out_channels, out_channels = 3, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3 * 6 * 7, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.start(x)
        for block in self.blocks:
            x = block(x)
        policy = self.policy_head(x)
        value = self.value_head(x)
        return policy, value
        

class ResidualConnection(nn.Module):

    def __init__(self, out_channels):
        super(ResidualConnection, self).__init__()

        self.block = nn.Sequential(
            nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(out_channels)     
        )

    def forward(self, x):
        features = self.block(x)
        return torch.relu(features + x)

In [3]:
class Connect4:

    def __init__(self):
        self.W = 7
        self.H = 6
        self.board = np.zeros((self.H, self.W))
        self.all_moves = np.arange(self.W)


    def reset(self):
        self.board = np.zeros((self.H, self.W))
        return self.board
        
    def get_valid_moves(self, state):
        return (state[0] == 0).astype(int)

    def step(self, board, move, player):
        board = board.copy()
        row, col = move//self.W, move%self.W
        brow, bcol = self.H-1, col
        while True:
            if board[brow, bcol] == 0:
                board[brow, bcol] = player
                break
            brow -= 1
        done, result = self.is_terminal(board)
        return board, result, done

    def is_terminal(self, board,H = 6, W = 7):
        for r in range(H):
            for c in range(W - 3):
                val = board[r, c]
                if val in [1, -1] and val == board[r, c+1] == board[r, c+2] == board[r, c+3]:
                    return True, val
        for r in range(H - 3):
            for c in range(W):
                val = board[r, c]
                if val in [1, -1] and val == board[r+1, c] == board[r+2, c] == board[r+3, c]:
                    return True, val
        for r in range(H - 3):
            for c in range(W - 3):
                val = board[r, c]
                if val in [1, -1] and val == board[r+1, c+1] == board[r+2, c+2] == board[r+3, c+3]:
                    return True, val
        for r in range(3, H):
            for c in range(W - 3):
                val = board[r, c]
                if val in [1, -1] and val == board[r-1, c+1] == board[r-2, c+2] == board[r-3, c+3]:
                    return True, val
        if not (board == 0).any():
            return True, 0
        return False, 0
        
    def stackedStates(self, state, current_player):
        
        return np.stack((
            state == current_player, 
            state == -current_player,
            state == 0
        )).astype(np.float32)

    
    def show(self, state):
        print(state)

In [4]:
import numpy as np

class Node:

    def __init__(self, state, parent, move, player, prob = 0):
        self.H = 6
        self.W = 7
        self.state = state
        self.parent = parent
        self.move = move

        self.player = player
        self.children = {}
        
        self.prob = prob
        
        self.N = 0
        self.W = 0
        
    def is_fully_expanded(self):
        return len(np.argwhere(self.state[0] == 0)) == 0

    
    def UCB1(self, c = 2):
        if self.N == 0:
            return float('inf')
        return self.W/self.N + c * self.prob * np.sqrt(self.parent.N)/(self.N + 1)

In [5]:
class MCTS:

    def __init__(self, env, root_state, root_player, model, device, PARAMS):
        self.env = env
        self.H = 6
        self.W = 7
        self.policy = 0
        self.model = model
        self.device = device
        self.root = Node(state = root_state,
                         parent = None,
                         move = None,
                         player = root_player
                        )
        self.PARAMS = PARAMS
        
    def selection(self):
        current = self.root
        
        while True:
            is_terminal, _ = self.env.is_terminal(current.state)
            if is_terminal:
                return current
            if current.is_fully_expanded():
                return current
            if len(current.children) == 0:
                return current
            best_child = max(current.children.values(), key = lambda c: c.UCB1())
            current = best_child
            
        return current

    def expansion(self, node, policy):
        valid_states = self.env.get_valid_moves(node.state)

        for action, prob in enumerate(policy):
            if valid_states[action] == 1 and action not in node.children:
                new_state = self.env.step(node.state, action, node.player)[0]
                child = Node(state = new_state,
                             parent = node,
                             move = action,
                             player = -node.player,
                             prob = prob
                            )
                node.children[action] = child
                
    def backpropagation(self, child, value):
        current = child
        while current is not None:
            current.N += 1
            current.W += value
            value = -value
            current = current.parent

    @torch.no_grad()
    def search(self):
    
        for _ in range(self.PARAMS['SEARCHES']):
            node = self.selection()
            is_terminal, value = self.env.is_terminal(node.state)
            
            if is_terminal:
                value = value * node.player
            else:
                stackedStates = self.env.stackedStates(node.state, node.player)
                policy, value = self.model(torch.tensor(stackedStates, dtype = torch.float).unsqueeze(0).to(self.device))
                policy = torch.softmax(policy, dim = 1).squeeze(0).cpu().detach().numpy()
                if node is self.root:
                    alpha = self.PARAMS['ALPHA']
                    eps = self.PARAMS['EPSILON']
                    noise = np.random.dirichlet([alpha] * len(self.env.all_moves))
                    policy = (1 - eps) * policy + eps * noise

                value = value.squeeze(0).cpu().item()
                valid_moves = self.env.get_valid_moves(node.state)
                policy = policy * valid_moves
                if np.sum(policy)>0:
                    policy = policy/(np.sum(policy))
                else:
                    policy = valid_moves/np.sum(valid_moves)
    
                self.expansion(node, policy)
                if len(node.children) > 0:
                    best_action = max(node.children.keys(), key=lambda a: node.children[a].prob)
                    child = node.children[best_action]
                else:
                    child = node
            target_node = node if is_terminal else child
    
            self.backpropagation(target_node, value)
        actions = np.zeros(7)

        if len(self.root.children) == 0:
            valid_moves = self.env.get_valid_moves(self.root.state)
            action_probs = np.array(valid_moves / np.sum(valid_moves))
            return action_probs

        total_visits = sum(child.N for child in self.root.children.values())
        if total_visits > 0:
            for action, child in self.root.children.items():
                actions[action] = child.N
            actions /= total_visits
        else:
            valid_moves = self.env.get_valid_moves(self.root.state)
            actions = np.array(valid_moves / np.sum(valid_moves))
        return actions

    def __repr__(self):
        return f"{self.root.state}\n{self.root.parent}\n{self.root.children}\n{self.root.action}\n{self.root.children.n}"

In [6]:
class AlphaZero:

    def __init__(self, env, model, policy_loss, value_loss, optimizer, device, PARAMS, current_player = 1):
        self.env = env
        self.model = model
        self.model.to(device)
        self.optimizer = optimizer
        self.mainBuffer = deque(maxlen = 30000)
        self.policy_loss = policy_loss
        self.value_loss = value_loss
        self.device = device
        self.current_player = current_player
        self.tau = 1.25
        self.PARAMS = PARAMS

    
    def selfPlayData(self, ):
        buffer = []
        state = self.env.reset()
        current_player = 1
        
        while True:
            root = MCTS(self.env, state, current_player, self.model, self.device, self.PARAMS)

            actions = root.search()
            temperature_probs = actions ** (1/self.PARAMS['TAU'])
            temperature_probs /= np.sum(temperature_probs)
            buffer.append((state, actions, current_player))
            action = np.random.choice(self.env.all_moves, p = temperature_probs)
            state, result, done = self.env.step(state, action, current_player)
            if done:
                mainBuffer = []
                for states, actions, player in buffer:
                    if result == 0:
                        value = 0
                    elif result == player:
                        value = 1
                    else:
                        value = -1
                    mainBuffer.append((self.env.stackedStates(states, current_player), actions, value))
                return mainBuffer

            current_player = -current_player

    def train(self):
        for i in trange(self.PARAMS['TOTAL_ITERATIONS']):
            self.model.eval()
            self.mainBuffer = []
            for _ in trange(self.PARAMS['SELF_PLAY_ITERATIONS']):
                self.mainBuffer.extend(self.selfPlayData())
            
            self.model.train()
            losses = []
            epoch_loss = 0
            previous_model = self.model
            for epoch in trange(self.PARAMS['EPOCHS']):

                epoch_loss = 0
                batch_count = 0
                np.random.shuffle(self.mainBuffer)
                for idx in range(0, len(self.mainBuffer), self.PARAMS["BATCH_SIZE"]):
                    data = self.mainBuffer[idx:idx+self.PARAMS["BATCH_SIZE"]]
                    states, actions, value = zip(*data)
    
                    actions = torch.tensor(np.array(actions), dtype = torch.float).to(self.device)
                    values = torch.tensor(value, dtype = torch.float).unsqueeze(1).to(self.device)
                    states = torch.tensor(states, dtype = torch.float).to(self.device)

                    model_policy, model_value = self.model(states)

                    policy_loss = self.policy_loss(model_policy, actions)
                    value_loss = self.value_loss(model_value, values)
                    total_loss = policy_loss + value_loss
                    
                    self.optimizer.zero_grad()
                    total_loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                    self.optimizer.step()            

                    epoch_loss += total_loss.item()
                    batch_count += 1

                avg_loss = epoch_loss/batch_count
                    
                print(f"Epoch {epoch + 1}\tTotal Loss : {avg_loss}")
                losses.append(avg_loss)
            
            self.evaluation(previous_model)
            self.save_model()

    def save_model(self):

        checkpoint = {
            "model" : self.model.state_dict(),
            "optimizer_state" : self.optimizer.state_dict()
        }
        torch.save(checkpoint, 'checkpoint.pth')
        print("Hit Checkpoint")
        
        
    

    def evaluation(self, pmodel):
        wins = 0
        self.model.eval()

        for game in trange(self.PARAMS['EVALUATION_GAMES']):
                state = self.env.reset()
                current_player = 1

                while True:
                    root = MCTS(self.env, state, current_player, self.model, self.device, self.PARAMS)
                    actions = root.search()
                    action = np.argmax(actions)
                    state, result, done = self.env.step(state, action, current_player)
                    if done:
                        if result == current_player:
                            wins += 1
                        break
                    current_player = -current_player

                    root = MCTS(self.env, state, current_player, pmodel, self.device, self.PARAMS)
                    actions = root.search()
                    action = np.argmax(actions)
                    state, result, done = self.env.step(state, action, current_player)
                    if done:
                        break
                    current_player = -current_player

        win_rate = wins/self.PARAMS["EVALUATION_GAMES"]
        print(f"Evaluation : {wins}/{self.PARAMS['EVALUATION_GAMES']} Wins : {wins}")
                            

In [24]:
PARAMETERS = {
    
    "IN_CHANNELS" : 3,
    "OUT_CHANNELS" : 64,
    "RESNET_BLOCKS" : 5,
    
    "SELF_PLAY_ITERATIONS" : 500,
    "EPOCHS" : 20,
    "BATCH_SIZE" : 128,
    "TOTAL_ITERATIONS" : 20,
    
    "SEARCHES" : 50,
    "EVALUATION_GAMES" : 20,
    
    "ALPHA" : 0.3,
    "EPSILON" : 0.25,
    "TAU" : 1.25
    }

In [8]:
env = Connect4()

model = Resnet(PARAMETERS['IN_CHANNELS'], 
               PARAMETERS['OUT_CHANNELS'], 
               PARAMETERS['RESNET_BLOCKS']
              )

optimizer = optim.Adam(model.parameters(), lr = 0.001, weight_decay = 1e-4)
policy_loss = nn.CrossEntropyLoss()
value_loss = nn.MSELoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

agent= AlphaZero(env, 
                 model, 
                 policy_loss, 
                 value_loss, 
                 optimizer, 
                 device, 
                 PARAMETERS
                )

agent.train()

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  states = torch.tensor(states, dtype = torch.float).to(self.device)


Epoch 1	Total Loss : 2.7432805454289473
Epoch 2	Total Loss : 2.3888552718692355
Epoch 3	Total Loss : 2.2750759809105485
Epoch 4	Total Loss : 2.196598357624478
Epoch 5	Total Loss : 2.1468663668190993
Epoch 6	Total Loss : 2.1211720020682723
Epoch 7	Total Loss : 2.0944228768348694
Epoch 8	Total Loss : 2.0807184599064015
Epoch 9	Total Loss : 2.050313030128126
Epoch 10	Total Loss : 2.053801982491105
Epoch 11	Total Loss : 2.034401912380148
Epoch 12	Total Loss : 2.0273858750308
Epoch 13	Total Loss : 2.0292916816693767
Epoch 14	Total Loss : 2.0161350943424083
Epoch 15	Total Loss : 2.0096707730381578
Epoch 16	Total Loss : 2.0008855561415353
Epoch 17	Total Loss : 1.9935759416332952
Epoch 18	Total Loss : 1.9892831919369873
Epoch 19	Total Loss : 1.9857241345776453
Epoch 20	Total Loss : 1.9824634326828852


  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation : 10/20 Wins : 10
Hit Checkpoint


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1	Total Loss : 2.337997275370139
Epoch 2	Total Loss : 2.1385268535878925
Epoch 3	Total Loss : 2.05045297521132
Epoch 4	Total Loss : 1.9891632276552695
Epoch 5	Total Loss : 1.9722012910577986
Epoch 6	Total Loss : 1.9434998609401561
Epoch 7	Total Loss : 1.923920484604659
Epoch 8	Total Loss : 1.8908545750158805
Epoch 9	Total Loss : 1.8733077424543876
Epoch 10	Total Loss : 1.85902147381394
Epoch 11	Total Loss : 1.8418533007303874
Epoch 12	Total Loss : 1.8253557428165719
Epoch 13	Total Loss : 1.8012184191633154
Epoch 14	Total Loss : 1.78668659152808
Epoch 15	Total Loss : 1.7689062588744693
Epoch 16	Total Loss : 1.7515238479331687
Epoch 17	Total Loss : 1.7302292927547738
Epoch 18	Total Loss : 1.7180739321090557
Epoch 19	Total Loss : 1.6957655063381902
Epoch 20	Total Loss : 1.6747384832965002


  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation : 11/20 Wins : 11
Hit Checkpoint


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1	Total Loss : 2.206475030391588
Epoch 2	Total Loss : 1.9768053151051932
Epoch 3	Total Loss : 1.8893096403244438
Epoch 4	Total Loss : 1.8251816592085253
Epoch 5	Total Loss : 1.7848308020775472
Epoch 6	Total Loss : 1.7535849131575418
Epoch 7	Total Loss : 1.7503447292047902
Epoch 8	Total Loss : 1.7085471415738447
Epoch 9	Total Loss : 1.7004295576603041
Epoch 10	Total Loss : 1.696487798603303
Epoch 11	Total Loss : 1.6498172119123127
Epoch 12	Total Loss : 1.6512919095678067
Epoch 13	Total Loss : 1.6395103012749908
Epoch 14	Total Loss : 1.626259811427615
Epoch 15	Total Loss : 1.622946087373506
Epoch 16	Total Loss : 1.5940793687050496
Epoch 17	Total Loss : 1.6016861265952433
Epoch 18	Total Loss : 1.5926500974445168
Epoch 19	Total Loss : 1.5823424420225511
Epoch 20	Total Loss : 1.5412641238728795


  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation : 3/20 Wins : 3
Hit Checkpoint


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1	Total Loss : 2.116706773367795
Epoch 2	Total Loss : 1.8490994388406927
Epoch 3	Total Loss : 1.7573101314631376
Epoch 4	Total Loss : 1.6896341139620001
Epoch 5	Total Loss : 1.6488130135969683
Epoch 6	Total Loss : 1.6243773861364885
Epoch 7	Total Loss : 1.5898540778593584
Epoch 8	Total Loss : 1.5615574609149585
Epoch 9	Total Loss : 1.5385817376050082
Epoch 10	Total Loss : 1.5146176511591132
Epoch 11	Total Loss : 1.500516992265528
Epoch 12	Total Loss : 1.4920683741569518
Epoch 13	Total Loss : 1.473624476519498
Epoch 14	Total Loss : 1.46081819859418
Epoch 15	Total Loss : 1.4539080825718966
Epoch 16	Total Loss : 1.452370979569175
Epoch 17	Total Loss : 1.433916501565413
Epoch 18	Total Loss : 1.431006162816828
Epoch 19	Total Loss : 1.423799368468198
Epoch 20	Total Loss : 1.4209308244965293


  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation : 8/20 Wins : 8
Hit Checkpoint


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1	Total Loss : 2.0588204772384078
Epoch 2	Total Loss : 1.80628576764354
Epoch 3	Total Loss : 1.7052434815300836
Epoch 4	Total Loss : 1.6450863590946905
Epoch 5	Total Loss : 1.6029961186426658
Epoch 6	Total Loss : 1.5790563965285267
Epoch 7	Total Loss : 1.550825297832489
Epoch 8	Total Loss : 1.5264461172951593
Epoch 9	Total Loss : 1.5056547103104767
Epoch 10	Total Loss : 1.4870495277422446
Epoch 11	Total Loss : 1.4817724404511627
Epoch 12	Total Loss : 1.458451341699671
Epoch 13	Total Loss : 1.4505755757844006
Epoch 14	Total Loss : 1.4449674608530823
Epoch 15	Total Loss : 1.4368550137237266
Epoch 16	Total Loss : 1.41942548972589
Epoch 17	Total Loss : 1.4102774637716788
Epoch 18	Total Loss : 1.4068348440859053
Epoch 19	Total Loss : 1.4013730320665572
Epoch 20	Total Loss : 1.3987567491001553


  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation : 5/20 Wins : 5
Hit Checkpoint


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1	Total Loss : 2.083868324188959
Epoch 2	Total Loss : 1.8046649966921124
Epoch 3	Total Loss : 1.7092435961677914
Epoch 4	Total Loss : 1.6367086626234508
Epoch 5	Total Loss : 1.5847464447929747
Epoch 6	Total Loss : 1.5396705729620797
Epoch 7	Total Loss : 1.5165526946385701
Epoch 8	Total Loss : 1.4927082992735363
Epoch 9	Total Loss : 1.4704239220846267
Epoch 10	Total Loss : 1.4501391161055792
Epoch 11	Total Loss : 1.4416071846371605
Epoch 12	Total Loss : 1.4284972054617746
Epoch 13	Total Loss : 1.4097765150524322
Epoch 14	Total Loss : 1.400945614633106
Epoch 15	Total Loss : 1.3884067240215483
Epoch 16	Total Loss : 1.3820077952884493
Epoch 17	Total Loss : 1.3808278333573114
Epoch 18	Total Loss : 1.3749168225697108
Epoch 19	Total Loss : 1.3827991167704263
Epoch 20	Total Loss : 1.3677416472207933


  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation : 7/20 Wins : 7
Hit Checkpoint


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1	Total Loss : 2.0426302000740977
Epoch 2	Total Loss : 1.785800573981811
Epoch 3	Total Loss : 1.6871592051514954
Epoch 4	Total Loss : 1.6370320041603017
Epoch 5	Total Loss : 1.58816635385852
Epoch 6	Total Loss : 1.5472683182386595
Epoch 7	Total Loss : 1.5198937509661523
Epoch 8	Total Loss : 1.49620120102
Epoch 9	Total Loss : 1.4709333961255082
Epoch 10	Total Loss : 1.4640224069078391
Epoch 11	Total Loss : 1.4537063915038777
Epoch 12	Total Loss : 1.4390092074314011
Epoch 13	Total Loss : 1.422025042159535
Epoch 14	Total Loss : 1.4089051507343755
Epoch 15	Total Loss : 1.401145125103888
Epoch 16	Total Loss : 1.4030519534494275
Epoch 17	Total Loss : 1.4029042208306144
Epoch 18	Total Loss : 1.3941384475921916
Epoch 19	Total Loss : 1.3806323024714104
Epoch 20	Total Loss : 1.3768546915499964


  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation : 8/20 Wins : 8
Hit Checkpoint


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1	Total Loss : 1.9939045337362027
Epoch 2	Total Loss : 1.7399330576625438
Epoch 3	Total Loss : 1.6410349957439878
Epoch 4	Total Loss : 1.5678976293003888
Epoch 5	Total Loss : 1.5219902193874395
Epoch 6	Total Loss : 1.488369794066893
Epoch 7	Total Loss : 1.4663947514437754
Epoch 8	Total Loss : 1.4421983279219461
Epoch 9	Total Loss : 1.4199095450410055
Epoch 10	Total Loss : 1.4044957696844678
Epoch 11	Total Loss : 1.3951166780716782
Epoch 12	Total Loss : 1.381870091508288
Epoch 13	Total Loss : 1.3686826163475667
Epoch 14	Total Loss : 1.3553289743738437
Epoch 15	Total Loss : 1.352611694860896
Epoch 16	Total Loss : 1.3494915787233126
Epoch 17	Total Loss : 1.3532555803246455
Epoch 18	Total Loss : 1.357865571975708
Epoch 19	Total Loss : 1.3482076213994156
Epoch 20	Total Loss : 1.3378612426442837


  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation : 8/20 Wins : 8
Hit Checkpoint


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1	Total Loss : 2.017188572221332
Epoch 2	Total Loss : 1.7410552711398513
Epoch 3	Total Loss : 1.6274328551910542
Epoch 4	Total Loss : 1.5537651413016849
Epoch 5	Total Loss : 1.5097005510771717
Epoch 6	Total Loss : 1.471195348986873
Epoch 7	Total Loss : 1.445181483471835
Epoch 8	Total Loss : 1.427548447140941
Epoch 9	Total Loss : 1.4042034690026883
Epoch 10	Total Loss : 1.3915481291435383
Epoch 11	Total Loss : 1.3782726724942524
Epoch 12	Total Loss : 1.3685916077207636
Epoch 13	Total Loss : 1.353641539812088
Epoch 14	Total Loss : 1.3547370764944289
Epoch 15	Total Loss : 1.3459642010706443
Epoch 16	Total Loss : 1.3277132742934756
Epoch 17	Total Loss : 1.3251852967120983
Epoch 18	Total Loss : 1.3301227456993527
Epoch 19	Total Loss : 1.3201418033352605
Epoch 20	Total Loss : 1.3208986101327118


  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation : 10/20 Wins : 10
Hit Checkpoint


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1	Total Loss : 1.9263224484665054
Epoch 2	Total Loss : 1.6732527890375681
Epoch 3	Total Loss : 1.5546883993915148
Epoch 4	Total Loss : 1.4909592015402657
Epoch 5	Total Loss : 1.4524572928036963
Epoch 6	Total Loss : 1.410373944256987
Epoch 7	Total Loss : 1.381576025060245
Epoch 8	Total Loss : 1.3544027592454637
Epoch 9	Total Loss : 1.3349893572075027
Epoch 10	Total Loss : 1.3307970400367464
Epoch 11	Total Loss : 1.3168870604463987
Epoch 12	Total Loss : 1.306666826563222
Epoch 13	Total Loss : 1.2989382754479135
Epoch 14	Total Loss : 1.2884962686470576
Epoch 15	Total Loss : 1.2895427271723747
Epoch 16	Total Loss : 1.2744852102228574
Epoch 17	Total Loss : 1.2685637431485313
Epoch 18	Total Loss : 1.2727129746760641
Epoch 19	Total Loss : 1.267725640109607
Epoch 20	Total Loss : 1.2658097382102693


  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation : 11/20 Wins : 11
Hit Checkpoint


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1	Total Loss : 1.896825294236879
Epoch 2	Total Loss : 1.6455212058247746
Epoch 3	Total Loss : 1.5386970064661525
Epoch 4	Total Loss : 1.472291798204989
Epoch 5	Total Loss : 1.4255447752840884
Epoch 6	Total Loss : 1.3861117706642494
Epoch 7	Total Loss : 1.3577554923994046
Epoch 8	Total Loss : 1.346524666021536
Epoch 9	Total Loss : 1.3231932590673636
Epoch 10	Total Loss : 1.3048456359553982
Epoch 11	Total Loss : 1.2907289502856967
Epoch 12	Total Loss : 1.282976468404134
Epoch 13	Total Loss : 1.280739620999173
Epoch 14	Total Loss : 1.2719727604238837
Epoch 15	Total Loss : 1.2711187644047781
Epoch 16	Total Loss : 1.2755904573578019
Epoch 17	Total Loss : 1.2630388124569043
Epoch 18	Total Loss : 1.2621967158876024
Epoch 19	Total Loss : 1.2489985893438529
Epoch 20	Total Loss : 1.2532505355439745


  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation : 7/20 Wins : 7
Hit Checkpoint


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1	Total Loss : 1.8450007839540465
Epoch 2	Total Loss : 1.6151954895627183
Epoch 3	Total Loss : 1.5034066079992108
Epoch 4	Total Loss : 1.4388105848194224
Epoch 5	Total Loss : 1.3928646007470324
Epoch 6	Total Loss : 1.363127466851631
Epoch 7	Total Loss : 1.33531300894982
Epoch 8	Total Loss : 1.3109516432854982
Epoch 9	Total Loss : 1.288666445597083
Epoch 10	Total Loss : 1.276175253159177
Epoch 11	Total Loss : 1.2666088300468648
Epoch 12	Total Loss : 1.2563406197370681
Epoch 13	Total Loss : 1.2422516894551505
Epoch 14	Total Loss : 1.2470813335570614
Epoch 15	Total Loss : 1.2458859882523527
Epoch 16	Total Loss : 1.2423554616691792
Epoch 17	Total Loss : 1.2345987220781038
Epoch 18	Total Loss : 1.2344808694535652
Epoch 19	Total Loss : 1.2277978901314524
Epoch 20	Total Loss : 1.2243205066275809


  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation : 13/20 Wins : 13
Hit Checkpoint


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1	Total Loss : 1.7985841495181443
Epoch 2	Total Loss : 1.5797094205103883
Epoch 3	Total Loss : 1.4779447535856054
Epoch 4	Total Loss : 1.4043837249825855
Epoch 5	Total Loss : 1.3618458870354049
Epoch 6	Total Loss : 1.3287419933791553
Epoch 7	Total Loss : 1.300709277118018
Epoch 8	Total Loss : 1.2799259391399698
Epoch 9	Total Loss : 1.266778615636563
Epoch 10	Total Loss : 1.251825948373987
Epoch 11	Total Loss : 1.2436853111337085
Epoch 12	Total Loss : 1.229133214425603
Epoch 13	Total Loss : 1.2225186398269934
Epoch 14	Total Loss : 1.216623009891685
Epoch 15	Total Loss : 1.2206744600873474
Epoch 16	Total Loss : 1.2110706294348481
Epoch 17	Total Loss : 1.2109298498258678
Epoch 18	Total Loss : 1.209184084463557
Epoch 19	Total Loss : 1.2112296327538448
Epoch 20	Total Loss : 1.204504365221076


  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation : 12/20 Wins : 12
Hit Checkpoint


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1	Total Loss : 1.7469826852325845
Epoch 2	Total Loss : 1.5486741266419402
Epoch 3	Total Loss : 1.45459589810498
Epoch 4	Total Loss : 1.396093201848258
Epoch 5	Total Loss : 1.3471353022398147
Epoch 6	Total Loss : 1.3122334163800804
Epoch 7	Total Loss : 1.286276478682999
Epoch 8	Total Loss : 1.2695468980654152
Epoch 9	Total Loss : 1.2545491182698614
Epoch 10	Total Loss : 1.2473821808806562
Epoch 11	Total Loss : 1.2329172191366684
Epoch 12	Total Loss : 1.2239178482410127
Epoch 13	Total Loss : 1.2209930694208735
Epoch 14	Total Loss : 1.2205110049880712
Epoch 15	Total Loss : 1.2156170733207095
Epoch 16	Total Loss : 1.2070782079105884
Epoch 17	Total Loss : 1.2045751957766777
Epoch 18	Total Loss : 1.2020954516081683
Epoch 19	Total Loss : 1.2053306102752686
Epoch 20	Total Loss : 1.2061088749792723


  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation : 8/20 Wins : 8
Hit Checkpoint


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1	Total Loss : 1.7417286851189353
Epoch 2	Total Loss : 1.5474391070279208
Epoch 3	Total Loss : 1.451599244637923
Epoch 4	Total Loss : 1.3807884552262046
Epoch 5	Total Loss : 1.3262858217412776
Epoch 6	Total Loss : 1.2923163479024713
Epoch 7	Total Loss : 1.2713009411638434
Epoch 8	Total Loss : 1.2483467838980935
Epoch 9	Total Loss : 1.2290256023406982
Epoch 10	Total Loss : 1.2189261523160067
Epoch 11	Total Loss : 1.2154327869415282
Epoch 12	Total Loss : 1.2040077166123824
Epoch 13	Total Loss : 1.2030042030594565
Epoch 14	Total Loss : 1.1922303676605224
Epoch 15	Total Loss : 1.1905117533423684
Epoch 16	Total Loss : 1.1910585024140097
Epoch 17	Total Loss : 1.1901806007732045
Epoch 18	Total Loss : 1.1805103540420532
Epoch 19	Total Loss : 1.1794963218949057
Epoch 20	Total Loss : 1.1826800097118724


  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation : 11/20 Wins : 11
Hit Checkpoint


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1	Total Loss : 1.7186550857784513
Epoch 2	Total Loss : 1.5381695042859327
Epoch 3	Total Loss : 1.4298549353539407
Epoch 4	Total Loss : 1.3580723399514552
Epoch 5	Total Loss : 1.3024517458838385
Epoch 6	Total Loss : 1.2680578049238738
Epoch 7	Total Loss : 1.2421650736181586
Epoch 8	Total Loss : 1.2240783642004203
Epoch 9	Total Loss : 1.2053848333186932
Epoch 10	Total Loss : 1.1924430486318227
Epoch 11	Total Loss : 1.1874955924781594
Epoch 12	Total Loss : 1.181464062080727
Epoch 13	Total Loss : 1.1715291684812255
Epoch 14	Total Loss : 1.174089230932631
Epoch 15	Total Loss : 1.1611365434285756
Epoch 16	Total Loss : 1.164523536020571
Epoch 17	Total Loss : 1.161658077626615
Epoch 18	Total Loss : 1.1612214021854572
Epoch 19	Total Loss : 1.1574167474970087
Epoch 20	Total Loss : 1.1492734363487176


  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation : 10/20 Wins : 10
Hit Checkpoint


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1	Total Loss : 1.8031715998607398
Epoch 2	Total Loss : 1.6007994466123328
Epoch 3	Total Loss : 1.4939188471937601
Epoch 4	Total Loss : 1.4307976570804561
Epoch 5	Total Loss : 1.3849995347250879
Epoch 6	Total Loss : 1.3518108002907407
Epoch 7	Total Loss : 1.3271693866864769
Epoch 8	Total Loss : 1.3021923324703115
Epoch 9	Total Loss : 1.2852181459950134
Epoch 10	Total Loss : 1.2765270766958725
Epoch 11	Total Loss : 1.2620165021018643
Epoch 12	Total Loss : 1.2553642924907988
Epoch 13	Total Loss : 1.2542275760026105
Epoch 14	Total Loss : 1.249103765572067
Epoch 15	Total Loss : 1.2414118747795577
Epoch 16	Total Loss : 1.2476942307126206
Epoch 17	Total Loss : 1.243776508137188
Epoch 18	Total Loss : 1.2374433817061703
Epoch 19	Total Loss : 1.236423461838106
Epoch 20	Total Loss : 1.2287024111874336


  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation : 14/20 Wins : 14
Hit Checkpoint


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1	Total Loss : 1.6995865933529966
Epoch 2	Total Loss : 1.5272452025800138
Epoch 3	Total Loss : 1.432543192897831
Epoch 4	Total Loss : 1.3722290563153792
Epoch 5	Total Loss : 1.3184818287153501
Epoch 6	Total Loss : 1.277857586070224
Epoch 7	Total Loss : 1.2499746427879677
Epoch 8	Total Loss : 1.2306889768119331
Epoch 9	Total Loss : 1.2189100723008852
Epoch 10	Total Loss : 1.2104992941693142
Epoch 11	Total Loss : 1.1984807177706882
Epoch 12	Total Loss : 1.1887731401769965
Epoch 13	Total Loss : 1.18048937148876
Epoch 14	Total Loss : 1.1833068669379294
Epoch 15	Total Loss : 1.1767250965307425
Epoch 16	Total Loss : 1.1716758044990334
Epoch 17	Total Loss : 1.165674885114034
Epoch 18	Total Loss : 1.1722815208606892
Epoch 19	Total Loss : 1.1616099439225755
Epoch 20	Total Loss : 1.1594447896287248


  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation : 12/20 Wins : 12
Hit Checkpoint


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1	Total Loss : 1.6401955025237904
Epoch 2	Total Loss : 1.4700805408912792
Epoch 3	Total Loss : 1.3725352339577257
Epoch 4	Total Loss : 1.315495110394662
Epoch 5	Total Loss : 1.264709888843068
Epoch 6	Total Loss : 1.2273593314907008
Epoch 7	Total Loss : 1.2003587243849771
Epoch 8	Total Loss : 1.1871207251883389
Epoch 9	Total Loss : 1.1710691169688576
Epoch 10	Total Loss : 1.1536897546366642
Epoch 11	Total Loss : 1.1449648873847829
Epoch 12	Total Loss : 1.1442200321900218
Epoch 13	Total Loss : 1.139270633981939
Epoch 14	Total Loss : 1.1360356933192204
Epoch 15	Total Loss : 1.1298904324832715
Epoch 16	Total Loss : 1.124364157517751
Epoch 17	Total Loss : 1.1194941757018106
Epoch 18	Total Loss : 1.1241538566455507
Epoch 19	Total Loss : 1.1303863922754924
Epoch 20	Total Loss : 1.1196770223609187


  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation : 14/20 Wins : 14
Hit Checkpoint


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1	Total Loss : 1.6315011860014081
Epoch 2	Total Loss : 1.4575936074729439
Epoch 3	Total Loss : 1.36762746712109
Epoch 4	Total Loss : 1.3070482445192766
Epoch 5	Total Loss : 1.2585936505515296
Epoch 6	Total Loss : 1.2230820387333363
Epoch 7	Total Loss : 1.197868503965773
Epoch 8	Total Loss : 1.1819635414862417
Epoch 9	Total Loss : 1.167873273024688
Epoch 10	Total Loss : 1.156662035632778
Epoch 11	Total Loss : 1.1518864556475803
Epoch 12	Total Loss : 1.1429979103105563
Epoch 13	Total Loss : 1.1385561850693848
Epoch 14	Total Loss : 1.1357672858882595
Epoch 15	Total Loss : 1.1280222422367818
Epoch 16	Total Loss : 1.1253701373263523
Epoch 17	Total Loss : 1.1214470723727803
Epoch 18	Total Loss : 1.1219836204975575
Epoch 19	Total Loss : 1.1201780959292575
Epoch 20	Total Loss : 1.127270497717299


  0%|          | 0/20 [00:00<?, ?it/s]

Evaluation : 16/20 Wins : 16
Hit Checkpoint


In [25]:
env = Connect4()

model = Resnet(PARAMETERS['IN_CHANNELS'], 
               PARAMETERS['OUT_CHANNELS'], 
               PARAMETERS['RESNET_BLOCKS']
              )


# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device= "cpu"
model_checkpoint = torch.load("/kaggle/input/checkpoint1/pytorch/default/1/checkpoint1.pth", map_location= "cpu")
model.load_state_dict(model_checkpoint["model"])

model.eval()

state = env.reset()
current_player = 1

while True:
    root = MCTS(env, state, current_player, model, device, PARAMETERS)
    actions = root.search()
    action = np.argmax(actions)
    state, result, done = env.step(state, action, current_player)
    env.show(state)
    if done:
        if result == current_player:
            wins += 1
            print("Bot Won!!")
        break
    current_player = -current_player

    print(f"Available Moves : {np.where(env.get_valid_moves(state)==1)}")
    action = int(input("Enter action (0-6) : "))
    state, result, done = env.step(state, action, current_player)
    if done:
        if result == current_player:
            print("You Won!!")
        break
    current_player = -current_player
    env.show(state)


[[0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0.]]
Available Moves : (array([0, 1, 2, 3, 4, 5, 6]),)


Enter action (0-6) :  6


[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  1.  0.  0.  0. -1.]]
[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  1.  0.  0.  0.  0.]
 [ 0.  0.  1.  0.  0.  0. -1.]]
Available Moves : (array([0, 1, 2, 3, 4, 5, 6]),)


Enter action (0-6) :  2


[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0. -1.  0.  0.  0.  0.]
 [ 0.  0.  1.  0.  0.  0.  0.]
 [ 0.  0.  1.  0.  0.  0. -1.]]
[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0. -1.  0.  0.  0.  0.]
 [ 0.  0.  1.  0.  0.  0.  0.]
 [ 0.  0.  1.  0.  1.  0. -1.]]
Available Moves : (array([0, 1, 2, 3, 4, 5, 6]),)


Enter action (0-6) :  3


[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0. -1.  0.  0.  0.  0.]
 [ 0.  0.  1.  0.  0.  0.  0.]
 [ 0.  0.  1. -1.  1.  0. -1.]]
[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0. -1.  0.  0.  0.  0.]
 [ 0.  0.  1.  0.  0.  0.  0.]
 [ 1.  0.  1. -1.  1.  0. -1.]]
Available Moves : (array([0, 1, 2, 3, 4, 5, 6]),)


Enter action (0-6) :  1


[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0. -1.  0.  0.  0.  0.]
 [ 0.  0.  1.  0.  0.  0.  0.]
 [ 1. -1.  1. -1.  1.  0. -1.]]
[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0. -1.  0.  0.  0.  0.]
 [ 1.  0.  1.  0.  0.  0.  0.]
 [ 1. -1.  1. -1.  1.  0. -1.]]
Available Moves : (array([0, 1, 2, 3, 4, 5, 6]),)


Enter action (0-6) :  3


[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0. -1.  0.  0.  0.  0.]
 [ 1.  0.  1. -1.  0.  0.  0.]
 [ 1. -1.  1. -1.  1.  0. -1.]]
[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0. -1.  0.  0.  0.  0.]
 [ 1.  0.  1. -1.  0.  0.  1.]
 [ 1. -1.  1. -1.  1.  0. -1.]]
Available Moves : (array([0, 1, 2, 3, 4, 5, 6]),)


Enter action (0-6) :  3


[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0. -1. -1.  0.  0.  0.]
 [ 1.  0.  1. -1.  0.  0.  1.]
 [ 1. -1.  1. -1.  1.  0. -1.]]
[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0. -1. -1.  0.  0.  0.]
 [ 1.  1.  1. -1.  0.  0.  1.]
 [ 1. -1.  1. -1.  1.  0. -1.]]
Available Moves : (array([0, 1, 2, 3, 4, 5, 6]),)


Enter action (0-6) :  3


You Won!!
