In [None]:
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):
  def __init__(self, capacity):
    self.capacity = capacity
    self.memory = deque(maxlen=capacity)
  
  def push(self, *args):
    self.memory.append(Transition(*args))
  
  def sample(self, batch_size):
    return random.sample(self.memory, batch_size)

  def __len__(self):
    return len(self.memory)

In [None]:
class DQNConv(nn.Module):
  def __init__(self, n_observations, n_actions, n_hidden=256):
    super(DQNConv, self).__init__()
    self.n_observations = n_observations
    self.n_actions = n_actions

    # first convolutional layer 4x4x4 => 4x4x8
    self.conv1 = nn.Conv2d(4, 8, kernel_size=3, stride=1, padding=1)
    self.bn1 = nn.BatchNorm2d(8)
    
    # second convolutional layer 4x4x8 => 4x4x16
    self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1)
    self.bn2 = nn.BatchNorm2d(16)
    
    # third convolutional layer 4x4x16 => 4x4x32
    self.conv3 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
    self.bn3 = nn.BatchNorm2d(32)

    # first fully connected layer 512 => 512
    self.fc1 = nn.Linear(4*4*32, 4*4*32)
    
    # second fully connected layer 512 => 256
    self.fc2 = nn.Linear(32*4*4, 16*4*4)

    # third fully connected layer 256 => 128
    self.fc3 = nn.Linear(16*4*4, 8*4*4)

    # output layer 128 => 16
    self.output_layer = nn.Linear(128, n_actions)

    nn.init.xavier_uniform_(self.fc1.weight)
    nn.init.xavier_uniform_(self.fc2.weight)
    nn.init.xavier_uniform_(self.fc3.weight)
    nn.init.xavier_uniform_(self.output_layer.weight)

  def forward(self, x):
    x = x.view(-1, 4, 4, 4)

    # conv1 + bn1 with activation function ReLU
    x = nn.functional.relu(self.bn1(self.conv1(x)))
    
    # conv2 + bn2 with activation function ReLU
    x = nn.functional.relu(self.bn2(self.conv2(x)))
    
    # conv3 + bn3 with activation function ReLU
    x = nn.functional.relu(self.bn3(self.conv3(x)))
    
    # flatten will transform data structure from 3D 8x8x128 to 1D 8192
    x = nn.Flatten()(x)
    
    # fully connected with activation function ReLU
    x = nn.functional.relu(self.fc1(x))

    # fully connected with activation function ReLU
    x = nn.functional.relu(self.fc2(x))

    # fully connected with activation function ReLU
    x = nn.functional.relu(self.fc3(x))

    x = self.output_layer(x)

    # x = F.relu(self.fc1(x))
    # x = F.relu(self.fc2(x))
    # x = F.relu(self.fc3(x))
    # x = self.output_layer(x)
    return x

In [None]:
class QAgentConv:
  def __init__(self, logging=False, training=False):
    self.training = training

    self.n_observations = 4*4*4
    self.n_actions = 16

    self.policy_net = DQNConv(self.n_observations, self.n_actions).to(device)
    self.target_net = DQNConv(self.n_observations, self.n_actions).to(device)
    self.target_net.load_state_dict(self.policy_net.state_dict())

    if training:
      self.BATCH_SIZE = 256
      self.memory = ReplayMemory(100000)
      self.steps_done = 0

      self.GAMMA = 0.995
      self.EPS_START = 0.9
      self.EPS_END = 0.05
      self.EPS_DECAY = 6000

      self.LR = 1e-3
      self.TAU = 0.005

      self.optimizer = optim.AdamW(self.policy_net.parameters(), lr=self.LR, amsgrad=True)
      self.loss = nn.MSELoss()
    else:
      self.policy_net.eval()
      self.target_net.eval()

  def load_weights(self, path):
    state_dict = torch.load(path)
    self.policy_net.load_state_dict(state_dict)
    self.target_net.load_state_dict(state_dict)
    print("Weights loaded successfully from", path)
  
  def save_weights(self, path):
    torch.save(self.policy_net.state_dict(), path)
    print("Weights saved successfully to", path)
  
  def create_indicator_array(self, coords, num_rows=4, num_cols=4):
    indicator_array = torch.zeros(16, device=device)
    
    for row, col in coords:
      index = row * num_cols + col
      indicator_array[index] = 1
    
    return indicator_array
  
  def creat_index_array(self, coords, num_rows=4, num_cols=4):
    index_array = []
    
    for row, col in coords:
      index = row * num_cols + col
      index_array.append(index)
    
    return torch.tensor(index_array, device=device)

  def findBestMove(self, board, possible_move, player):
    if len(possible_move) == 0:
      return None
    
    if self.training:
      sample = random.random()
      eps_threshold = self.EPS_END + (self.EPS_START - self.EPS_END) * math.exp(-1. * self.steps_done / self.EPS_DECAY)
      self.steps_done += 1

      if sample > eps_threshold:
        self.policy_net.eval()
        self.target_net.eval()

        with torch.no_grad():
          # Change the board to the player's perspective
          state = torch.tensor(board, dtype=torch.float32).to(device)
          mask = self.create_indicator_array(possible_move)
          score = self.policy_net(state) * mask - 9999 * (1-mask)
          return score.max(1).indices.view(1, 1)
      else:
        return torch.tensor([[random.choice(self.creat_index_array(possible_move))]], device=device, dtype=torch.long)
    else:
      with torch.no_grad():
        # Change the board to the player's perspective
        state = torch.tensor(board, dtype=torch.float32).to(device)
        mask = self.create_indicator_array(possible_move)
        score = self.policy_net(state) * mask - 9999 * (1-mask)
        return score.max(1).indices.view(1, 1)
  
  def optimize_model(self):
    self.policy_net.train()
    self.target_net.train()

    if len(self.memory) < self.BATCH_SIZE:
      return
    transitions = self.memory.sample(self.BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])

    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    state_action_values = self.policy_net(state_batch).gather(1, action_batch)

    next_state_values = torch.zeros(self.BATCH_SIZE, device=device)
    with torch.no_grad():
      next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1).values
    expected_state_action_values = (next_state_values * self.GAMMA) + reward_batch

    loss = self.loss(state_action_values, expected_state_action_values.unsqueeze(1))

    self.optimizer.zero_grad()
    loss.backward()
    # torch.nn.utils.clip_grad_value_(self.policy_net.parameters(), 100)
    # for param in self.policy_net.parameters():
    #   param.grad.data.clamp_(-1, 1)
    self.optimizer.step()