In [1]:
!pip install creversi

Collecting creversi
  Downloading creversi-0.0.1-cp310-cp310-manylinux_2_24_x86_64.whl (711 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m711.0/711.0 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: creversi
Successfully installed creversi-0.0.1


In [2]:
from creversi import *
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
k = 192
fcl_units = 256
class DQN(nn.Module):
  def __init__(self):
    super(DQN, self).__init__()
    self.conv1 = nn.Conv2d(2, k, kernel_size=3, padding=1)
    self.bn1 = nn.BatchNorm2d(k)
    self.conv2 = nn.Conv2d(k, k, kernel_size=3, padding=1)
    self.bn2 = nn.BatchNorm2d(k)
    self.conv3 = nn.Conv2d(k, k, kernel_size=3, padding=1)
    self.bn3 = nn.BatchNorm2d(k)
    self.conv4 = nn.Conv2d(k, k, kernel_size=3, padding=1)
    self.bn4 = nn.BatchNorm2d(k)
    self.conv5 = nn.Conv2d(k, k, kernel_size=3, padding=1)
    self.bn5 = nn.BatchNorm2d(k)
    self.conv6 = nn.Conv2d(k, k, kernel_size=3, padding=1)
    self.bn6 = nn.BatchNorm2d(k)
    self.conv7 = nn.Conv2d(k, k, kernel_size=3, padding=1)
    self.bn7 = nn.BatchNorm2d(k)
    self.conv8 = nn.Conv2d(k, k, kernel_size=3, padding=1)
    self.bn8 = nn.BatchNorm2d(k)
    self.conv9 = nn.Conv2d(k, k, kernel_size=3, padding=1)
    self.bn9 = nn.BatchNorm2d(k)
    self.conv10 = nn.Conv2d(k, k, kernel_size=3, padding=1)
    self.bn10 = nn.BatchNorm2d(k)
    self.fcl1 = nn.Linear(k * 64, fcl_units)
    self.fcl2 = nn.Linear(fcl_units, 65)

  def forward(self, x):
    x = F.relu(self.bn1(self.conv1(x)))
    x = F.relu(self.bn2(self.conv2(x)))
    x = F.relu(self.bn3(self.conv3(x)))
    x = F.relu(self.bn4(self.conv4(x)))
    x = F.relu(self.bn5(self.conv5(x)))
    x = F.relu(self.bn6(self.conv6(x)))
    x = F.relu(self.bn7(self.conv7(x)))
    x = F.relu(self.bn8(self.conv8(x)))
    x = F.relu(self.bn9(self.conv9(x)))
    x = F.relu(self.bn10(self.conv10(x)))
    x = F.relu(self.fcl1(x.view(-1, k * 64)))
    x = self.fcl2(x)
    return x.tanh()

In [4]:
class GreedyPlayer:
    def __init__(self, model_path, device, network='dqn'):
        # if network == 'dueling':
        #     from creversi_gym.network.cnn10_dueling import DQN
        # else:
        #     #from creversi_gym.network.cnn5 import DQN
        #     from creversi_gym.network.cnn10 import DQN
        self.device = device
        self.model = DQN().to(device)
        checkpoint = torch.load(model_path)
        self.model.load_state_dict(checkpoint['state_dict'])
        self.model.eval()
        self.features = np.empty((1, 2, 8, 8), np.float32)

    def go(self, board):
        with torch.no_grad():
            board.piece_planes(self.features[0])
            state = torch.from_numpy(self.features).to(self.device)
            q = self.model(state)
            # 合法手に絞る
            legal_moves = list(board.legal_moves)
            next_actions = torch.tensor([legal_moves], device=self.device, dtype=torch.long)
            legal_q = q.gather(1, next_actions)
            return legal_moves[legal_q.argmax(dim=1).item()]

In [5]:
class QLearning:
    def __init__(self):
        self.q_table = {}

    def get_q_value(self, state, action):
        return self.q_table.get((state, action), 0.0)

    def update_q_value(self, state, action, value):
        self.q_table[(state, action)] = value

class QAgent:
    def __init__(self, epsilon=0.1, alpha=0.1, gamma=0.9):
        self.epsilon = epsilon
        self.alpha = alpha
        self.gamma = gamma
        self.q_learning = QLearning()

    def choose_action(self, state, legal_moves):
        if random.uniform(0, 1) < self.epsilon:
            return random.choice(legal_moves)
        else:
            q_values = [self.q_learning.get_q_value(state, action) for action in legal_moves]
            max_q_value = max(q_values)
            best_actions = [action for action, value in zip(legal_moves, q_values) if value == max_q_value]
            return random.choice(best_actions)

    def train(self, state, action, reward, next_state, legal_moves):
        current_q_value = self.q_learning.get_q_value(state, action)
        max_next_q_value = max([self.q_learning.get_q_value(next_state, next_action) for next_action in legal_moves])
        new_q_value = (1 - self.alpha) * current_q_value + self.alpha * (reward + self.gamma * max_next_q_value)
        self.q_learning.update_q_value(state, action, new_q_value)

    def reset_q_table(self):
        self.q_learning = QLearning()

class RandomAgent:
    def choose_action(self, state, legal_moves):
        return random.choice(legal_moves)

In [6]:
def print_board(board):
    print(str(board))

def play_game(q_agent, opponent_agent, board, first = True):
    if first:
      current_agent = q_agent
    else:
      current_agent = opponent_agent

    while not board.is_game_over():
        state = str(board)
        legal_moves = [creversi.move_to_str(move) for move in board.legal_moves]

        # print_board(board)

        if isinstance(current_agent, QAgent):
            action = current_agent.choose_action(state, legal_moves)
            board.move_from_str(action)
        else:
            action = opponent_agent.go(board)
            board.move(action)

        # board.move_from_str(action)
        current_agent = opponent_agent if current_agent == q_agent else q_agent

    # print("end with q_agent" if current_agent == q_agent else "end with opponent_agent")
    if first:
      firstname = "q_agent"
      secondname = "opponent_agent"
    else:
      firstname = "opponent_agent"
      secondname = "q_agent"
    if current_agent == opponent_agent:
      n_black = 64 - board.piece_num()
      n_white = board.piece_num()
    else:
      n_white = 64 - board.piece_num()
      n_black = board.piece_num()
    if n_white > n_black:
      if first:
        winner = 2
      else:
        winner = 1
      print(secondname + " white win", n_white)
    elif n_black > n_white:
      if first:
        winner = 1
      else:
        winner = 2
      print(firstname + " black win", n_black)
    else:
      winner = 0
      print("draw")

    return winner

In [7]:
q_agent = QAgent()
opponent_agent = GreedyPlayer("/content/epsilon_greedy_model.pt", device)
board = creversi.Board()

opponent_agentcount = 0
q_agentcount = 0

In [8]:
round = 500
for i in range(round):
  board = creversi.Board()
  result = play_game(q_agent, opponent_agent, board, True)
  if result == 1:
    q_agentcount += 1
  elif result == 2:
    opponent_agentcount += 1
  board = creversi.Board()
  result = play_game(q_agent, opponent_agent, board, False)
  if result == 1:
    q_agentcount += 1
  elif result == 2:
    opponent_agentcount += 1

opponent_agent white win 33
opponent_agent black win 48
opponent_agent white win 38
opponent_agent black win 37
q_agent black win 37
q_agent white win 46
q_agent black win 37
q_agent white win 38
draw
q_agent white win 44
q_agent black win 34
opponent_agent black win 33
opponent_agent white win 42
opponent_agent black win 40
opponent_agent white win 44
q_agent white win 33
q_agent black win 33
opponent_agent black win 42
q_agent black win 34
q_agent white win 44
q_agent black win 39
opponent_agent black win 41
opponent_agent white win 37
opponent_agent black win 37
opponent_agent white win 55
opponent_agent black win 54
opponent_agent white win 36
opponent_agent black win 51
opponent_agent white win 52
opponent_agent black win 35
opponent_agent white win 37
opponent_agent black win 33
opponent_agent white win 39
q_agent white win 41
q_agent black win 47
q_agent white win 47
opponent_agent white win 36
opponent_agent black win 38
draw
q_agent white win 45
q_agent black win 39
q_agent wh

In [9]:
print("q_agent win", q_agentcount)
print("opponent_agent win", opponent_agentcount)
print("draw", round * 2 - q_agentcount - opponent_agentcount)

q_agent win 430
opponent_agent win 519
draw 51


In [10]:
print("q_agent winrate", q_agentcount / (round * 2))
print("opponent_agent winrate", opponent_agentcount / (round * 2))

q_agent winrate 0.43
opponent_agent winrate 0.519
