In [1]:
import numpy as np

class GomokuEnv:
    def __init__(self, board_size=5):
        self.board_size = board_size
        self.reset()

    def reset(self):
        self.board = np.zeros((self.board_size, self.board_size), dtype=int)
        self.current_player = 1
        return self.board.flatten()

    def step(self, action):
        row, col = divmod(action, self.board_size)
        if self.board[row, col] != 0:
            return self.board.flatten(), -10, True, {}  # illegal move penalty

        self.board[row, col] = self.current_player
        if self.check_win(row, col):
            return self.board.flatten(), 10, True, {}  # win reward

        if np.all(self.board != 0):
            return self.board.flatten(), 0, True, {}  # draw

        self.current_player = 3 - self.current_player  # switch player
        return self.board.flatten(), 0, False, {}

    def check_win(self, row, col):
        player = self.board[row, col]
        directions = [(1, 0), (0, 1), (1, 1), (1, -1)]
        for dr, dc in directions:
            count = 1
            for i in range(1, 5):
                r, c = row + dr * i, col + dc * i
                if 0 <= r < self.board_size and 0 <= c < self.board_size and self.board[r, c] == player:
                    count += 1
                else:
                    break
            for i in range(1, 5):
                r, c = row - dr * i, col - dc * i
                if 0 <= r < self.board_size and 0 <= c < self.board_size and self.board[r, c] == player:
                    count += 1
                else:
                    break
            if count >= 5:
                return True
        return False

    def available_actions(self):
        return [i for i in range(self.board_size * self.board_size) if self.board.flat[i] == 0]


In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque

class DQN(nn.Module):
    def __init__(self, state_size, action_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_size, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=2000)
        self.gamma = 0.95  # discount rate
        self.epsilon = 1.0  # exploration rate
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.learning_rate = 0.001
        self.model = DQN(state_size, action_size)
        self.target_model = DQN(state_size, action_size)
        self.update_target_model()
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        self.criterion = nn.MSELoss()

    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        state_flat = state.flatten()
        if np.random.rand() <= self.epsilon:
            return random.choice(np.flatnonzero(state_flat == 0))
        state = torch.FloatTensor(state).unsqueeze(0)
        act_values = self.model(state)
        available_actions = np.flatnonzero(state_flat == 0)
        act_values = act_values[0][available_actions]
        return available_actions[torch.argmax(act_values).item()]

    def replay(self, batch_size):
        minibatch = random.sample(self.memory, batch_size)
        for state, action, reward, next_state, done in minibatch:
            state = torch.FloatTensor(state).unsqueeze(0)
            next_state = torch.FloatTensor(next_state).unsqueeze(0)
            target = reward
            if not done:
                target += self.gamma * torch.max(self.target_model(next_state))
            target_f = self.model(state).detach()
            target_f[0][action] = target
            self.optimizer.zero_grad()
            output = self.model(state)
            loss = self.criterion(output, target_f)
            loss.backward()
            self.optimizer.step()
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def load(self, name):
        self.model.load_state_dict(torch.load(name))

    def save(self, name):
        torch.save(self.model.state_dict(), name)

def train_dqn(board_size=5, episodes=10):
    env = GomokuEnv(board_size)
    state_size = board_size * board_size
    action_size = state_size
    agent = DQNAgent(state_size, action_size)
    done = False
    batch_size = 32

    for e in range(episodes):
        state = env.reset()
        for time in range(500):
            action = agent.act(state)
            next_state, reward, done, _ = env.step(action)
            agent.remember(state, action, reward, next_state, done)
            state = next_state
            if done:
                agent.update_target_model()
                break
            if len(agent.memory) > batch_size:
                agent.replay(batch_size)
        print(f"Episode {e+1}/{episodes}")

    return agent, env


In [13]:
import tkinter as tk
from tkinter import messagebox

class GomokuGUI:
    def __init__(self, master, agent, env):
        self.master = master
        self.agent = agent
        self.env = env
        self.board_size = env.board_size
        self.board = env.board
        self.current_player = 1
        self.buttons = [[None for _ in range(self.board_size)] for _ in range(self.board_size)]
        self.create_widgets()

    def create_widgets(self):
        for row in range(self.board_size):
            for col in range(self.board_size):
                button = tk.Button(self.master, width=4, height=2, command=lambda r=row, c=col: self.player_move(r, c))
                button.grid(row=row, column=col)
                self.buttons[row][col] = button

    def player_move(self, row, col):
        action = row * self.board_size + col
        if self.board[row, col] == 0:
            self.update_board(action, self.current_player)
            state, reward, done, _ = self.env.step(action)
            self.update_buttons()
            if done:
                messagebox.showinfo("Game Over", f"Player {self.current_player} wins!")
                self.reset_game()
                return
            self.current_player = 3 - self.current_player  # switch player
            self.agent_move()

    def agent_move(self):
        state = self.env.board.flatten()
        available_actions = self.env.available_actions()
        action = self.agent.act(state)
        self.update_board(action, self.current_player)
        state, reward, done, _ = self.env.step(action)
        self.update_buttons()
        if done:
            messagebox.showinfo("Game Over", f"Player {self.current_player} wins!")
            self.reset_game()
        self.current_player = 3 - self.current_player  # switch player

    def update_board(self, action, player):
        row, col = divmod(action, self.board_size)
        self.board[row, col] = player

    def update_buttons(self):
        symbols = {0: " ", 1: "X", 2: "O"}
        for row in range(self.board_size):
            for col in range(self.board_size):
                self.buttons[row][col].config(text=symbols[self.board[row, col]])

    def reset_game(self):
        self.board = self.env.reset().reshape(self.board_size, self.board_size)
        self.current_player = 1
        self.update_buttons()

if __name__ == "__main__":
    agent, env = train_dqn()

    root = tk.Tk()
    root.title("Gomoku")
    app = GomokuGUI(root, agent, env)
    root.mainloop()


Episode 1/10
Episode 2/10
Episode 3/10
Episode 4/10
Episode 5/10
Episode 6/10
Episode 7/10
Episode 8/10
Episode 9/10
Episode 10/10
