In [1]:
import random
import torch
import torch.optim as optim
import torch.nn as nn
from typing import List, Tuple, Union, Optional, Any
from env import ConnectFourEnv
from tqdm import tqdm
import math
import torch.nn.functional as F
from replay_buffer import ReplayBuffer
from monte_carlo import MonteCarloTreeSearchAgent 
from agent import RandomAgent

In [3]:
env = ConnectFourEnv()
env.step(0)
env.step(1)
env.step(2)
env.step(1)
env.step(3)
env.step(1)
# env.render()
monte = MonteCarloTreeSearchAgent(env, n_iterations=10000, c=1.4)
# monte.choose_action()



env.play(monte, RandomAgent(env), n_games=20)

 30%|███       | 6/20 [03:37<08:31, 36.53s/it]

In [None]:
class DQN(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_shape[0] * input_shape[1], 42),
            nn.ReLU(),
            nn.Linear(42, 24),
            nn.ReLU(),
            nn.Linear(24, num_actions)
        )

    def forward(self, x):
        return self.net(x)



class DQNAgent(Agent):
    def __init__(self, env, replay_buffer, evaluation_agent=None):
        super().__init__(env)
        self.name = "DQNAgent"
        self.state_dim = env.observation_space.shape
        self.action_dim = env.action_space.n
        self.replay_buffer: ReplayBuffer = replay_buffer
        self.policy_net = DQN(self.state_dim, self.action_dim)
        self.target_net = DQN(self.state_dim, self.action_dim)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()  # Target net is not trained
        self.lr = 0.005
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.lr)
        self.steps_done = 0
        self.epsilon_start = 1.0
        self.epsilon_end = 0.15 # if you leave overnight, you can decrease this to 0.01
        self.epsilon_decay = 500_000 # if you leave overnight, you can increase this to 1_000_000
        self.batch_size = 256
        self.gamma = 1  # Discount factor
        self.target_update = 1000
        if evaluation_agent is None:
            self.evaluation_agent = RandomAgent(self.env)
        else:
            self.evaluation_agent = evaluation_agent

    def load_trained_model_from_file(self, path):
        self.policy_net = torch.load(path)
        self.policy_net.eval()

    def choose_action(self, explore=False) -> int:
        sample = random.random()
        epsilon_threshold = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
            math.exp(-1. * self.steps_done / self.epsilon_decay)
        if sample > epsilon_threshold or not explore:
            with torch.no_grad():
                state = torch.tensor(
                    self.env.board, dtype=torch.float).unsqueeze(0)
                decision = self.policy_net(state)
                return decision.max(1)[1].view(1, 1).item()
        else:
            return random.randrange(self.action_dim)

    def optimize_model(self):
        if len(self.replay_buffer) < self.batch_size:
            return

        # Sample a batch of experiences from the replay buffer
        transitions = self.replay_buffer.sample(self.batch_size)
        batch = Transition(*zip(*transitions))

        # Separate the components of each transition
        batch_states = torch.stack(batch.state).float()
        batch_actions = torch.stack(batch.action).view(-1, 1).long()
        batch_rewards = torch.tensor(batch.reward, dtype=torch.float)
        batch_next_states = torch.stack(batch.next_state).float()
        batch_dones = torch.tensor(batch.done, dtype=torch.float)

        # Calculate current Q-values from the policy_net
        current_q_values = self.policy_net(batch_states).gather(
            1, batch_actions).squeeze(1)

        # Calculate the maximum Q-value for the next states from the target_net
        next_state_values = self.target_net(
            batch_next_states).max(1)[0].detach()
        

        # Apply (1 - done) to zero out the values for terminal states
        next_state_values = next_state_values * (1 - batch_dones)

        # Compute the expected Q values for the current state-action pairs
        expected_q_values = batch_rewards + self.gamma * next_state_values

        # Compute loss
        loss = F.smooth_l1_loss(current_q_values, expected_q_values)

        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def train(self, n_games, print_interval):
        losses = []

        self.evaluate(50, message=True)

        for episode in tqdm(range(n_games)):
            observation = self.env.reset()
            observation = torch.tensor(observation, dtype=torch.float)
            done = False

            while not done:
                action = self.choose_action(explore=True)
                self.steps_done += 1

                next_observation, reward, done, info = self.env.step(action)
                
                # the agent thinks he is player 1, so we need to flip the board and the player
                self.env.flip_board()
                self.env.current_player = 3 - self.env.current_player
                
                next_observation = torch.tensor(
                    next_observation, dtype=torch.float)

                reward = torch.tensor([reward], dtype=torch.float)
                done_tensor = torch.tensor([done], dtype=torch.float)

                self.replay_buffer.push(
                    observation, action, reward, next_observation, done_tensor)
                observation = next_observation

                loss = self.optimize_model()

                if loss is not None:
                    losses.append(loss)

            if (episode + 1) % self.target_update == 0:
                self.target_net.load_state_dict(self.policy_net.state_dict())

            if (episode + 1) % print_interval == 0 and len(losses) > 0:
                avg_loss = sum(losses[-print_interval:]) / len(losses[-print_interval:])
                print(f"Episode {episode + 1}: Average Loss = {avg_loss}")
                self.evaluate(50, message=True)
                self.evaluate(50, message=True, evaluation_agent=RandomAgent(self.env))
                print()
                torch.save(self.policy_net, f"checkpoints/model_{episode + 1}.pt")
               

    def evaluate(self, n_games, show=False, message=False, evaluation_agent=None):
        if evaluation_agent is None:
            evaluation_agent = self.evaluation_agent

        wins, avg_length = self.env.play(self, evaluation_agent, n_games, show)
        if message:
            print(f"Out of {n_games} games against {evaluation_agent.name}, the model won {wins[1]} games : {wins[1] / n_games * 100:.2f}% with an average game length of {avg_length}")
        return wins

In [None]:
# Train a new model
env = ConnectFourEnv()
replay_buffer = ReplayBuffer(10000)
agent = DQNAgent(env, replay_buffer)
# agent.load_trained_model_from_file("checkpoints/model_6000.pt")
agent.evaluation_agent = agent

In [None]:


agent.train(1_000_000, print_interval=1000)

In [None]:
env.play(RandomAgent(env), RandomAgent(env), n_games=1, show_outcome=True)

In [None]:
env.play(agent, agent, n_games=1, show_game=True, show_outcome=True)
# env.play(agent, agent, n_games=100)

In [None]:
class HumanAgent(Agent):
    def __init__(self, env):
        self.env = env

    def choose_action(self):
        action = int(input("Enter your move: "))
        while action not in range(self.env.columns) or self.env.playable_rows[action] == -1:
            print("Invalid move. Try again.")
            action = int(input("Enter your move: "))
        return action

In [None]:
# play against the trained model

env.play(agent, HumanAgent(env), n_games=1, show_game=True, show_outcome=True)

In [None]:
# save the model
torch.save(agent.policy_net, 'connect_four_model.pt')

In [None]:
# resume training

env = ConnectFourEnv()
replay_buffer = ReplayBuffer(10000)
agent = DQNAgent(env, replay_buffer)
agent.policy_net = torch.load('connect_four_model.pt')

In [None]:
agent.evaluate(1000, message=True)

In [None]:
# import profilers and check the training bottlenecks of the model

from torch.profiler import profile, record_function, ProfilerActivity


with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    agent.train(50, print_interval=100)

print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))