In [None]:
!pip install stable-baselines3


Collecting stable-baselines3
  Downloading stable_baselines3-2.5.0-py3-none-any.whl.metadata (4.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3.0,>=2.3->stable-baselines3)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<3.0,>=2.3->stable-baselines3)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<3.0,>=2.3->stable-baselines3)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<3.0,>=2.3->stable-baselines3)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<3.0,>=2.3->stable-baselines3)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (

In [None]:
import numpy as np
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.dqn.policies import CnnPolicy, MultiInputPolicy
from stable_baselines3.common.torch_layers import NatureCNN

import gymnasium as gym
from gymnasium import spaces
import numpy as np


In [None]:
class Connect4Env(gym.Env):
    """Connect4 environment for DQN training."""
    def __init__(self):
        super().__init__()
        self.rows = 6
        self.cols = 7
        self.board = np.zeros((self.rows, self.cols), dtype=np.uint8)  # np.uint8
        self.current_player = 1
        self.action_space = spaces.Discrete(self.cols)
        self.observation_space = spaces.Box(low=0, high=255, shape=(1, self.rows, self.cols), dtype=np.uint8)  # channel-first observation space

    def reset(self, seed=None, options=None):
        """
        Reset the environment to the initial state.
        """
        # initialize board
        self.board = np.zeros((self.rows, self.cols), dtype=np.uint8)
        self.current_player = 1

        # returns channel first array
        return np.expand_dims(self.board, axis=0), {}

    def step(self, action):
        """
        Execute one step in the environment.
        """
        if action not in self.possible_actions():
            return np.expand_dims(self.board, axis=0), -1, True, False, {}

        row = self._drop_piece(action, self.current_player)
        if self._check_win(row, action):
            reward = 1 if self.current_player == 1 else -1
            terminated = True
            truncated = False
        elif np.all(self.board != 0):
            reward = 0
            terminated = True
            truncated = False
        else:
            reward = 0
            terminated = False
            truncated = False

        self.current_player = 2 if self.current_player == 1 else 1
        return np.expand_dims(self.board, axis=0), reward, terminated, truncated, {}

    def _drop_piece(self, col, player):
        """
        Drop a piece in the specified column for the given player.
        """
        for row in reversed(range(self.rows)):
            if self.board[row, col] == 0:
                self.board[row, col] = player
                return row
        return -1

    def _check_win(self, row, col):
        """
        Check if the last move resulted in a win.
        """
        # Check horizontal, vertical, and diagonal wins (leetcode BFS stuff lol)
        directions = [
            (0, 1),  # Horizontal
            (1, 0),  # Vertical
            (1, 1),  # Diagonal (top-left to bottom-right)
            (1, -1),  # Diagonal (top-right to bottom-left)
        ]
        for dr, dc in directions:
            count = 1
            for delta in [-1, 1]:
                r, c = row + delta * dr, col + delta * dc
                while 0 <= r < self.rows and 0 <= c < self.cols and self.board[r, c] == self.board[row, col]:
                    count += 1
                    r += delta * dr
                    c += delta * dc
            if count >= 4:
                return True
        return False

    def possible_actions(self):
        """
        Return a list of valid actions (columns that are not full).
        """
        return [col for col in range(self.cols) if self.board[0, col] == 0]

    def render(self):
        """
        Render the current state of the board.
        """
        print(self.board)
        print()

In [None]:
import torch.nn as nn
import torch
from stable_baselines3.common.torch_layers import NatureCNN
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

class CustomCNN(BaseFeaturesExtractor):
    """
    Custom CNN architecture for Connect 4.
    """
    def __init__(self, observation_space, features_dim=512, normalized_image=False):
        super().__init__(observation_space, features_dim)
        # CNN layers
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1), # made kernel size smaller
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Flatten(),
        )

        # output dimension of the CNN
        with torch.no_grad():
            sample_input = torch.as_tensor(observation_space.sample()[None]).float()
            n_flatten = self.cnn(sample_input).shape[1]

        # define the fully connected layer
        self.linear = nn.Sequential(
            nn.Linear(n_flatten, features_dim),
            nn.ReLU(),
        )

    def forward(self, observations):
        # forward pass through the CNN
        return self.linear(self.cnn(observations))

# train and save DQN with a specific policy
def train_and_save_dqn(policy, policy_name, total_timesteps=100_000):
    print(f"Training DQN with {policy_name} policy...")

    env = Connect4Env()
    env = DummyVecEnv([lambda: env])

    # DQN model with the CNN architecture
    model = DQN(
        policy,
        env,
        verbose=1,
        tensorboard_log=f"./dqn_{policy_name}_tensorboard/",
        policy_kwargs=dict(
            features_extractor_class=CustomCNN,
            features_extractor_kwargs=dict(normalized_image=False)
        )
    )

    model.learn(total_timesteps=total_timesteps)
    model.save(f"dqn_{policy_name}_connect4_model")
    print(f"Training complete. Model saved as dqn_{policy_name}_connect4_model.")

# train and save DQN with CNN Policy
train_and_save_dqn(CnnPolicy, "cnn", total_timesteps=100_000)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
|    time_elapsed     | 1149     |
|    total_timesteps  | 64142    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00548  |
|    n_updates        | 16010    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 6240     |
|    fps              | 55       |
|    time_elapsed     | 1151     |
|    total_timesteps  | 64230    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00726  |
|    n_updates        | 16032    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 6244     |
|    fps              | 55       |
|    time_elapsed     | 1