In [88]:
import torch
from tqdm import tqdm
import random
import numpy as np
import math

In [89]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

class GomokuEnvironment:
  BLACK, WHITE, EMPTY = 1, -1, 0

  def __init__(self, n):
    self.n = n

    self.reset()

    self.w = torch.zeros((729), dtype=torch.float32).to(device)
  
  def __str__(self):
    rv = ""
    for i in range(self.n):
      for j in range(self.n):
        if self.state[i, j] == GomokuEnvironment.BLACK:
          rv += "X "
        if self.state[i, j] == GomokuEnvironment.WHITE:
          rv += "O "
        if self.state[i, j] == GomokuEnvironment.EMPTY:
          rv += "- "
      rv = rv[:-1] + '\n'
    return rv[:-1]

  def __init_action_space(self):
    self.action_space = set()
    for i in range(self.n):
      for j in range(self.n):
        self.action_space.add((i, j))
  
  def __init_contrib_total(self):
    self.contrib_total = torch.zeros((729), dtype=torch.float32).to(device)
    for i in range(self.n):
      for j in range(self.n):
        if i + 6 <= self.n:
          self.contrib_total[364] += 1
        if j + 6 <= self.n:
          self.contrib_total[364] += 1
        if i + 6 and j + 6 <= self.n:
          self.contrib_total[364] += 2

  def __update_contrib_total(self, i, j, di, dj, s):
    i_0, j_0, k_0 = i, j, 0
    while self.is_inside(i_0 - di, j_0 - dj) and k_0 - 1 >= -5:
      i_0 -= di
      j_0 -= dj
      k_0 -= 1

    i_f, j_f, k_f, mask_0, mask_f = i_0, j_0, k_0, 0, 0 
    while self.is_inside(i_f, j_f) and k_f <= 5:
      cell_0 = self.state[i_f, j_f].item()
      cell_f = self.player if i_f == i and j_f == j else cell_0

      mask_0 = (mask_0 * 3 + cell_0 + 1) % 729
      mask_f = (mask_f * 3 + cell_f + 1) % 729

      if k_f - k_0 >= 5:
        self.contrib_total[mask_0] -= s
        self.contrib_total[mask_f] += s

      i_f += di
      j_f += dj
      k_f += 1
  
  def is_inside(self, i, j):
    return i >= 0 and i < self.n and j >= 0 and j < self.n

  def get_utility(self):
    return self.utility

  def get_eval(self):
    return torch.dot(self.w, self.contrib_total)
  
  def get_eval_grad(self):
    return self.contrib_total.clone()

  def get_action_space(self):
    return self.action_space
  
  def get_player(self):
    return self.player
  
  def is_begin(self):
    return len(self.path) == 0
  
  def is_end(self):
    return self.utility != 0 or len(self.action_space) == 0
      
  def forward(self, a):
    self.path.append(a)
    i, j = a
    
    self.action_space.remove(a)

    self.__update_contrib_total(i, j, 0, 1, 1)
    self.__update_contrib_total(i, j, 1, 0, 1)
    self.__update_contrib_total(i, j, 1, 1, 1)
    self.__update_contrib_total(i, j, 1, -1, 1)
    self.state[i, j] = self.player

    if self.player == GomokuEnvironment.BLACK:
      for i in (726, 727, 728, 242, 485):
        if self.contrib_total[i] > 0:
          self.utility = self.player
    else:
      for i in (0, 1, 2, 243, 486):
        if self.contrib_total[i] > 0:
          self.utility = self.player

    self.player = -self.player

  def backward(self):
    a = self.path.pop(-1)
    i, j = a

    self.player = -self.player
    
    self.action_space.add(a)

    self.state[i, j] = GomokuEnvironment.EMPTY
    self.__update_contrib_total(i, j, 0, 1, -1)
    self.__update_contrib_total(i, j, 1, 0, -1)
    self.__update_contrib_total(i, j, 1, 1, -1)
    self.__update_contrib_total(i, j, 1, -1, -1)

    self.utility = 0

    return a

  def reset(self):
    self.state = torch.zeros((self.n, self.n), dtype=torch.int8).to(device)
    self.player = GomokuEnvironment.BLACK
    self.utility = 0
    self.path = []
    
    self.__init_action_space()
    self.__init_contrib_total()

In [90]:
class GTDLearning(GomokuEnvironment):
  def __init__(self, n, gamma):
    super().__init__(n)

    self.gamma = gamma

  def random_policy(self):
    return random.choice(list(self.get_action_space()))

  def greedy_policy(self):
    rv = (-self.get_player() * math.inf, None)
    for a in self.get_action_space():
      self.forward(a)
      rv = max(rv, (self.get_eval(), a))
      self.backward()
    return rv[1]

  def epsilon_greedy_policy(self, epsilon):
    if random.random() > epsilon:
      return self.greedy_policy()
    else:
      return self.random_policy()

  def training_loop(self, training_episodes=2, eta=0.1, epsilon_low=0.05, epsilon_high=1.0, epsilon_rate=0.0005):
    def epsilon(episode):
      return epsilon_low + (epsilon_high - epsilon_low) * np.exp(-episode * epsilon_rate)

    for episode in range(training_episodes):
      self.reset()
      while not self.is_end():
        a = self.epsilon_greedy_policy(epsilon(episode))
        if a == None:
          a = self.random_policy()

        v_0 = self.get_eval()
        grad = self.get_eval_grad()
        self.forward(a)
        v_f = self.get_eval()
        r = self.get_utility()

        print(v_0.item(), end=' ')

        self.w.add_(eta * ((r + self.gamma * v_f) - v_0) * grad)
      print()

learner = GTDLearning(9, 0.8)
learner.training_loop()

0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 
-820.800048828125 389022.84375 -152931856.0 60886175744.0 -20980581793792.0 6905501031858176.0 -2.2119074702452326e+18 7.305884978820866e+20 -2.1706982710196197e+23 4.7129581822234515e+25 -9.296920533642229e+27 1.5010319163161041e+30 -2.7543430847201385e+32 4.629834840277389e+34 -5.207294115873478e+36 inf nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan 


In [91]:
print(learner.w)

tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, n

tensor([ 5.4962e+07, -9.4583e+07,  5.3028e+08,  3.8073e+08, -9.9378e+07,
        -2.4892e+08,  3.7594e+08,  2.2639e+08, -3.9905e+08, -5.4860e+08,
         7.6260e+07, -7.3284e+07, -5.5339e+08, -7.0294e+08, -7.8079e+07,
        -2.2762e+08,  2.2762e+08,  7.8079e+07,  7.0294e+08,  5.5339e+08,
         7.3284e+07, -7.6260e+07,  5.4860e+08,  3.9905e+08, -2.2639e+08,
        -3.7594e+08,  2.4892e+08,  9.9378e+07, -3.8073e+08, -5.3028e+08,
         9.4583e+07, -5.4962e+07], requires_grad=True)
00000 +
00001 -
00010 +
00011 +
00100 -
00101 -
00110 +
00111 +
01000 -
01001 -
01010 +
01011 -
01100 -
01101 -
01110 -
01111 -
10000 +
10001 +
10010 +
10011 +
10100 +
10101 -
10110 +
10111 +
11000 -
11001 -
11010 +
11011 +
11100 -
11101 -
11110 +
11111 -