Importing Libraries

In [10]:
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
from copy import deepcopy
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
if device.type == "cuda":
    print("CUDA device name:", torch.cuda.get_device_name(0))

Using device: cuda
CUDA device name: NVIDIA GeForce RTX 3070 Ti Laptop GPU


In [11]:
def to_device(tensor):
    return tensor.to(device)

Class for connect4 environment

In [12]:
class Connect4Env:
    def __init__(self):
        self.rows = 6
        self.cols = 7
        self.reset()

    def reset(self):
        self.board = np.zeros((self.rows, self.cols), dtype=int)
        self.current_player = 1
        self.last_move = None
        return self._get_obs()

    def _get_obs(self):
        p1_board = (self.board == 1).astype(np.float32)
        p2_board = (self.board == -1).astype(np.float32)
        if self.current_player == 1:
            return np.stack([p1_board, p2_board], axis=0)
        else:
            return np.stack([p2_board, p1_board], axis=0)

    def valid_actions(self):
        return [c for c in range(self.cols) if self.board[0, c] == 0]

    def step(self, action):
        if action not in self.valid_actions():
            return self._get_obs(), -10, True, {}

        for r in range(self.rows - 1, -1, -1):
            if self.board[r, action] == 0:
                self.board[r, action] = self.current_player
                self.last_move = (r, action)
                break

        if self._check_winner(self.current_player, self.last_move):
            return self._get_obs(), 1, True, {}

        if len(self.valid_actions()) == 0:
            return self._get_obs(), 0, True, {}

        self.current_player *= -1
        return self._get_obs(), -0.01, False, {}

    def _check_winner(self, player, last_move):
        if last_move is None:
            return False
        r, c = last_move
        directions = [(1, 0), (0, 1), (1, 1), (1, -1)]
        for dr, dc in directions:
            count = 1
            for d in [-1, 1]:
                nr, nc = r + d * dr, c + d * dc
                while 0 <= nr < self.rows and 0 <= nc < self.cols and self.board[nr, nc] == player:
                    count += 1
                    if count >= 4:
                        return True
                    nr += d * dr
                    nc += d * dc
        return False



Model

In [13]:
class DQN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 6 * 7, 128)
        self.out = nn.Linear(128, 7)

    def forward(self, x):
        x = x.to(device)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.out(x)

Replay Buffer

In [14]:
class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)

    def push(self, transition):
        self.buffer.append(transition)

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def __len__(self):
        return len(self.buffer)

Selecting Action

In [15]:
# --- Training Functions ---
def select_action(model, state, epsilon, valid_actions):
    if random.random() < epsilon:
        return random.choice(valid_actions)
    else:
        with torch.no_grad():
            q_values = model(torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device))
            q_values = q_values.squeeze()
            q_values[[i for i in range(7) if i not in valid_actions]] = -float('inf')
            return torch.argmax(q_values).item()

Optimizing

In [16]:
def optimize(model, target_model, buffer, optimizer, batch_size, gamma):
    if len(buffer) < batch_size:
        return
    states, actions, rewards, next_states, dones = buffer.sample(batch_size)

    states = torch.tensor(states, dtype=torch.float32).to(device)
    actions = torch.tensor(actions, dtype=torch.int64).to(device)
    rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
    next_states = torch.tensor(next_states, dtype=torch.float32).to(device)
    dones = torch.tensor(dones, dtype=torch.float32).to(device)

    q_values = model(states).gather(1, actions.unsqueeze(1)).squeeze()
    with torch.no_grad():
        next_q_values = target_model(next_states).max(1)[0]
    targets = rewards + gamma * next_q_values * (1 - dones)

    loss = F.mse_loss(q_values, targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Training Loop

In [None]:
def train(episodes=50000, epsilon=1.0, epsilon_min=0.1, epsilon_decay=0.995,
          gamma=0.99, batch_size=64, update_target_every=10):
    env = Connect4Env()
    model = DQN().to(device)
    target_model = DQN().to(device)
    target_model.load_state_dict(model.state_dict())
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    buffer = ReplayBuffer()

    for episode in range(episodes):
        state = env.reset()
        done = False
        total_reward = 0

        while not done:
            valid_actions = env.valid_actions()
            action = select_action(model, state, epsilon, valid_actions)
            next_state, reward, done, _ = env.step(action)
            buffer.push((np.array(state, copy=True), action, reward, np.array(next_state, copy=True), done))
            state = next_state
            total_reward += reward

            if len(buffer) >= batch_size:
                transitions = buffer.sample(batch_size)
                states, actions, rewards, next_states, dones = zip(*transitions)

                states = torch.tensor(np.stack(states), dtype=torch.float32).to(device)
                actions = torch.tensor(actions).unsqueeze(1).to(device)
                rewards = torch.tensor(rewards).unsqueeze(1).to(device)
                next_states = torch.tensor(np.stack(next_states), dtype=torch.float32).to(device)
                dones = torch.tensor(dones, dtype=torch.float32).unsqueeze(1).to(device)

                q_values = model(states).gather(1, actions)
                with torch.no_grad():
                    max_next_q = target_model(next_states).max(1)[0].unsqueeze(1)
                    target_q = rewards + gamma * max_next_q * (1 - dones)

                loss = F.mse_loss(q_values, target_q)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        if epsilon > epsilon_min:
            epsilon *= epsilon_decay

        if episode % update_target_every == 0:
            target_model.load_state_dict(model.state_dict())

        if episode % 10000 == 0 and episode > 0:
            torch.save(model.state_dict(), f"connect4_dqn_ep{episode}.pth")

        if episode % 100 == 0: print(f"Episode {episode}, Total reward: {total_reward:.2f}, Epsilon: {epsilon:.3f}")

    torch.save(model.state_dict(), "connect4_dqn.pth")
    return model


Train It

In [19]:
trained_model = train(episodes=100000,epsilon_decay=0.9999)
save_model(trained_model)

Episode 0, Total reward: 0.70, Epsilon: 1.000
Episode 1, Total reward: 0.73, Epsilon: 1.000
Episode 2, Total reward: 0.78, Epsilon: 1.000
Episode 3, Total reward: 0.67, Epsilon: 1.000
Episode 4, Total reward: 0.80, Epsilon: 1.000
Episode 5, Total reward: 0.68, Epsilon: 0.999
Episode 6, Total reward: 0.82, Epsilon: 0.999
Episode 7, Total reward: 0.87, Epsilon: 0.999
Episode 8, Total reward: 0.78, Epsilon: 0.999
Episode 9, Total reward: 0.79, Epsilon: 0.999
Episode 10, Total reward: 0.77, Epsilon: 0.999
Episode 11, Total reward: 0.87, Epsilon: 0.999
Episode 12, Total reward: 0.79, Epsilon: 0.999
Episode 13, Total reward: 0.74, Epsilon: 0.999
Episode 14, Total reward: 0.83, Epsilon: 0.999
Episode 15, Total reward: 0.82, Epsilon: 0.998
Episode 16, Total reward: 0.84, Epsilon: 0.998
Episode 17, Total reward: 0.84, Epsilon: 0.998
Episode 18, Total reward: 0.82, Epsilon: 0.998
Episode 19, Total reward: 0.78, Epsilon: 0.998
Episode 20, Total reward: 0.68, Epsilon: 0.998
Episode 21, Total rewar

KeyboardInterrupt: 

In [37]:
def save_model(model, path="connect4_model_weights.pth"):
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

In [38]:
save_model(trained_model)

Model saved to connect4_model_weights.pth


Load Weights

In [40]:
def load_trained_model(path="connect4_model_weights.pth"):
    model = DQN().to(device)
    model.load_state_dict(torch.load(path, map_location=device))
    model.eval()  # Set to evaluation mode
    print(f"Loaded model from {path}")
    return model

In [41]:
trained_model = load_trained_model()

Loaded model from connect4_model_weights.pth


Playing

In [50]:
def play_against_model(starting_player=None):
    env = Connect4Env()
    model = DQN().to(device)
    model.load_state_dict(torch.load("connect4_dqn.pth"))
    model.eval()

    state = env.reset()
    done = False

    if starting_player is None:
        choice = input("Do you want to go first? (y/n): ").lower()
        if choice == 'y':
            env.current_player = -1
        else:
            env.current_player = 1
    else:
        env.current_player = -1 if starting_player == -1 else 1
    print("You are Player -1 (O). Model is Player 1 (X). Columns: 0 to 6")
    print(env.board)

    while not done:
        if env.current_player == -1:
            try:
                user_action = int(input("Your move (0-6): "))
            except ValueError:
                print("Invalid input. Try a number from 0 to 6.")
                continue
            if user_action not in env.valid_actions():
                print("Invalid move. Try again.")
                continue
            state, reward, done, _ = env.step(user_action)
        else:
            valid_actions = env.valid_actions()
            action = select_action(model, state, epsilon=0.0, valid_actions=valid_actions)
            state, reward, done, _ = env.step(action)
            print("Model played column:", action)

        symbol_map = {1: 'X', -1: 'O', 0: '.'}
        for row in env.board:
            print(' '.join(symbol_map[cell] for cell in row))
        if done:
            print("Game Over. Reward:", reward)
            break

In [51]:
play_against_model()

You are Player -1 (O). Model is Player 1 (X). Columns: 0 to 6
[[0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]]
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . O . . .
Model played column: 3
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . X . . .
. . . O . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . X . . .
. . O O . . .
Model played column: 5
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . X . . .
. . O O . X .
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . X . . .
. O O O . X .
Model played column: 5
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . X . X .
. O O O . X .
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . X . X .
. O O O O X .
Game Over. Reward: 1
