In [2]:
import numpy as np

class Gomoku:
    def __init__(self, size=15):
        self.size = size
        self.board = np.zeros((size, size), dtype=int)
        self.current_player = 1
    
    def reset(self):
        self.board = np.zeros((self.size, self.size), dtype=int)
        self.current_player = 1
        return self.board
    
    def is_valid_move(self, x, y):
        return 0 <= x < self.size and 0 <= y < self.size and self.board[x, y] == 0
    
    def make_move(self, x, y):
        if self.is_valid_move(x, y):
            self.board[x, y] = self.current_player
            self.current_player = 3 - self.current_player
            return True
        return False
    
    def check_winner(self):
        for x in range(self.size):
            for y in range(self.size):
                if self.board[x, y] != 0 and self.check_direction(x, y):
                    return self.board[x, y]
        return 0
    
    def check_direction(self, x, y):
        directions = [(1, 0), (0, 1), (1, 1), (1, -1)]
        for d in directions:
            count = 0
            for i in range(-4, 5):
                nx, ny = x + i * d[0], y + i * d[1]
                if 0 <= nx < self.size and 0 <= ny < self.size and self.board[nx, ny] == self.board[x, y]:
                    count += 1
                    if count == 5:
                        return True
                else:
                    count = 0
        return False

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque

class DQN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class DQNAgent:
    def __init__(self, state_size, action_size, hidden_size=128, gamma=0.99, lr=0.001, batch_size=64, epsilon=1.0, epsilon_min=0.01, epsilon_decay=0.995):
        self.state_size = state_size
        self.action_size = action_size
        self.hidden_size = hidden_size
        self.gamma = gamma
        self.lr = lr
        self.batch_size = batch_size
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        
        self.memory = deque(maxlen=2000)
        self.model = DQN(state_size, hidden_size, action_size)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.criterion = nn.MSELoss()
    
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
    
    def act(self, state):
        if random.uniform(0, 1) < self.epsilon:
            return random.choice(range(self.action_size))
        state = torch.FloatTensor(state).unsqueeze(0)
        q_values = self.model(state)
        return torch.argmax(q_values).item()
    
    def replay(self):
        if len(self.memory) < self.batch_size:
            return
        batch = random.sample(self.memory, self.batch_size)
        for state, action, reward, next_state, done in batch:
            state = torch.FloatTensor(state).unsqueeze(0)
            next_state = torch.FloatTensor(next_state).unsqueeze(0)
            target = reward
            if not done:
                target += self.gamma * torch.max(self.model(next_state)).item()
            target_f = self.model(state)
            target_f[0][action] = target
            self.optimizer.zero_grad()
            loss = self.criterion(target_f, self.model(state))
            loss.backward()
            self.optimizer.step()
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def load_model(self, path):
        self.model.load_state_dict(torch.load(path))

    def save_model(self, path):
        torch.save(self.model.state_dict(), path)


In [14]:
def train_dqn(agent, env, episodes=10):
    for e in range(episodes):
        state = env.reset().flatten()
        done = False
        while not done:
            action = agent.act(state)
            x, y = divmod(action, env.size)
            if env.is_valid_move(x, y):
                env.make_move(x, y)
                reward = 1 if env.check_winner() else 0
                next_state = env.board.flatten()
                done = reward > 0 or (env.board != 0).all()
                agent.remember(state, action, reward, next_state, done)
                state = next_state
                if done:
                    break
                # Opponent's turn (random move)
                opp_action = random.choice([i for i in range(env.size * env.size) if env.is_valid_move(*divmod(i, env.size))])
                opp_x, opp_y = divmod(opp_action, env.size)
                env.make_move(opp_x, opp_y)
                if env.check_winner():
                    reward = -1
                    done = True
            else:
                reward = -1
                done = True
            agent.replay()
        print(f"Episode {e+1}/{episodes}, epsilon: {agent.epsilon:.2f}")
    agent.save_model('dqn_gomoku.pth')

gomoku_env = Gomoku()
dqn_agent = DQNAgent(state_size=gomoku_env.size*gomoku_env.size, action_size=gomoku_env.size*gomoku_env.size)
train_dqn(dqn_agent, gomoku_env, episodes=10)


Episode 1/10, epsilon: 1.00
Episode 2/10, epsilon: 1.00
Episode 3/10, epsilon: 1.00
Episode 4/10, epsilon: 1.00
Episode 5/10, epsilon: 0.98
Episode 6/10, epsilon: 0.94
Episode 7/10, epsilon: 0.92
Episode 8/10, epsilon: 0.86
Episode 9/10, epsilon: 0.83
Episode 10/10, epsilon: 0.82


In [None]:
import tkinter as tk
import numpy as np

class GomokuApp:
    def __init__(self, root, size=15):
        self.root = root
        self.size = size
        self.gomoku = Gomoku(size)
        self.agent = DQNAgent(state_size=size*size, action_size=size*size)
        self.agent.load_model('dqn_gomoku.pth')
        self.canvas = tk.Canvas(root, width=600, height=600)
        self.canvas.pack()
        self.canvas.bind("<Button-1>", self.on_click)
        self.draw_board()
        self.reset_game()
    
    def reset_game(self):
        self.gomoku.reset()
        self.update_canvas()
    
    def draw_board(self):
        for i in range(self.size):
            self.canvas.create_line(20 + i * 40, 20, 20 + i * 40, 580)
            self.canvas.create_line(20, 20 + i * 40, 580, 20 + i * 40)
    
    def update_canvas(self):
        self.canvas.delete("piece")
        for x in range(self.size):
            for y in range(self.size):
                if self.gomoku.board[x, y] == 1:
                    self.canvas.create_oval(20 + x * 40 - 15, 20 + y * 40 - 15, 20 + x * 40 + 15, 20 + y * 40 + 15, fill="black", tags="piece")
                elif self.gomoku.board[x, y] == 2:
                    self.canvas.create_oval(20 + x * 40 - 15, 20 + y * 40 - 15, 20 + x * 40 + 15, 20 + y * 40 + 15, fill="red", tags="piece")
    
    def on_click(self, event):
        x, y = (event.x - 20) // 40, (event.y - 20) // 40
        if 0 <= x < self.size and 0 <= y < self.size and self.gomoku.is_valid_move(x, y):
            self.gomoku.make_move(x, y)
            self.update_canvas()
            if self.gomoku.check_winner():
                winner = self.gomoku.check_winner()
                print(f"Player {winner} wins!")
                self.reset_game()
                return
            self.agent_move()
    
    def agent_move(self):
        state = self.gomoku.board.flatten()
        action = self.agent.act(state)
        x, y = divmod(action, self.size)
        if self.gomoku.is_valid_move(x, y):
            self.gomoku.make_move(x, y)
            self.update_canvas()
            if self.gomoku.check_winner():
                winner = self.gomoku.check_winner()
                print(f"Player {winner} wins!")
                self.reset_game()

if __name__ == "__main__":
    root = tk.Tk()
    app = GomokuApp(root)
    root.mainloop()


Player 1 wins!
Player 1 wins!
Player 1 wins!
