Importing Libraries

In [1]:
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
from copy import deepcopy
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())
print("Using device:", device)
if device.type == "cuda":
    print("CUDA device name:", torch.cuda.get_device_name(0))

True
Using device: cuda
CUDA device name: NVIDIA GeForce RTX 3070 Ti Laptop GPU


In [2]:
def to_device(tensor):
    return tensor.to(device)

Class for connect4 environment

In [1]:
class Connect4Env:
    def __init__(self):
        self.rows = 6
        self.cols = 7
        self.reset()

    def reset(self):
        self.board = np.zeros((self.rows, self.cols), dtype=int)
        self.current_player = 1
        self.last_move = None
        return self._get_obs()

    def _get_obs(self):
        p1_board = (self.board == 1).astype(np.float32)
        p2_board = (self.board == -1).astype(np.float32)
        return np.stack([p1_board, p2_board], axis=0) if self.current_player == 1 else np.stack([p2_board, p1_board], axis=0)

    def valid_actions(self):
        return [c for c in range(self.cols) if self.board[0, c] == 0]

    def step(self, action):
        if action not in self.valid_actions():
            return self._get_obs(), -10, True, {}

        for r in range(self.rows - 1, -1, -1):
            if self.board[r, action] == 0:
                self.board[r, action] = self.current_player
                self.last_move = (r, action)
                break

        if self._check_winner(self.current_player, self.last_move):
            return self._get_obs(), 1, True, {}

        if len(self.valid_actions()) == 0:
            return self._get_obs(), 0, True, {}

        self.current_player *= -1
        return self._get_obs(), -0.005, False, {}

    def _check_winner(self, player, last_move):
        if last_move is None:
            return False
        r, c = last_move
        directions = [(1, 0), (0, 1), (1, 1), (1, -1)]
        for dr, dc in directions:
            count = 1
            for d in [-1, 1]:
                nr, nc = r + d * dr, c + d * dc
                while 0 <= nr < self.rows and 0 <= nc < self.cols and self.board[nr, nc] == player:
                    count += 1
                    if count >= 4:
                        return True
                    nr += d * dr
                    nc += d * dc
        return False

    @staticmethod
    def _check_winner_static(board, player):
        for r in range(6):
            for c in range(4):
                if all(board[r, c+i] == player for i in range(4)):
                    return True
        for c in range(7):
            for r in range(3):
                if all(board[r+i, c] == player for i in range(4)):
                    return True
        for r in range(3):
            for c in range(4):
                if all(board[r+i, c+i] == player for i in range(4)):
                    return True
        for r in range(3):
            for c in range(3, 7):
                if all(board[r+i, c-i] == player for i in range(4)):
                    return True
        return False



In [2]:
def find_threats_from_move(board, player, move):
    r, c = move
    rows, cols = 6, 7
    threats = set()
    directions = [(0,1), (1,0), (1,1), (1,-1)]

    for dr, dc in directions:
        for offset in range(-3, 1):
            segment = []
            positions = []
            for i in range(4):
                nr = r + (offset + i) * dr
                nc = c + (offset + i) * dc
                if 0 <= nr < rows and 0 <= nc < cols:
                    segment.append(board[nr, nc])
                    positions.append((nr, nc))
            if len(segment) == 4 and segment.count(player) == 3 and segment.count(0) == 1:
                idx = segment.index(0)
                er, ec = positions[idx]
                if er == rows - 1 or board[er + 1, ec] != 0:
                    threats.add(ec)
    return threats

In [3]:
def analyze_move(prev_board, new_board, current_player, valid_actions, last_move):
    reward = 0.0

    threats = find_threats_from_move(new_board, current_player, last_move)
    if len(threats) >= 2:
        reward += 0.6
    elif len(threats) == 1:
        reward += 0.3

    for action in valid_actions:
        temp = new_board.copy()
        for r in range(5, -1, -1):
            if temp[r, action] == 0:
                temp[r, action] = current_player
                break
        if Connect4Env._check_winner_static(temp, current_player):
            reward -= 0.25
            break

    opponent = -current_player
    blocked = False
    for action in valid_actions:
        temp = prev_board.copy()
        for r in range(5, -1, -1):
            if temp[r, action] == 0:
                temp[r, action] = opponent
                break
        if Connect4Env._check_winner_static(temp, opponent):
            temp2 = new_board.copy()
            for r in range(5, -1, -1):
                if temp2[r, action] == 0:
                    temp2[r, action] = opponent
                    break
            if not Connect4Env._check_winner_static(temp2, opponent):
                reward += 0.3
                blocked = True
                break
    if not blocked:
        reward -= 0.2

    return reward

In [26]:
def count_true_threes(board, player):
    count = 0
    # Horizontal
    for row in range(6):
        for col in range(4):
            line = board[row, col:col+4]
            if np.count_nonzero(line == player) == 3 and np.count_nonzero(line == 0) == 1:
                if is_contiguous(line, player):
                    count += 1
    # Vertical
    for col in range(7):
        for row in range(3):
            line = board[row:row+4, col]
            if np.count_nonzero(line == player) == 3 and np.count_nonzero(line == 0) == 1:
                if is_contiguous(line, player):
                    count += 1
    # Diagonal (\)
    for row in range(3):
        for col in range(4):
            line = [board[row+i, col+i] for i in range(4)]
            if line.count(player) == 3 and line.count(0) == 1:
                if is_contiguous(line, player):
                    count += 1
    # Diagonal (/)
    for row in range(3):
        for col in range(3, 7):
            line = [board[row+i, col-i] for i in range(4)]
            if line.count(player) == 3 and line.count(0) == 1:
                if is_contiguous(line, player):
                    count += 1
    return count

def is_contiguous(line, player):
    # True if there are no gaps between player pieces in the 4-element line
    indexes = [i for i, val in enumerate(line) if val == player]
    return max(indexes) - min(indexes) <= 2

Model

In [None]:
class DQN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 6 * 7, 128)
        self.out = nn.Linear(128, 7)

    def forward(self, x):
        x = x.to(device)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.out(x)

Replay Buffer

In [None]:
class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)

    def push(self, transition):
        self.buffer.append(transition)

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def __len__(self):
        return len(self.buffer)


Selecting Action

In [None]:
# --- Training Functions ---
def select_action(model, state, epsilon, valid_actions):
    if random.random() < epsilon:
        return random.choice(valid_actions)
    else:
        with torch.no_grad():
            q_values = model(torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)).squeeze()
            q_values[[i for i in range(7) if i not in valid_actions]] = -float('inf')
            return torch.argmax(q_values).item()

Training Loop

In [4]:
def train_self_play(episodes=50000, epsilon=1.0, epsilon_min=0.01, epsilon_decay=0.9995,
                    gamma=0.99, batch_size=64, update_target_every=10):
    env = Connect4Env()
    model1 = DQN().to(device)
    model2 = DQN().to(device)
    target_model1 = DQN().to(device)
    target_model2 = DQN().to(device)
    target_model1.load_state_dict(model1.state_dict())
    target_model2.load_state_dict(model2.state_dict())
    optimizer1 = optim.Adam(model1.parameters(), lr=1e-3)
    optimizer2 = optim.Adam(model2.parameters(), lr=1e-3)
    buffer1 = ReplayBuffer()
    buffer2 = ReplayBuffer()

    for episode in range(episodes):
        state = env.reset()
        done = False
        reward_p1 = 0
        reward_p2 = 0
        last_moves = {1: None, -1: None}

        while not done:
            current_player = env.current_player
            model = model1 if current_player == 1 else model2
            buffer = buffer1 if current_player == 1 else buffer2
            optimizer = optimizer1 if current_player == 1 else optimizer2
            target_model = target_model1 if current_player == 1 else target_model2

            valid_actions = env.valid_actions()
            action = select_action(model, state, epsilon, valid_actions)

            prev_board = env.board.copy()
            next_state, reward, done, _ = env.step(action)
            reward += analyze_move(prev_board, env.board, current_player, env.valid_actions(), env.last_move)

            last_moves[current_player] = (np.array(state, copy=True), action, np.array(next_state, copy=True))
            buffer.push((np.array(state, copy=True), action, reward, np.array(next_state, copy=True), done))
            state = next_state

            if current_player == 1:
                reward_p1 += reward
            else:
                reward_p2 += reward

            if done and reward >= 1:
                losing_player = -current_player
                if last_moves[losing_player] is not None:
                    s, a, ns = last_moves[losing_player]
                    opp_buffer = buffer1 if losing_player == 1 else buffer2
                    opp_buffer.push((s, a, -1.0, ns, True))
                    if losing_player == 1:
                        reward_p1 += -1.0
                    else:
                        reward_p2 += -1.0

            if len(buffer) >= batch_size:
                transitions = buffer.sample(batch_size)
                states, actions, rewards, next_states, dones = zip(*transitions)
                states = torch.tensor(np.stack(states), dtype=torch.float32).to(device)
                actions = torch.tensor(actions).unsqueeze(1).to(device)
                rewards = torch.tensor(rewards).unsqueeze(1).to(device)
                next_states = torch.tensor(np.stack(next_states), dtype=torch.float32).to(device)
                dones = torch.tensor(dones, dtype=torch.float32).unsqueeze(1).to(device)

                q_values = model(states).gather(1, actions)
                with torch.no_grad():
                    max_next_q = target_model(next_states).max(1)[0].unsqueeze(1)
                    target_q = rewards + gamma * max_next_q * (1 - dones)

                loss = F.mse_loss(q_values, target_q)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        if epsilon > epsilon_min:
            epsilon *= epsilon_decay

        if episode % update_target_every == 0:
            target_model1.load_state_dict(model1.state_dict())
            target_model2.load_state_dict(model2.state_dict())

        if episode % 100 == 0:
            print(f"Episode {episode}, P1 reward: {reward_p1:.2f}, P2 reward: {reward_p2:.2f}, Epsilon: {epsilon:.3f}")

    torch.save(model1.state_dict(), "connect4_model1_final.pth")
    torch.save(model2.state_dict(), "connect4_model2_final.pth")
    return model1, model2


Train It

In [None]:
trained_model1, trained_model2 = train_self_play(episodes=100000,epsilon_decay=0.9999)

Episode 0, P1 reward: 3.30, P2 reward: -1.60, Epsilon: 1.000
Episode 100, P1 reward: 1.40, P2 reward: -1.40, Epsilon: 0.990
Episode 200, P1 reward: -1.90, P2 reward: 3.00, Epsilon: 0.980
Episode 300, P1 reward: 4.60, P2 reward: -0.60, Epsilon: 0.970
Episode 400, P1 reward: 1.90, P2 reward: -1.90, Epsilon: 0.961
Episode 500, P1 reward: 4.60, P2 reward: -2.30, Epsilon: 0.951
Episode 600, P1 reward: 3.40, P2 reward: -1.10, Epsilon: 0.942
Episode 700, P1 reward: 4.40, P2 reward: -1.90, Epsilon: 0.932
Episode 800, P1 reward: 4.10, P2 reward: -1.60, Epsilon: 0.923
Episode 900, P1 reward: -1.90, P2 reward: 3.50, Epsilon: 0.914
Episode 1000, P1 reward: 4.00, P2 reward: -1.90, Epsilon: 0.905
Episode 1100, P1 reward: -1.30, P2 reward: 4.50, Epsilon: 0.896
Episode 1200, P1 reward: -4.20, P2 reward: 1.10, Epsilon: 0.887
Episode 1300, P1 reward: 1.80, P2 reward: -5.50, Epsilon: 0.878
Episode 1400, P1 reward: 3.50, P2 reward: -2.20, Epsilon: 0.869
Episode 1500, P1 reward: 2.40, P2 reward: -4.50, Eps

KeyboardInterrupt: 

: 

In [8]:
def save_model(model, path):
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

In [9]:
save_model(trained_model)

NameError: name 'trained_model' is not defined

Load Weights

In [10]:
def load_trained_model(path="connect4_model_weights.pth"):
    model = DQN().to(device)
    model.load_state_dict(torch.load(path, map_location=device))
    model.eval()  # Set to evaluation mode
    print(f"Loaded model from {path}")
    return model

In [11]:
trained_model = load_trained_model(path="connect4_dqn_ep10000.pth")

Loaded model from connect4_dqn_ep10000.pth


Playing

In [42]:
def play_against_model(starting_player=None):
    model_choice = input("Play against model 1 (X) or model 2 (O)? Enter 1 or 2: ").strip()
    model_path = "connect4_model1_ep40000.pth" if model_choice == "1" else "connect4_model2_ep40000.pth"
    model = DQN().to(device)
    env = Connect4Env()

    model.load_state_dict(torch.load(model_path))
    model.eval()

    state = env.reset()
    done = False

    if starting_player is None:
        choice = input("Do you want to go first? (y/n): ").lower()
        if choice == 'y':
            env.current_player = -1
        else:
            env.current_player = 1
    else:
        env.current_player = -1 if starting_player == -1 else 1
    print("You are Player -1 (O). Model is Player 1 (X). Columns: 0 to 6")
    print(f"Playing against model from: {model_path}")
    print(env.board)

    while not done:
        if env.current_player == -1:
            try:
                user_action = int(input("Your move (0-6): "))
            except ValueError:
                print("Invalid input. Try a number from 0 to 6.")
                continue
            if user_action not in env.valid_actions():
                print("Invalid move. Try again.")
                continue
            state, reward, done, _ = env.step(user_action)
        else:
            valid_actions = env.valid_actions()
            action = select_action(model, state, epsilon=0.0, valid_actions=valid_actions)
            state, reward, done, _ = env.step(action)
            print("Model played column:", action)

        # Always print board after any move
        symbol_map = {1: 'X', -1: 'O', 0: '.'}
        print("\nCurrent board:")
        for row in env.board:
            print(' '.join(symbol_map[cell] for cell in row))
        print()

        if done:
            print("Game Over. Reward:", reward)
            break



In [43]:
play_against_model()

  model.load_state_dict(torch.load(model_path))


You are Player -1 (O). Model is Player 1 (X). Columns: 0 to 6
Playing against model from: connect4_model1_ep40000.pth
[[0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]]

Current board:
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . O . . .

Model played column: 6

Current board:
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . O . . X


Current board:
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . O . . .
. . . O . . X

Model played column: 6

Current board:
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . O . . X
. . . O . . X


Current board:
. . . . . . .
. . . . . . .
. . . . . . .
. . . O . . .
. . . O . . X
. . . O . . X

Model played column: 6

Current board:
. . . . . . .
. . . . . . .
. . . . . . .
. . . O . . X
. . . O . . X
. . . O . . X


Current board:
. . . . . . .
. . . . . . .
. . . O . . .
. . . O . . X
. . . O . . X
. . . O . . X

Gam