# Implementation of Alpha Zero for games TicTacToe and Connect 4
This follows the ideas of idea.md

In [2]:
import numpy as np
#import torch
#import torch.nn as nn
#import torch.nn.functional as F
#from tqdm.notebook import trange

import random
import math

### Games

In [3]:
class TicTacToe:
    def __init__(self):
        self.row_count = 3
        self.column_count = 3
        self.action_size = self.row_count * self.column_count
        
    def __repr__(self):
        return "TicTacToe"
        
    def get_initial_state(self):
        return np.zeros((self.row_count, self.column_count))
    
    def get_next_state(self, state, action, player):
        row = action // self.column_count
        column = action % self.column_count
        state[row, column] = player
        return state
    
    def get_valid_moves(self, state):
        return (state.reshape(-1) == 0).astype(np.uint8)
    
    def check_win(self, state, action):
        if action == None:
            return False
        
        row = action // self.column_count
        column = action % self.column_count
        player = state[row, column]
        
        return (
            np.sum(state[row, :]) == player * self.column_count
            or np.sum(state[:, column]) == player * self.row_count
            or np.sum(np.diag(state)) == player * self.row_count
            or np.sum(np.diag(np.flip(state, axis=0))) == player * self.row_count
        )
    
    def get_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            return 1, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        return 0, False
    
    def get_opponent(self, player):
        return -player
    
    def get_opponent_value(self, value):
        return -value
    
    def change_perspective(self, state, player):
        return state * player
    
    def get_encoded_state(self, state):
        encoded_state = np.stack(
            (state == -1, state == 0, state == 1)
        ).astype(np.float32)
        
        if len(state.shape) == 3:
            encoded_state = np.swapaxes(encoded_state, 0, 1)
        
        return encoded_state

In [4]:
class ConnectFour:
    def __init__(self):
        self.row_count = 6
        self.column_count = 7
        self.action_size = self.column_count
        self.in_a_row = 4
        
    def __repr__(self):
        return "ConnectFour"
        
    def get_initial_state(self):
        return np.zeros((self.row_count, self.column_count))
    
    def get_next_state(self, state, action, player):
        row = np.max(np.where(state[:, action] == 0))
        state[row, action] = player
        return state
    
    def get_valid_moves(self, state):
        return (state[0] == 0).astype(np.uint8)
    
    def check_win(self, state, action):
        if action == None:
            return False
        
        row = np.min(np.where(state[:, action] != 0))
        column = action
        player = state[row][column]

        def count(offset_row, offset_column):
            for i in range(1, self.in_a_row):
                r = row + offset_row * i
                c = action + offset_column * i
                if (
                    r < 0 
                    or r >= self.row_count
                    or c < 0 
                    or c >= self.column_count
                    or state[r][c] != player
                ):
                    return i - 1
            return self.in_a_row - 1

        return (
            count(1, 0) >= self.in_a_row - 1 # vertical
            or (count(0, 1) + count(0, -1)) >= self.in_a_row - 1 # horizontal
            or (count(1, 1) + count(-1, -1)) >= self.in_a_row - 1 # top left diagonal
            or (count(1, -1) + count(-1, 1)) >= self.in_a_row - 1 # top right diagonal
        )
    
    def get_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            return 1, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        return 0, False
    
    def get_opponent(self, player):
        return -player
    
    def get_opponent_value(self, value):
        return -value
    
    def change_perspective(self, state, player):
        return state * player
    
    def get_encoded_state(self, state):
        encoded_state = np.stack(
            (state == -1, state == 0, state == 1)
        ).astype(np.float32)
        
        if len(state.shape) == 3:
            encoded_state = np.swapaxes(encoded_state, 0, 1)
        
        return encoded_state

### OUR GAME

In [1]:
from enum import StrEnum, Flag, auto
from functools import reduce

class Command(StrEnum):
    INFO = "info"
    HELP = "help"
    OPTIONS = "options"
    NEWGAME = "newgame"
    VALIDMOVES = "validmoves"
    BESTMOVE = "bestmove"
    PLAY = "play"
    PASS = "pass"
    UNDO = "undo"
    EXIT = "exit"
    GET = "get"
    SET = "set"

class PlayerColor():
    WHITE = 1
    BLACK = -1

class GameType(Flag):
  Base = auto()
  M = auto()
  L = auto()
  P = auto()

  @classmethod
  def parse(cls, type: str):
    if type:
      base, *expansions = type.split("+")
      try:
        if GameType[base] != GameType.Base or expansions == [""] or len(expansions) > 1 and type.find("+") >= 0: raise KeyError()
        return reduce(lambda type, expansion: type | expansion, [GameType[expansion] for expansion in (expansions[0] if expansions else "")], GameType[base])
      except KeyError:
        raise ValueError(f"'{type}' is not a valid GameType")
    return GameType.Base

  def __str__(self) -> str:
    return "".join(str(gametype.name) + ("+" if gametype is GameType.Base and len(self) > 1 else "") for gametype in self)

class BugType(StrEnum):
  QUEEN_BEE = "Q"
  SPIDER = "S"
  BEETLE = "B"
  GRASSHOPPER = "G"
  SOLDIER_ANT = "A"
  MOSQUITO = "M"
  LADYBUG = "L"
  PILLBUG = "P"

class Direction(StrEnum):
  RIGHT = "|-"
  UP_RIGHT = "|/"
  UP_LEFT = "\\|"
  LEFT = "-|"
  DOWN_LEFT = "/|"
  DOWN_RIGHT = "|\\"
  BELOW = ""
  ABOVE = "|"

  @classmethod
  def flat(cls):
    return [direction for direction in Direction if direction is not Direction.ABOVE and direction is not Direction.BELOW]
  
  @classmethod
  def flat_left(cls):
    return [direction for direction in Direction if direction.is_left]

  @classmethod
  def flat_right(cls):
    return [direction for direction in Direction if direction.is_right]

  def __str__(self) -> str:
    return self.replace("|", "")

  @property
  def opposite(self):
    match self:
      case Direction.BELOW | Direction.ABOVE:
        return list(Direction)[(self.delta_index - 5) % 2 + 6]
      case _:
        return list(Direction)[(self.delta_index + 3) % 6]

  @property
  def left_of(self):
    match self:
      case Direction.BELOW | Direction.ABOVE:
        return self
      case _:
        return list(Direction)[(self.delta_index + 1) % 6]
  
  @property
  def right_of(self):
    match self:
      case Direction.BELOW | Direction.ABOVE:
        return self
      case _:
        return list(Direction)[(self.delta_index + 5) % 6]

  @property
  def delta_index(self) -> int:
    return list(Direction).index(self)

  @property
  def is_right(self) -> bool:
    return self is Direction.RIGHT or self is Direction.UP_RIGHT or self is Direction.DOWN_RIGHT

  @property
  def is_left(self) -> bool:
    return self is Direction.LEFT or self is Direction.UP_LEFT or self is Direction.DOWN_LEFT


In [2]:
from typing import Final, Optional
import re

class Position():
  def __init__(self, q: int, r: int):
    self.row: Final[int] = q
    self.col: Final[int] = r

  def __str__(self) -> str:
    return f"({self.row}, {self.col})"

  def __hash__(self) -> int:
    return hash((self.row, self.col))

  def __eq__(self, value: object) -> bool:
    return self is value or isinstance(value, Position) and self.row == value.row and self.col == value.col

  def __add__(self, other: object):
    return Position(self.row + other.row, self.col + other.col) if isinstance(other, Position) else NotImplemented
    
  def __sub__(self, other: object):
    return Position(self.row - other.row, self.col - other.col) if isinstance(other, Position) else NotImplemented

class Bug():

  def __init__(self, color: int, bug_type: BugType, bug_id: int = 0) -> None:
    self.color: Final[PlayerColor] = color
    self.type: Final[BugType] = bug_type
    self.num: Final[int] = bug_id

  def __hash__(self) -> int:
    return hash(str(self)[1:])*self.color # TODO: make them have a number that makes sense, like the one in C++
  
  def __eq__(self, value: object) -> bool:
    return self is value or isinstance(value, Bug) and self.color is value.color and self.type is value.type and self.num == value.num

class Move():
  PASS: Final[str] = "pass"
    
  def __init__(self, bug: Bug, origin: Optional[Position], destination: Position) -> None:
    self.bug: Final[Bug] = bug
    self.origin: Final[Optional[Position]] = origin
    self.destination: Final[Position] = destination

  def __hash__(self) -> int:
    return hash((self.bug, self.origin, self.destination))

  def __eq__(self, value: object) -> bool:
    return self is value or isinstance(value, Move) and self.bug == value.bug and self.origin == value.origin and self.destination == value.destination


In [5]:
from typing import Final, Optional, Set

class HiveBoard(): # This is simply the implementation of the board
  NEIGHBOR_DELTAS: Final[tuple[Position, Position, Position, Position, Position, Position, Position, Position]] = (
    Position(1, 0), # Right
    Position(1, -1), # Up right
    Position(0, -1), # Up left
    Position(-1, 0), # Left
    Position(-1, 1), # Down left
    Position(0, 1), # Down right
    Position(0, 0), # Below (no change)
    Position(0, 0), # Above (no change)
  )

  def __init__(self, gamestring: str = "") -> None:
    #self.type: Final[GameType] = self._parse_gamestring(gamestring) TODO: we learn only on MLP for now
    self.player: int = PlayerColor.WHITE
    self.turn = 1
    self.row_count = 28
    self.column_count = 28
    self.height_count = 5
    self.last_moved : Bug = None
    self.board = np.zeros((self.row_count, self.column_count, self.height_count), dtype=int)
    self._bug_to_pos: dict[Bug, Optional[Position]] = {}

    for color in PlayerColor:
      #for expansion in self.type:
       # if expansion is GameType.Base:
          self._bug_to_pos[Bug(color, BugType.QUEEN_BEE)] = None
          # Add ids greater than 0 only for bugs with multiple copies.
          for i in range(1, 3):
            self._bug_to_pos[Bug(color, BugType.SPIDER, i)] = None
            self._bug_to_pos[Bug(color, BugType.BEETLE, i)] = None
            self._bug_to_pos[Bug(color, BugType.GRASSHOPPER, i)] = None
            self._bug_to_pos[Bug(color, BugType.SOLDIER_ANT, i)] = None
          self._bug_to_pos[Bug(color, BugType.GRASSHOPPER, 3)] = None
          self._bug_to_pos[Bug(color, BugType.SOLDIER_ANT, 3)] = None
      #else:
          for expansion in ['M', 'L', 'P']:
              self._bug_to_pos[Bug(color, BugType(expansion.name))] = None

  @property
  def current_player_queen_in_play(self) -> bool:
    return bool(self._bug_to_pos[Bug(self.player, BugType.QUEEN_BEE)])

  def get_next_state(self, move: Move):
      turn += 1
      self.player = -self.player
      if not move.origin:
        self._bug_to_pos[move.bug] = move.destination
      else:
        self.board[move.origin.r][move.origin.q] = 0 #TODO: how to manage a piece above another one: we should have a 3d image
      self.board[move.destination.row][move.destination.column] = move.bug.to_id()


  def get_value_and_terminated(self):
    black_queen_surrounded = (queen_pos := self._bug_to_pos[Bug(PlayerColor.BLACK, BugType.QUEEN_BEE)]) and all(self._bugs_from_pos(self._get_neighbor(queen_pos, direction)) for direction in Direction.flat())
    white_queen_surrounded = (queen_pos := self._bug_to_pos[Bug(PlayerColor.WHITE, BugType.QUEEN_BEE)]) and all(self._bugs_from_pos(self._get_neighbor(queen_pos, direction)) for direction in Direction.flat())
    if black_queen_surrounded and white_queen_surrounded:
      return 0, True
    elif black_queen_surrounded:
      return 1, True
    elif white_queen_surrounded:
      return -1, True
    return 0, False

  def get_valid_moves(self) -> Set[Move]:
      # TODO: da qui
      self._valid_moves_cache = []
      for bug, pos in self._bug_to_pos.items():
        # Iterate over available pieces of the current player
        if bug.color is self.player:
          # Turn 1 is White player's first turn
          if self.turn == 1:
            # Can't place the queen on the first turn
            if bug.type is not BugType.QUEEN_BEE and self._can_bug_be_played(bug):
              # Add the only valid placement for the current bug piece
              self._valid_moves_cache.add(Move(bug, None, self.ORIGIN))
          # Turn 0 is Black player's first turn
          elif self.turn == 2:
            # Can't place the queen on the first turn
            if bug.type is not BugType.QUEEN_BEE and self._can_bug_be_played(bug):
              # Add all valid placements for the current bug piece (can be placed only around the first White player's first piece)
              self._valid_moves_cache.update(Move(bug, None, self._get_neighbor(self.ORIGIN, direction)) for direction in Direction.flat())
          # Bug piece has not been played yet
          elif not pos:
            # Check hand placement, and turn and queen placement, related rule.
            if self._can_bug_be_played(bug) and (self.current_player_turn != 4 or (self.current_player_turn == 4 and (self.current_player_queen_in_play or (not self.current_player_queen_in_play and bug.type is BugType.QUEEN_BEE)))):
              # Add all valid placements for the current bug piece
              self._valid_moves_cache.update(Move(bug, None, placement) for placement in self._get_valid_placements())
          # A bug piece in play can move only if it's at the top and its queen is in play and has not been moved in the previous player's turn
          elif self.current_player_queen_in_play and self._bugs_from_pos(pos)[-1] == bug and self._was_not_last_moved(bug):
            # Can't move pieces that would break the hive. Pieces stacked upon other can never break the hive by moving
            if len(self._bugs_from_pos(pos)) > 1 or self._can_move_without_breaking_hive(pos):
              match bug.type:
                case BugType.QUEEN_BEE:
                  self._valid_moves_cache.update(self._get_sliding_moves(bug, pos, 1))
                case BugType.SPIDER:
                  self._valid_moves_cache.update(self._get_sliding_moves(bug, pos, 3))
                case BugType.BEETLE:
                  self._valid_moves_cache.update(self._get_beetle_moves(bug, pos))
                case BugType.GRASSHOPPER:
                  self._valid_moves_cache.update(self._get_grasshopper_moves(bug, pos))
                case BugType.SOLDIER_ANT:
                  self._valid_moves_cache.update(self._get_sliding_moves(bug, pos))
                case BugType.MOSQUITO:
                  self._valid_moves_cache.update(self._get_mosquito_moves(bug, pos))
                case BugType.LADYBUG:
                  self._valid_moves_cache.update(self._get_ladybug_moves(bug, pos))
                case BugType.PILLBUG:
                  self._valid_moves_cache.update(self._get_sliding_moves(bug, pos, 1))
                  self._valid_moves_cache.update(self._get_pillbug_special_moves(pos))
            else:
              match bug.type:
                case BugType.MOSQUITO:
                  self._valid_moves_cache.update(self._get_mosquito_moves(bug, pos, True))
                case BugType.PILLBUG:
                  self._valid_moves_cache.update(self._get_pillbug_special_moves(pos))
                case _:
                  pass
      return self._valid_moves_cache

  def _get_valid_placements(self) -> Set[Position]:
    placements: Set[Position] = set()
    # Iterate over all placed bug pieces of the current player
    for bug, pos in self._bug_to_pos.items():
      if bug.color is self.current_player_color and pos and self._is_bug_on_top(bug):
        # Iterate over all neighbors of the current bug piece
        for direction in Direction.flat():
          neighbor = self._get_neighbor(pos, direction)
          # If the neighboring tile is empty
          if not self._bugs_from_pos(neighbor):
            # If all neighbor's neighbors are empty or of the same color, add the neighbor as a valid placement
            if all(not self._bugs_from_pos(self._get_neighbor(neighbor, dir)) or self._bugs_from_pos(self._get_neighbor(neighbor, dir))[-1].color is self.current_player_color for dir in Direction.flat() if dir is not direction.opposite):
              placements.add(neighbor)
    return placements

  def _get_sliding_moves(self, bug: Bug, origin: Position, depth: int = 0) -> Set[Move]:
    destinations: Set[Position] = set()
    visited: Set[Position] = set()
    stack: Set[tuple[Position, int]] = {(origin, 0)}
    unlimited_depth = depth == 0
    while stack:
      current, current_depth = stack.pop()
      visited.add(current)
      if unlimited_depth or current_depth == depth:
        destinations.add(current)
      if unlimited_depth or current_depth < depth:
        stack.update(
          (neighbor, current_depth + 1)
          for direction in Direction.flat()
          if (neighbor := self._get_neighbor(current, direction)) not in visited and not self._bugs_from_pos(neighbor) and bool(self._bugs_from_pos((right := self._get_neighbor(current, direction.right_of)))) != bool(self._bugs_from_pos((left := self._get_neighbor(current, direction.left_of)))) and right != origin != left
        )
    return {Move(bug, origin, destination) for destination in destinations if destination != origin}

  def _get_beetle_moves(self, bug: Bug, origin: Position, virtual: bool = False) -> Set[Move]:
    moves: Set[Move] = set()
    for direction in Direction.flat():
      # Don't consider the Beetle in the height, unless it's a virtual move (the bug is not actually in origin, but moving at the top of origin is part of its full move).
      height = len(self._bugs_from_pos(origin)) - 1 + virtual
      destination = self._get_neighbor(origin, direction)
      dest_height = len(self._bugs_from_pos(destination))
      left_height = len(self._bugs_from_pos(self._get_neighbor(origin, direction.left_of)))
      right_height = len(self._bugs_from_pos(self._get_neighbor(origin, direction.right_of)))
      # Logic from http://boardgamegeek.com/wiki/page/Hive_FAQ#toc9
      if not ((height == 0 and dest_height == 0 and left_height == 0 and right_height == 0) or (dest_height < left_height and dest_height < right_height and height < left_height and height < right_height)):
        moves.add(Move(bug, origin, destination))
    return moves

  def _get_grasshopper_moves(self, bug: Bug, origin: Position) -> Set[Move]:
    moves: Set[Move] = set()
    for direction in Direction.flat():
      destination: Position = self._get_neighbor(origin, direction)
      distance: int = 0
      while self._bugs_from_pos(destination):
        # Jump one more tile in the same direction
        destination = self._get_neighbor(destination, direction)
        distance += 1
      if distance > 0:
        # Can only move if there's at least one piece in the way
        moves.add(Move(bug, origin, destination))
    return moves

  def _get_mosquito_moves(self, bug: Bug, origin: Position, special_only: bool = False) -> Set[Move]:
    if len(self._bugs_from_pos(origin)) > 1:
      return self._get_beetle_moves(bug, origin)
    moves: Set[Move] = set()
    bugs_copied: Set[BugType] = set()
    for direction in Direction.flat():
      if (bugs := self._bugs_from_pos(self._get_neighbor(origin, direction))) and (neighbor := bugs[-1]).type not in bugs_copied:
         bugs_copied.add(neighbor.type)
         if special_only:
           if neighbor.type == BugType.PILLBUG:
             moves.update(self._get_pillbug_special_moves(origin))
         else:
          match neighbor.type:
            case BugType.QUEEN_BEE:
              moves.update(self._get_sliding_moves(bug, origin, 1))
            case BugType.SPIDER:
              moves.update(self._get_sliding_moves(bug, origin, 3))
            case BugType.BEETLE:
              moves.update(self._get_beetle_moves(bug, origin))
            case BugType.GRASSHOPPER:
              moves.update(self._get_grasshopper_moves(bug, origin))
            case BugType.SOLDIER_ANT:
              moves.update(self._get_sliding_moves(bug, origin))
            case BugType.LADYBUG:
              moves.update(self._get_ladybug_moves(bug, origin))
            case BugType.PILLBUG:
              moves.update(self._get_sliding_moves(bug, origin, 1))
            case BugType.MOSQUITO:
              pass
    return moves

  def _get_ladybug_moves(self, bug: Bug, origin: Position) -> Set[Move]:
    return {
      Move(bug, origin, final_move.destination)
      for first_move in self._get_beetle_moves(bug, origin, True) if self._bugs_from_pos(first_move.destination)
      for second_move in self._get_beetle_moves(bug, first_move.destination, True) if self._bugs_from_pos(second_move.destination) and second_move.destination != origin
      for final_move in self._get_beetle_moves(bug, second_move.destination, True) if not self._bugs_from_pos(final_move.destination) and final_move.destination != origin
    }

  def _get_pillbug_special_moves(self, origin: Position) -> Set[Move]:
    moves: Set[Move] = set()
    # There must be at least one empty neighboring tile for the Pillbug to move another bug piece
    if (empty_positions := [self._get_neighbor(origin, direction) for direction in Direction.flat() if not self._bugs_from_pos(self._get_neighbor(origin, direction))]):
      for direction in Direction.flat():
        position = self._get_neighbor(origin, direction)
        # A Pillbug can move another bug piece only if it's not stacked, it's not the last moved piece, it can be moved without breaking the hive, and it's not obstructed in moving above the Pillbug itself
        if len(bugs := self._bugs_from_pos(position)) == 1 and self._was_not_last_moved(neighbor := bugs[-1]) and self._can_move_without_breaking_hive(position) and Move(neighbor, position, origin) in self._get_beetle_moves(neighbor, position):
          moves.update(Move(neighbor, position, move.destination) for move in self._get_beetle_moves(neighbor, position, True) if move.destination in empty_positions)
    return moves

  def _can_move_without_breaking_hive(self, position: Position) -> bool:
    # Try gaps heuristic first
    neighbors: list[list[Bug]] = [self._bugs_from_pos(self._get_neighbor(position, direction)) for direction in Direction.flat()]
    # If there is more than 1 gap, perform a DFS to check if all neighbors are still connected in some way.
    if sum(bool(neighbors[i] and not neighbors[i - 1]) for i in range(len(neighbors))) > 1:
      visited: Set[Position] = set()
      neighbors_pos: list[Position] = [pos for bugs in neighbors if bugs and (pos := self._pos_from_bug(bugs[-1]))]    
      stack: Set[Position] = {neighbors_pos[0]}
      while stack:
        current = stack.pop()
        visited.add(current)
        stack.update(neighbor for direction in Direction.flat() if (neighbor := self._get_neighbor(current, direction)) != position and self._bugs_from_pos(neighbor) and neighbor not in visited)
      # Check if all neighbors with bug pieces were visited
      return all(neighbor_pos in visited for neighbor_pos in neighbors_pos)
    # If there is only 1 gap, then all neighboring pieces are connected even without the piece at the given position.
    return True

  def _can_bug_be_played(self, piece: Bug) -> bool:
    return all(bug.id >= piece.id for bug, pos in self._bug_to_pos.items() if pos is None and bug.type is piece.type and bug.color is piece.color)

  def _was_not_last_moved(self, bug: Bug) -> bool: 
    return not self.moves[-1] or self.moves[-1].bug != bug

  def _is_bug_on_top(self, bug: Bug) -> bool:
    return (pos := self._pos_from_bug(bug)) != None and self._bugs_from_pos(pos)[-1] == bug

  def _bugs_from_pos(self, position: Position) -> list[Bug]:
    return self._pos_to_bug[position] if position in self._pos_to_bug else []

  def _pos_from_bug(self, bug: Bug) -> Optional[Position]:
    return self._bug_to_pos[bug] if bug in self._bug_to_pos else None

  def _get_neighbor(self, position: Position, direction: Direction) -> Position:
    return position + self.NEIGHBOR_DELTAS[direction.delta_index]
  
  def get_opponent(self, player):
    return -player
  
  def get_opponent_value(self, value):
    return -value
  
  def change_perspective(self, state, player):
    return state * player
  
  def get_encoded_state(self, state):
    encoded_state = np.stack(
        (state if state < 0 else 0, state == 0, state if state > 0 else 0)
    ).astype(np.float32)
    
    if len(state.shape) == 3:
      encoded_state = np.swapaxes(encoded_state, 0, 1)
    
    return encoded_state

  def state(self):
      return self.board
  

In [6]:
game = HiveBoard("Base+MLP")
player = 1


while True:
    print(game.state())
    
    valid_moves = game.get_valid_moves()
    print(valid_moves)
    print("valid_moves", [valid_moves[i] for i in range(len(valid_moves))])
    action = int(input(f"{player}:"))

    if valid_moves[action] == 0:
        print("action not valid")
        continue
            
    state = game.get_next_state(state, action, player)
    
    value, is_terminal = game.get_value_and_terminated(state, action)
    
    if is_terminal:
        print(state)
        if value == 1:
            print(player, "won")
        else:
            print("draw")
        break
        
    player = game.get_opponent(player)


NameError: name 'np' is not defined

### Neural Network Model

In [4]:
class ResNet(nn.Module):
    def __init__(self, game, num_resBlocks, num_hidden, device):
        super().__init__()
        
        self.device = device
        self.startBlock = nn.Sequential(
            nn.Conv2d(3, num_hidden, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_hidden),
            nn.ReLU()
        )
        
        self.backBone = nn.ModuleList(
            [ResBlock(num_hidden) for i in range(num_resBlocks)]
        )
        
        self.policyHead = nn.Sequential(
            nn.Conv2d(num_hidden, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * game.row_count * game.column_count, game.action_size)
        )
        
        self.valueHead = nn.Sequential(
            nn.Conv2d(num_hidden, 3, kernel_size=3, padding=1),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3 * game.row_count * game.column_count, 1),
            nn.Tanh()
        )
        
        self.to(device)
        
    def forward(self, x):
        x = self.startBlock(x)
        for resBlock in self.backBone:
            x = resBlock(x)
        policy = self.policyHead(x)
        value = self.valueHead(x)
        return policy, value
        
        
class ResBlock(nn.Module):
    def __init__(self, num_hidden):
        super().__init__()
        self.conv1 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_hidden)
        self.conv2 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_hidden)
        
    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        x = F.relu(x)
        return x
        

### Basic MCTS and Alpha Zero

In [6]:
class Node:
    def __init__(self, game, args, state, parent=None, action_taken=None, prior=0, visit_count=0):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken
        self.prior = prior
        
        self.children = []
        
        self.visit_count = visit_count
        self.value_sum = 0
        
    def is_fully_expanded(self):
        return len(self.children) > 0
    
    def select(self):
        best_child = None
        best_ucb = -np.inf
        
        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb
                
        return best_child
    
    def get_ucb(self, child):
        if child.visit_count == 0:
            q_value = 0
        else:
            q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2
        return q_value + self.args['C'] * (math.sqrt(self.visit_count) / (child.visit_count + 1)) * child.prior
    
    def expand(self, policy):
        for action, prob in enumerate(policy):
            if prob > 0:
                child_state = self.state.copy()
                child_state = self.game.get_next_state(child_state, action, 1)
                child_state = self.game.change_perspective(child_state, player=-1)

                child = Node(self.game, self.args, child_state, self, action, prob)
                self.children.append(child)
                
        return child
            
    def backpropagate(self, value):
        self.value_sum += value
        self.visit_count += 1
        
        value = self.game.get_opponent_value(value)
        if self.parent is not None:
            self.parent.backpropagate(value)  


class MCTS:
    def __init__(self, game, args, model):
        self.game = game
        self.args = args
        self.model = model
        
    @torch.no_grad()
    def search(self, state):
        root = Node(self.game, self.args, state, visit_count=1)
        
        policy, _ = self.model(
            torch.tensor(self.game.get_encoded_state(state), device=self.model.device).unsqueeze(0)
        )
        policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
        policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] \
            * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size)
        
        valid_moves = self.game.get_valid_moves(state)
        policy *= valid_moves
        policy /= np.sum(policy)
        root.expand(policy)
        
        for search in range(self.args['num_searches']):
            node = root
            
            while node.is_fully_expanded():
                node = node.select()
                
            value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
            value = self.game.get_opponent_value(value)
            
            if not is_terminal:
                policy, value = self.model(
                    torch.tensor(self.game.get_encoded_state(node.state), device=self.model.device).unsqueeze(0)
                )
                policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
                valid_moves = self.game.get_valid_moves(node.state)
                policy *= valid_moves
                policy /= np.sum(policy)
                
                value = value.item()
                
                node.expand(policy)
                
            node.backpropagate(value)    
            
            
        action_probs = np.zeros(self.game.action_size)
        for child in root.children:
            action_probs[child.action_taken] = child.visit_count
        action_probs /= np.sum(action_probs)
        return action_probs
        

In [7]:
class AlphaZero:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTS(game, args, model)
        
    def selfPlay(self):
        memory = []
        player = 1
        state = self.game.get_initial_state()
        
        while True:
            neutral_state = self.game.change_perspective(state, player)
            action_probs = self.mcts.search(neutral_state)
            
            memory.append((neutral_state, action_probs, player))
            
            temperature_action_probs = action_probs ** (1 / self.args['temperature'])
            action = np.random.choice(self.game.action_size, p=temperature_action_probs) # Divide temperature_action_probs with its sum in case of an error
            
            state = self.game.get_next_state(state, action, player)
            
            value, is_terminal = self.game.get_value_and_terminated(state, action)
            
            if is_terminal:
                returnMemory = []
                for hist_neutral_state, hist_action_probs, hist_player in memory:
                    hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
                    returnMemory.append((
                        self.game.get_encoded_state(hist_neutral_state),
                        hist_action_probs,
                        hist_outcome
                    ))
                return returnMemory
            
            player = self.game.get_opponent(player)
                
    def train(self, memory):
        random.shuffle(memory)
        for batchIdx in range(0, len(memory), self.args['batch_size']):
            sample = memory[batchIdx:min(len(memory) - 1, batchIdx + self.args['batch_size'])] # Change to memory[batchIdx:batchIdx+self.args['batch_size']] in case of an error
            state, policy_targets, value_targets = zip(*sample)
            
            state, policy_targets, value_targets = np.array(state), np.array(policy_targets), np.array(value_targets).reshape(-1, 1)
            
            state = torch.tensor(state, dtype=torch.float32, device=self.model.device)
            policy_targets = torch.tensor(policy_targets, dtype=torch.float32, device=self.model.device)
            value_targets = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)
            
            out_policy, out_value = self.model(state)
            
            policy_loss = F.cross_entropy(out_policy, policy_targets)
            value_loss = F.mse_loss(out_value, value_targets)
            loss = policy_loss + value_loss
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
    
    def learn(self):
        for iteration in range(self.args['num_iterations']):
            memory = []
            
            self.model.eval()
            for selfPlay_iteration in trange(self.args['num_selfPlay_iterations']):
                memory += self.selfPlay()
                
            self.model.train()
            for epoch in trange(self.args['num_epochs']):
                self.train(memory)
            
            torch.save(self.model.state_dict(), f"model_{iteration}_{self.game}.pt")
            torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}_{self.game}.pt")

### Optimized and parallel (even if without threads) version of MCTS and Alpha Zero

In [8]:
class MCTSParallel:
    def __init__(self, game, args, model):
        self.game = game
        self.args = args
        self.model = model
        
    @torch.no_grad()
    def search(self, states, spGames):
        policy, _ = self.model(
            torch.tensor(self.game.get_encoded_state(states), device=self.model.device)
        )
        policy = torch.softmax(policy, axis=1).cpu().numpy()
        policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] \
            * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size, size=policy.shape[0])
        
        for i, spg in enumerate(spGames):
            spg_policy = policy[i]
            valid_moves = self.game.get_valid_moves(states[i])
            spg_policy *= valid_moves
            spg_policy /= np.sum(spg_policy)

            spg.root = Node(self.game, self.args, states[i], visit_count=1)
            spg.root.expand(spg_policy)
        
        for search in range(self.args['num_searches']):
            for spg in spGames:
                spg.node = None
                node = spg.root

                while node.is_fully_expanded():
                    node = node.select()

                value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
                value = self.game.get_opponent_value(value)
                
                if is_terminal:
                    node.backpropagate(value)
                    
                else:
                    spg.node = node
                    
            expandable_spGames = [mappingIdx for mappingIdx in range(len(spGames)) if spGames[mappingIdx].node is not None]
                    
            if len(expandable_spGames) > 0:
                states = np.stack([spGames[mappingIdx].node.state for mappingIdx in expandable_spGames])
                
                policy, value = self.model(
                    torch.tensor(self.game.get_encoded_state(states), device=self.model.device)
                )
                policy = torch.softmax(policy, axis=1).cpu().numpy()
                value = value.cpu().numpy()
                
            for i, mappingIdx in enumerate(expandable_spGames):
                node = spGames[mappingIdx].node
                spg_policy, spg_value = policy[i], value[i]
                
                valid_moves = self.game.get_valid_moves(node.state)
                spg_policy *= valid_moves
                spg_policy /= np.sum(spg_policy)

                node.expand(spg_policy)
                node.backpropagate(spg_value)

In [9]:
class AlphaZeroParallel:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTSParallel(game, args, model)
        
    def selfPlay(self):
        return_memory = []
        player = 1
        spGames = [SPG(self.game) for spg in range(self.args['num_parallel_games'])]
        
        while len(spGames) > 0:
            states = np.stack([spg.state for spg in spGames])
            neutral_states = self.game.change_perspective(states, player)
            
            self.mcts.search(neutral_states, spGames)
            
            for i in range(len(spGames))[::-1]:
                spg = spGames[i]
                
                action_probs = np.zeros(self.game.action_size)
                for child in spg.root.children:
                    action_probs[child.action_taken] = child.visit_count
                action_probs /= np.sum(action_probs)

                spg.memory.append((spg.root.state, action_probs, player))

                temperature_action_probs = action_probs ** (1 / self.args['temperature'])
                action = np.random.choice(self.game.action_size, p=temperature_action_probs) # Divide temperature_action_probs with its sum in case of an error

                spg.state = self.game.get_next_state(spg.state, action, player)

                value, is_terminal = self.game.get_value_and_terminated(spg.state, action)

                if is_terminal:
                    for hist_neutral_state, hist_action_probs, hist_player in spg.memory:
                        hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
                        return_memory.append((
                            self.game.get_encoded_state(hist_neutral_state),
                            hist_action_probs,
                            hist_outcome
                        ))
                    del spGames[i]
                    
            player = self.game.get_opponent(player)
            
        return return_memory
                
    def train(self, memory):
        random.shuffle(memory)
        for batchIdx in range(0, len(memory), self.args['batch_size']):
            sample = memory[batchIdx:min(len(memory) - 1, batchIdx + self.args['batch_size'])] # Change to memory[batchIdx:batchIdx+self.args['batch_size']] in case of an error
            state, policy_targets, value_targets = zip(*sample)
            
            state, policy_targets, value_targets = np.array(state), np.array(policy_targets), np.array(value_targets).reshape(-1, 1)
            
            state = torch.tensor(state, dtype=torch.float32, device=self.model.device)
            policy_targets = torch.tensor(policy_targets, dtype=torch.float32, device=self.model.device)
            value_targets = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)
            
            out_policy, out_value = self.model(state)
            
            policy_loss = F.cross_entropy(out_policy, policy_targets)
            value_loss = F.mse_loss(out_value, value_targets)
            loss = policy_loss + value_loss
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
    
    def learn(self):
        for iteration in range(self.args['num_iterations']):
            memory = []
            
            self.model.eval()
            for selfPlay_iteration in trange(self.args['num_selfPlay_iterations'] // self.args['num_parallel_games']):
                memory += self.selfPlay()
                
            self.model.train()
            for epoch in trange(self.args['num_epochs']):
                self.train(memory)
            
            torch.save(self.model.state_dict(), f"model_{iteration}_{self.game}.pt")
            torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}_{self.game}.pt")
            
class SPG:
    def __init__(self, game):
        self.state = game.get_initial_state()
        self.memory = []
        self.root = None
        self.node = None

### Training

In [None]:
game = ConnectFour()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ResNet(game, 9, 128, device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

args = {
    'C': 2,
    'num_searches': 600,
    'num_iterations': 8,
    'num_selfPlay_iterations': 500,
    'num_parallel_games': 100,
    'num_epochs': 4,
    'batch_size': 128,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}

alphaZero = AlphaZeroParallel(model, optimizer, game, args)
alphaZero.learn()

### Test to see how it plays

To make it play we only need the model (ResNet), the file produced by the training (.pt file) and the MCTS algorithm (standard version).

In [None]:
game = ConnectFour()
player = 1

args = {
    'C': 2,
    'num_searches': 600,
    'dirichlet_epsilon': 0.,
    'dirichlet_alpha': 0.3
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ResNet(game, 9, 128, device)
model.load_state_dict(torch.load("model_7_ConnectFour.pt", map_location=device))
model.eval()

mcts = MCTS(game, args, model)

state = game.get_initial_state()


while True:
    print(state)
    
    if player == 1:
        valid_moves = game.get_valid_moves(state)
        print("valid_moves", [i for i in range(game.action_size) if valid_moves[i] == 1])
        action = int(input(f"{player}:"))

        if valid_moves[action] == 0:
            print("action not valid")
            continue
            
    else:
        neutral_state = game.change_perspective(state, player)
        mcts_probs = mcts.search(neutral_state)
        action = np.argmax(mcts_probs)
        
    state = game.get_next_state(state, action, player)
    
    value, is_terminal = game.get_value_and_terminated(state, action)
    
    if is_terminal:
        print(state)
        if value == 1:
            print(player, "won")
        else:
            print("draw")
        break
        
    player = game.get_opponent(player)