In [3]:
import random
import torch
import torch.optim as optim
import torch.nn as nn
from typing import List, Tuple, Union, Optional, Any
from connect_4_env import ConnectFourEnv
from tqdm import tqdm
import math
import torch.nn.functional as F
from replay_buffer import ReplayBuffer
from monte_carlo_agent import MonteCarloTreeSearchAgent 
from agent import RandomAgent, Agent
from human_agent import HumanAgent
import torch.optim.lr_scheduler as lr_scheduler

In [4]:
env = ConnectFourEnv()
monte10 = MonteCarloTreeSearchAgent(env, n_iterations=10)
monte100 = MonteCarloTreeSearchAgent(env, n_iterations=100)
monte1_000 = MonteCarloTreeSearchAgent(env, n_iterations=1_000)
monte10_000 = MonteCarloTreeSearchAgent(env, n_iterations=10_000)
monte50_000 = MonteCarloTreeSearchAgent(env, n_iterations=50_000)

# Test my own strength
# env.play(monte, HumanAgent(env), n_games=20)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
class DQN(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(DQN, self).__init__()
        self.input_shape = input_shape
        self.conv = nn.Conv2d(1, 128, kernel_size=4, stride=1, padding=0)
        self.fc1 = nn.Linear(128 * (input_shape[0] - 3) * (input_shape[1] - 3), 64)
        self.fc2 = nn.Linear(64, 64)
        self.output = nn.Linear(64, num_actions)

    def forward(self, x):
        x = x.view(-1, 1, self.input_shape[0], self.input_shape[1])
        x = F.relu(self.conv(x))
        x = x.view(x.size(0), -1) 
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.output(x)
        return x

class DQNAgent(Agent):
    def __init__(self, env, replay_buffer, evaluation_agent=None):
        super().__init__(env)
        self.name = "DQNAgent"
        self.state_dim = env.observation_space.shape
        self.action_dim = env.action_space.n
        self.replay_buffer: ReplayBuffer = replay_buffer
        self.policy_net = DQN(self.state_dim, self.action_dim).to(device)
        self.target_net = DQN(self.state_dim, self.action_dim).to(device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()  # Target net is not trained
        self.lr = 0.005
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.lr)
        # self.lr_scheduler = lr_scheduler.ExponentialLR(optimizer=self.optimizer, gamma=0.99)
        self.lr_scheduler = None
        self.steps_done = 0
        self.epsilon_start = 1.0
        self.epsilon_end = 0.15 # if you leave overnight, you can decrease this to 0.01
        self.epsilon_decay = 500_000 # if you leave overnight, you can increase this to 1_000_000
        self.batch_size = 256
        self.gamma = 0.99  # Discount factor
        self.target_update = 5000
        self.loss = nn.MSELoss()
        if evaluation_agent is None:
            self.evaluation_agent = RandomAgent(self.env)
        else:
            self.evaluation_agent = evaluation_agent

    def load_trained_model_from_file(self, path):
        self.policy_net = torch.load(path)
        self.policy_net.eval()

    def choose_action(self, explore=False) -> int:
        sample = random.random()
        epsilon_threshold = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
            math.exp(-1. * self.steps_done / self.epsilon_decay)
        legal_moves = self.env.get_legal_actions()
        if sample > epsilon_threshold or not explore:
            with torch.no_grad():
                state_tensor = torch.tensor(
                    self.env.board, dtype=torch.float).unsqueeze(0)
  
                predictions = self.policy_net(state_tensor)
                
                # Mask out illegal moves
                for i in range(self.env.action_space.n):
                    if i not in legal_moves:
                        predictions[0][i] = -float('inf')

                # Choose the action with the highest Q value that is legal and convert to int
                best_action = int(torch.argmax(predictions).item())
                return best_action
        else:
            return random.choice(legal_moves)

    def optimize_model(self):
        if len(self.replay_buffer) < self.batch_size:
            return

        # Sample a batch of experiences from the replay buffer
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(
            self.batch_size)

        # Separate the components of each transition
        states = torch.tensor(states, dtype=torch.float32).to(device)
        actions = torch.tensor(actions, dtype=torch.long).to(device)
        rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
        next_states = torch.tensor(next_states, dtype=torch.float32).to(device)
        dones = torch.tensor(dones, dtype=torch.float32).to(device)

        # Calculate current Q-values from the policy_net
        current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)
        next_state_values = self.target_net(next_states).max(1)[0]

        # Compute the expected Q values for the current state-action pairs
        expected_q_values = rewards + self.gamma * next_state_values * (1 - dones)

        # Compute loss
        loss = self.loss(current_q_values, expected_q_values)

        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def train(self, n_games, print_interval):
        losses = []

        self.evaluate(50, message=True)

        for episode in tqdm(range(n_games)):
            observation = self.env.reset()
            done = False

            while not done:
                action = self.choose_action(explore=True)
                self.steps_done += 1

                next_observation, reward, done, info = self.env.step(action)
                self.replay_buffer.push(
                    observation, action, reward, next_observation, done)
                
                # the agent thinks he is player 1, so we need to flip the board and the player
                self.env.flip_board()
                self.env.current_player = 3 - self.env.current_player

                loss = self.optimize_model()

                observation = next_observation

                if loss is not None:
                    losses.append(loss)

            if (episode + 1) % self.target_update == 0:
                self.target_net.load_state_dict(self.policy_net.state_dict())

            if (episode + 1) % print_interval == 0 and len(losses) > 0:
                avg_loss = sum(losses[-print_interval:]) / len(losses[-print_interval:])
                print(f"Episode {episode + 1}: Average Loss = {avg_loss}")
                self.evaluate(50, message=True)
                print()
                torch.save(self.policy_net, f"checkpoints/model_{episode + 1}.pt")
               

    def evaluate(self, n_games, show=False, message=False, evaluation_agent=None):
        if evaluation_agent is None:
            evaluation_agent = self.evaluation_agent

        wins, avg_length = self.env.play(self, evaluation_agent, n_games, show)
        if message:
            print(f"Out of {n_games} games against {evaluation_agent.name}, the model won {wins[1]} games : {wins[1] / n_games * 100:.2f}% with an average game length of {avg_length}")
        return wins

In [10]:
# Train a new model

env = ConnectFourEnv()
replay_buffer = ReplayBuffer(10000)
agent = DQNAgent(env, replay_buffer)
# agent.load_trained_model_from_file("checkpoints/model_6000.pt")

# print the number of parameters in the model
print(
    f"Number of parameters in the model: {sum(p.numel() for p in agent.policy_net.parameters())}")

Number of parameters in the model: 105159


In [11]:
agent.train(1_000_000, print_interval=100)

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [00:00<00:00, 196.06it/s]


Out of 50 games against RandomAgent, the model won 45 games : 90.00% with an average game length of 12.54


  0%|          | 99/1000000 [00:19<64:55:13,  4.28it/s]

Episode 100: Average Loss = 0.0009939866964123211


100%|██████████| 50/50 [00:00<00:00, 365.68it/s]
  0%|          | 100/1000000 [00:20<74:06:16,  3.75it/s]

Out of 50 games against RandomAgent, the model won 39 games : 78.00% with an average game length of 15.66



  0%|          | 199/1000000 [00:42<74:33:16,  3.73it/s] 

Episode 200: Average Loss = 0.003907338668068405


100%|██████████| 50/50 [00:00<00:00, 407.46it/s]
  0%|          | 200/1000000 [00:43<96:20:47,  2.88it/s]

Out of 50 games against RandomAgent, the model won 37 games : 74.00% with an average game length of 13.38



  0%|          | 299/1000000 [01:09<49:46:32,  5.58it/s] 

Episode 300: Average Loss = 0.0038855219157994726


100%|██████████| 50/50 [00:00<00:00, 225.14it/s]
  0%|          | 300/1000000 [01:09<68:22:28,  4.06it/s]

Out of 50 games against RandomAgent, the model won 40 games : 80.00% with an average game length of 16.04



  0%|          | 399/1000000 [01:34<64:53:16,  4.28it/s] 

Episode 400: Average Loss = 0.0031605233182199297


100%|██████████| 50/50 [00:00<00:00, 440.49it/s]
  0%|          | 400/1000000 [01:35<67:44:09,  4.10it/s]

Out of 50 games against RandomAgent, the model won 39 games : 78.00% with an average game length of 11.18



  0%|          | 499/1000000 [01:59<61:04:43,  4.55it/s] 

Episode 500: Average Loss = 0.0042245591018581765


100%|██████████| 50/50 [00:00<00:00, 409.03it/s]
  0%|          | 500/1000000 [01:59<74:05:51,  3.75it/s]

Out of 50 games against RandomAgent, the model won 44 games : 88.00% with an average game length of 13.96



  0%|          | 599/1000000 [02:24<46:24:06,  5.98it/s] 

Episode 600: Average Loss = 0.006626346647390164


100%|██████████| 50/50 [00:00<00:00, 444.92it/s]
  0%|          | 600/1000000 [02:24<64:17:07,  4.32it/s]

Out of 50 games against RandomAgent, the model won 44 games : 88.00% with an average game length of 12.2



  0%|          | 630/1000000 [02:31<66:44:38,  4.16it/s]


KeyboardInterrupt: 

In [None]:
env.play(RandomAgent(env), RandomAgent(env), n_games=10000)

In [None]:
# play against the trained model

env.play(agent, HumanAgent(env), n_games=1, show_game=True, show_outcome=True)

In [None]:
# save the model
torch.save(agent.policy_net, 'connect_four_model.pt')

In [None]:
# resume training

env = ConnectFourEnv()
replay_buffer = ReplayBuffer(10000)
agent = DQNAgent(env, replay_buffer)
agent.policy_net = torch.load('connect_four_model.pt')

In [None]:
agent.evaluate(1000, message=True)

In [12]:
# import profilers and check the training bottlenecks of the model

from torch.profiler import profile, record_function, ProfilerActivity


with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    agent.train(50, print_interval=100)

print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))

STAGE:2024-03-22 00:52:10 6158:258567 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
100%|██████████| 50/50 [00:00<00:00, 165.46it/s]


Out of 50 games against RandomAgent, the model won 41 games : 82.00% with an average game length of 15.7


100%|██████████| 50/50 [00:10<00:00,  4.73it/s]
STAGE:2024-03-22 00:52:23 6158:258567 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-03-22 00:52:23 6158:258567 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             aten::convolution_backward        25.46%        1.772s        25.63%        1.783s     928.671us          1920  
                               aten::mkldnn_convolution        14.09%     980.305ms        14.25%     991.197ms     425.406us          2330  
                                               aten::mm        13.60%     946.368ms        13.60%     946.459ms      82.158us         11520  
                                            aten::addmm         8.27%     575.178ms         8.86%     616.572ms      88.208us          6990  
      