In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import tkinter as tk

class Gomoku:
    def __init__(self, size=5):
        self.size = size
        self.board = np.zeros((size, size), dtype=int)
        self.current_player = 1
    
    def reset(self):
        self.board = np.zeros((self.size, self.size), dtype=int)
        self.current_player = 1
        return self.board
    
    def is_valid_move(self, x, y):
        return 0 <= x < self.size and 0 <= y < self.size and self.board[x, y] == 0
    
    def make_move(self, x, y):
        if self.is_valid_move(x, y):
            self.board[x, y] = self.current_player
            self.current_player = 3 - self.current_player
            return True
        return False
    
    def check_winner(self):
        for x in range(self.size):
            for y in range(self.size):
                if self.board[x, y] != 0 and self.check_direction(x, y):
                    return self.board[x, y]
        return 0
    
    def check_direction(self, x, y):
        directions = [(1, 0), (0, 1), (1, 1), (1, -1)]
        for d in directions:
            count = 0
            for i in range(-4, 5):
                nx, ny = x + i * d[0], y + i * d[1]
                if 0 <= nx < self.size and 0 <= ny < self.size and self.board[nx, ny] == self.board[x, y]:
                    count += 1
                    if count == 5:
                        return True
                else:
                    count = 0
        return False

    def evaluate_position(self, player):
        score = 0
        for x in range(self.size):
            for y in range(self.size):
                if self.board[x, y] == player:
                    score += self.evaluate_point(x, y, player)
                elif self.board[x, y] == 3 - player:
                    score -= self.evaluate_point(x, y, 3 - player)
        return score

    def evaluate_point(self, x, y, player):
        score = 0
        directions = [(1, 0), (0, 1), (1, 1), (1, -1)]
        for d in directions:
            count = 0
            block = 0
            for i in range(-4, 5):
                nx, ny = x + i * d[0], y + i * d[1]
                if 0 <= nx < self.size and 0 <= ny < self.size:
                    if self.board[nx, ny] == player:
                        count += 1
                    elif self.board[nx, ny] != 0:
                        block += 1
                        break
                else:
                    block += 1
            if count == 5:
                score += 10000  # win
            elif count == 4 and block == 0:
                score += 100  # open four
            elif count == 3 and block == 0:
                score += 10  # open three
            elif count == 2 and block == 0:
                score += 1  # open two
        return score

class DQN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * input_dim * input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class DQNAgent:
    def __init__(self, state_size, action_size, hidden_size=128, gamma=0.99, lr=0.001, batch_size=64, epsilon=1.0, epsilon_min=0.01, epsilon_decay=0.995):
        self.state_size = state_size
        self.action_size = action_size
        self.hidden_size = hidden_size
        self.gamma = gamma
        self.lr = lr
        self.batch_size = batch_size
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        
        self.memory = deque(maxlen=2000)
        self.model = DQN(5, hidden_size, action_size)
        self.target_model = DQN(5, hidden_size, action_size)
        self.update_target_model()
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.criterion = nn.MSELoss()
    
    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())
    
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
    
    def act(self, state):
        if random.uniform(0, 1) < self.epsilon:
            return random.choice(self.get_valid_actions(state))
        state = torch.FloatTensor(state).view(1, 1, 5, 5)  # Add channel dimension
        q_values = self.model(state)
        valid_actions = self.get_valid_actions(state.numpy().flatten())
        return valid_actions[torch.argmax(q_values[0][valid_actions]).item()]
    
    def get_valid_actions(self, state):
        if isinstance(state, torch.Tensor):
            state = state.numpy().flatten()
        return [i for i in range(self.action_size) if state[i] == 0]
    
    def replay(self):
        if len(self.memory) < self.batch_size:
            return
        batch = random.sample(self.memory, self.batch_size)
        for state, action, reward, next_state, done in batch:
            state = torch.FloatTensor(state).view(1, 1, 5, 5)
            next_state = torch.FloatTensor(next_state).view(1, 1, 5, 5)
            target = reward
            if not done:
                target += self.gamma * torch.max(self.target_model(next_state)).item()
            target_f = self.model(state)
            target_f[0][action] = target
            self.optimizer.zero_grad()
            loss = self.criterion(target_f, self.model(state))
            loss.backward()
            self.optimizer.step()
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
    
    def load_model(self, path):
        self.model.load_state_dict(torch.load(path))
    
    def save_model(self, path):
        torch.save(self.model.state_dict(), path)

def heuristic_opponent(env):
    for x in range(env.size):
        for y in range(env.size):
            if env.is_valid_move(x, y):
                env.make_move(x, y)
                if env.check_winner() == 2:
                    env.board[x, y] = 0
                    return x * env.size + y
                env.board[x, y] = 0
    return random.choice([i for i in range(env.size * env.size) if env.is_valid_move(*divmod(i, env.size))])

def train_dqn(agent, env, episodes=1000):
    for e in range(episodes):
        state = env.reset().flatten()
        done = False
        while not done:
            action = agent.act(state)
            x, y = divmod(action, env.size)
            if env.is_valid_move(x, y):
                env.make_move(x, y)
                reward = env.evaluate_position(1)
                if env.check_winner() == 1:
                    reward += 1000
                next_state = env.board.flatten()
                done = env.check_winner() > 0 or (env.board != 0).all()
                agent.remember(state, action, reward, next_state, done)
                state = next_state
                if done:
                    break
                # Opponent's turn (heuristic move)
                opp_action = heuristic_opponent(env)
                opp_x, opp_y = divmod(opp_action, env.size)
                env.make_move(opp_x, opp_y)
                if env.check_winner() == 2:
                    reward -= 1000
                    done = True
            else:
                reward = -100
                done = True
            agent.replay()
        agent.update_target_model()
        print(f"Episode {e+1}/{episodes}, epsilon: {agent.epsilon:.2f}")
    agent.save_model('dqn_gomoku2.pth')



In [None]:
gomoku_env = Gomoku(size=5)
dqn_agent = DQNAgent(state_size=gomoku_env.size*gomoku_env.size, action_size=gomoku_env.size*gomoku_env.size)
train_dqn(dqn_agent, gomoku_env, episodes=1000)

In [12]:
import tkinter as tk
from tkinter import messagebox
import torch

class GomokuGUI:
    def __init__(self, master, size=5):
        self.master = master
        self.size = size
        self.cell_size = 60
        self.canvas = tk.Canvas(master, width=self.size * self.cell_size, height=self.size * self.cell_size)
        self.canvas.pack()
        self.gomoku = Gomoku(size)
        self.agent = DQNAgent(state_size=size*size, action_size=size*size)
        self.agent.load_model('dqn_gomoku2.pth')
        self.draw_board()
        self.canvas.bind("<Button-1>", self.click)

    def draw_board(self):
        for i in range(self.size):
            for j in range(self.size):
                x0, y0 = i * self.cell_size, j * self.cell_size
                x1, y1 = x0 + self.cell_size, y0 + self.cell_size
                self.canvas.create_rectangle(x0, y0, x1, y1, outline="black")
    
    def draw_piece(self, x, y, player):
        x0, y0 = x * self.cell_size + 10, y * self.cell_size + 10
        x1, y1 = x0 + self.cell_size - 20, y0 + self.cell_size - 20
        color = "black" if player == 1 else "white"
        self.canvas.create_oval(x0, y0, x1, y1, fill=color)
    
    def click(self, event):
        x, y = event.x // self.cell_size, event.y // self.cell_size
        if self.gomoku.make_move(x, y):
            self.draw_piece(x, y, 3 - self.gomoku.current_player)
            winner = self.gomoku.check_winner()
            if winner:
                self.game_over(winner)
            else:
                self.ai_move()
    
    def ai_move(self):
        state = self.gomoku.board.flatten()
        action = self.agent.act(state)
        x, y = divmod(action, self.size)
        if self.gomoku.make_move(x, y):
            self.draw_piece(x, y, 3 - self.gomoku.current_player)
            winner = self.gomoku.check_winner()
            if winner:
                self.game_over(winner)
    
    def game_over(self, winner):
        messagebox.showinfo("Game Over", f"Player {winner} wins!")
        self.master.destroy()

if __name__ == "__main__":
    root = tk.Tk()
    root.title("Gomoku")
    gomoku_gui = GomokuGUI(root, size=5)
    root.mainloop()
