In [1]:
from connect4 import Connect4
from agents.human import Human
from agents.random_agent import RandomAgent
from agents.negamax_agent import NegamaxAgent

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [57]:
from collections import deque
import numpy as np
import random

class DQNAgent:
    def __init__(self, env):
        self.env = env
        self.memory = deque(maxlen=10000)
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_decay = 0.995
        self.epsilon_min = 0.01
        self.learning_rate = 0.001
        self.batch_size = 64
        self.model = DQN(input_dim=42, output_dim=7)  # 6x7 board flattened
        self.target_model = DQN(input_dim=42, output_dim=7)
        self.update_target_model()
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)

    def convert_board_to_numeric_and_flatten(self, board):
        conversion = {'X': 1, 'O': -1, ' ': 0}
        numeric_board = np.array([[conversion[cell] for cell in row] for row in board])
        return numeric_board.flatten()

    def convert_board_to_numeric(self, board):
        conversion = {'X': 1, 'O': -1, ' ': 0}
        numeric_board = np.array([conversion[cell] for cell in board])
        return numeric_board

    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())

    def act(self, state):
        valid_moves = self.env.get_valid_moves()
        if not valid_moves: # If there are no valid moves, return -1
            return -1
        if np.random.rand() <= self.epsilon:
            return random.choice(valid_moves)
        state = self.convert_board_to_numeric(state)  # Convert board to numeric
        state = torch.FloatTensor(state).unsqueeze(0)
        q_values = self.model(state)
        valid_moves = self.env.get_valid_moves()
        valid_q_values = [q_values[0][move].item() for move in valid_moves]
        best_move = valid_moves[valid_q_values.index(max(valid_q_values))]
        return best_move

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def replay(self):
        if len(self.memory) < self.batch_size:
            return
        minibatch = random.sample(self.memory, self.batch_size)
        for state, action, reward, next_state, done in minibatch:
            state = self.convert_board_to_numeric(state)  # Convert board to numeric
            next_state = self.convert_board_to_numeric(next_state)  # Convert board to numeric
            
            state = torch.FloatTensor(state).unsqueeze(0)
            next_state = torch.FloatTensor(next_state).unsqueeze(0)
            target = reward
            if not done:
                target += self.gamma * torch.max(self.target_model(next_state)).item()
            target_f = self.model(state)
            target_f[0][action] = target
            self.optimizer.zero_grad()
            loss = F.mse_loss(target_f, self.model(state))
            loss.backward()
            self.optimizer.step()
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def train(self, episodes):
        for e in range(episodes):
            self.env = Connect4()
            state = self.env.board.flatten()
            done = False
            while not done:
                action = self.act(state)
                next_state, reward, done, _ = self.env.step(action)
                next_state = next_state.flatten()
                self.remember(state, action, reward, next_state, done)
                state = next_state
                self.replay()
            # print(state.reshape(6,7))
            self.update_target_model()
            print(f"Episode {e+1}/{episodes} - Epsilon: {self.epsilon:.4f}")
            print(state.reshape(6,7))

In [58]:
env = Connect4()
agent = DQNAgent(env)
agent.train(episodes=1000)

Episode 1/1000 - Epsilon: 1.0000
[[' ' ' ' ' ' ' ' ' ' ' ' ' ']
 ['O' ' ' ' ' ' ' ' ' ' ' ' ']
 ['X' ' ' ' ' ' ' ' ' ' ' ' ']
 ['X' ' ' ' ' 'X' ' ' ' ' 'X']
 ['O' 'X' ' ' 'O' 'O' 'O' 'O']
 ['X' 'O' 'O' 'X' 'X' 'X' 'O']]
Episode 2/1000 - Epsilon: 1.0000
[[' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ']
 ['O' ' ' 'X' ' ' ' ' ' ' ' ']
 ['X' 'O' 'X' ' ' ' ' ' ' ' ']
 ['O' 'X' 'O' 'O' ' ' 'X' ' ']
 ['O' 'X' 'X' 'O' 'X' 'O' ' ']]
Episode 3/1000 - Epsilon: 1.0000
[[' ' ' ' 'O' ' ' ' ' ' ' 'X']
 [' ' ' ' 'O' ' ' ' ' 'O' 'X']
 [' ' 'X' 'O' ' ' ' ' 'O' 'X']
 ['X' 'X' 'O' 'X' ' ' 'X' 'O']
 ['O' 'O' 'X' 'X' ' ' 'X' 'X']
 ['O' 'O' 'O' 'X' 'O' 'O' 'X']]
Episode 4/1000 - Epsilon: 0.8604
[['X' ' ' ' ' ' ' 'O' ' ' 'X']
 ['O' ' ' ' ' ' ' 'X' ' ' 'X']
 ['X' 'O' ' ' 'O' 'O' ' ' 'X']
 ['O' 'X' ' ' 'X' 'O' 'O' 'X']
 ['X' 'O' 'O' 'O' 'X' 'X' 'O']
 ['O' 'X' 'O' 'O' 'X' 'X' 'X']]
Episode 5/1000 - Epsilon: 0.7590
[[' ' ' ' ' ' 'X' ' ' ' ' ' ']
 [' ' ' ' ' ' 'O' ' ' ' ' ' ']
 ['O' 'O' ' ' 'O' ' ' ' '

In [59]:
agent

<__main__.DQNAgent at 0x15ae47020>