In [0]:
# Lint as: python3
"""Pseudocode description of the MuZero algorithm."""
# pylint: disable=unused-argument
# pylint: disable=missing-docstring
# pylint: disable=assignment-from-no-return

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import math
import typing
from typing import Dict, List, Optional
import enum

import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

import threading

##########################
####### Helpers ##########

MAXIMUM_FLOAT_VALUE = float('inf')

KnownBounds = collections.namedtuple('KnownBounds', ['min', 'max'])

# noinspection PyArgumentList
Winner = enum.Enum("Winner", "black white draw")

# noinspection PyArgumentList
Player = enum.Enum("Player", "black white")

num_filters = 2
num_blocks = 8

class MinMaxStats(object):
  """A class that holds the min-max values of the tree."""

  def __init__(self, known_bounds: Optional[KnownBounds]):
    self.maximum = known_bounds.max if known_bounds else -MAXIMUM_FLOAT_VALUE
    self.minimum = known_bounds.min if known_bounds else MAXIMUM_FLOAT_VALUE

  def update(self, value: float):
    self.maximum = max(self.maximum, value)
    self.minimum = min(self.minimum, value)

  def normalize(self, value: float) -> float:
    if self.maximum > self.minimum:
      # We normalize only when we have set the maximum and minimum values.
      return (value - self.minimum) / (self.maximum - self.minimum)
    return value


class MuZeroConfig(object):

  def __init__(self,
               action_space_size: int,
               max_moves: int,
               discount: float,
               dirichlet_alpha: float,
               num_simulations: int,
               batch_size: int,
               td_steps: int,
               num_actors: int,
               lr_init: float,
               lr_decay_steps: float,
               visit_softmax_temperature_fn,
               known_bounds: Optional[KnownBounds] = None):
    ### Self-Play
    self.action_space_size = action_space_size
    self.num_actors = num_actors

    self.visit_softmax_temperature_fn = visit_softmax_temperature_fn
    self.max_moves = max_moves
    self.num_simulations = num_simulations
    self.discount = discount

    # Root prior exploration noise.
    self.root_dirichlet_alpha = dirichlet_alpha
    self.root_exploration_fraction = 0.25

    # UCB formula
    self.pb_c_base = 19652
    self.pb_c_init = 1.25

    # If we already have some information about which values occur in the
    # environment, we can use them to initialize the rescaling.
    # This is not strictly necessary, but establishes identical behaviour to
    # AlphaZero in board games.
    self.known_bounds = known_bounds

    ### Training
    self.training_steps = int(1e6)
    self.checkpoint_interval = int(100)
    self.window_size = int(1e6)
    self.batch_size = batch_size
    self.num_unroll_steps = 4
    self.td_steps = td_steps

    self.weight_decay = 1e-4
    self.momentum = 0.9

    # Exponential learning rate schedule
    self.lr_init = lr_init
    self.lr_decay_rate = 0.1
    self.lr_decay_steps = lr_decay_steps

  def new_game(self):
    return Game(self.action_space_size, self.discount)


def make_board_game_config(action_space_size: int, max_moves: int,
                           dirichlet_alpha: float,
                           lr_init: float) -> MuZeroConfig:

  def visit_softmax_temperature(num_moves, training_steps):
    if num_moves < 30:
      return 1.0
    else:
      return 0.0  # Play according to the max.

  return MuZeroConfig(
      action_space_size=action_space_size,
      max_moves=max_moves,
      discount=1.0,
      dirichlet_alpha=dirichlet_alpha,
      num_simulations=10,
      batch_size=64,
      td_steps=max_moves,  # Always use Monte Carlo return.
      num_actors=1,
      lr_init=lr_init,
      lr_decay_steps=400e3,
      visit_softmax_temperature_fn=visit_softmax_temperature,
      known_bounds=KnownBounds(-1, 1))

def make_connect4_config() -> MuZeroConfig:
  return make_board_game_config(
      action_space_size=7, max_moves=20, dirichlet_alpha=0.03, lr_init=0.01)

class Action(object):

  def __init__(self, index: int):
    self.index = index

  def __hash__(self):
    return self.index

  def __eq__(self, other):
    return self.index == other

  def __gt__(self, other):
    return self.index > other


In [0]:

class Node(object):

  def __init__(self, prior: float):
    self.visit_count = 0
    self.to_play = -1
    self.prior = prior
    self.value_sum = 0
    self.children = {}
    self.hidden_state = None
    self.reward = 0

  def expanded(self) -> bool:
    return len(self.children) > 0

  def value(self) -> float:
    if self.visit_count == 0:
      return 0
    return self.value_sum / self.visit_count


class ActionHistory(object):
  """Simple history container used inside the search.

  Only used to keep track of the actions executed.
  """

  def __init__(self, history: List[Action], action_space_size: int):
    self.history = list(history)
    self.action_space_size = action_space_size

  def clone(self):
    return ActionHistory(self.history, self.action_space_size)

  def add_action(self, action: Action):
    self.history.append(action)

  def last_action(self) -> Action:
    return self.history[-1]

  def action_space(self) -> List[Action]:
    return [i for i in range(self.action_space_size)]

  def to_play(self) -> Player:
    if len(self.history) % 2 == 0:
      return Player.white
    else:
      return Player.black

class Environment(object):
  """The environment MuZero is interacting with."""

  def __init__(self):
      self.board = None
      self.turn = 0
      self.done = False
      self.winner = None  # type: Winner
      self.resigned = False

  def reset(self):
      self.board = []
      for i in range(6):
          self.board.append([])
          for j in range(7): # pylint: disable=unused-variable
              self.board[i].append(' ')
      self.turn = 0
      self.done = False
      self.winner = None
      self.resigned = False
      return self

  def update(self, board):
      self.board = numpy.copy(board)
      self.turn = self.turn_n()
      self.done = False
      self.winner = None
      self.resigned = False
      return self

  def turn_n(self):
      turn = 0
      for i in range(6):
          for j in range(7):
              if self.board[i][j] != ' ':
                  turn += 1

      return turn

  def player_turn(self):
      if self.turn % 2 == 0:
          return Player.white
      else:
          return Player.black

  def step(self, action):
      for i in range(6):
          if self.board[i][action] == ' ':
              self.board[i][action] = ('X' if self.player_turn() == Player.white else 'O')
              break

      self.turn += 1

      self.check_for_fours()

      if self.turn >= 42:
          self.done = True
          if self.winner is None:
              self.winner = Winner.draw

      r = 0
      if self.done:
        if self.turn % 2 == 0:
          if Winner.white:
            r = 1
          elif Winner.black:
            r = -1
        else:
          if Winner.black:
            r = 1
          elif Winner.white:
            r = -1

      return r

  def legal_moves(self):
      legal = [0, 0, 0, 0, 0, 0, 0]
      for j in range(7):
          for i in range(6):
              if self.board[i][j] == ' ':
                  legal[j] = 1
                  break
      return legal

  def legal_actions(self):
      legal = []
      for j in range(7):
          for i in range(6):
              if self.board[i][j] == ' ':
                  legal.append(j)
                  break
      return legal

  def check_for_fours(self):
      for i in range(6):
          for j in range(7):
              if self.board[i][j] != ' ':
                  # check if a vertical four-in-a-row starts at (i, j)
                  if self.vertical_check(i, j):
                      self.done = True
                      return

                  # check if a horizontal four-in-a-row starts at (i, j)
                  if self.horizontal_check(i, j):
                      self.done = True
                      return

                  # check if a diagonal (either way) four-in-a-row starts at (i, j)
                  diag_fours = self.diagonal_check(i, j)
                  if diag_fours:
                      self.done = True
                      return

  def vertical_check(self, row, col):
      # print("checking vert")
      four_in_a_row = False
      consecutive_count = 0

      for i in range(row, 6):
          if self.board[i][col].lower() == self.board[row][col].lower():
              consecutive_count += 1
          else:
              break

      if consecutive_count >= 4:
          four_in_a_row = True
          if 'x' == self.board[row][col].lower():
              self.winner = Winner.white
          else:
              self.winner = Winner.black

      return four_in_a_row

  def horizontal_check(self, row, col):
      four_in_a_row = False
      consecutive_count = 0

      for j in range(col, 7):
          if self.board[row][j].lower() == self.board[row][col].lower():
              consecutive_count += 1
          else:
              break

      if consecutive_count >= 4:
          four_in_a_row = True
          if 'x' == self.board[row][col].lower():
              self.winner = Winner.white
          else:
              self.winner = Winner.black

      return four_in_a_row

  def diagonal_check(self, row, col):
      four_in_a_row = False
      count = 0

      consecutive_count = 0
      j = col
      for i in range(row, 6):
          if j > 6:
              break
          elif self.board[i][j].lower() == self.board[row][col].lower():
              consecutive_count += 1
          else:
              break
          j += 1

      if consecutive_count >= 4:
          count += 1
          if 'x' == self.board[row][col].lower():
              self.winner = Winner.white
          else:
              self.winner = Winner.black

      consecutive_count = 0
      j = col
      for i in range(row, -1, -1):
          if j > 6:
              break
          elif self.board[i][j].lower() == self.board[row][col].lower():
              consecutive_count += 1
          else:
              break
          j += 1

      if consecutive_count >= 4:
          count += 1
          if 'x' == self.board[row][col].lower():
              self.winner = Winner.white
          else:
              self.winner = Winner.black

      if count > 0:
          four_in_a_row = True

      return four_in_a_row

  def black_and_white_plane(self):
      board_white = numpy.copy(self.board)
      board_black = numpy.copy(self.board)
      for i in range(6):
          for j in range(7):
              if self.board[i][j] == ' ':
                  board_white[i][j] = 0
                  board_black[i][j] = 0
              elif self.board[i][j] == 'X':
                  board_white[i][j] = 1
                  board_black[i][j] = 0
              else:
                  board_white[i][j] = 0
                  board_black[i][j] = 1

      return numpy.array(board_white), numpy.array(board_black)

  def render(self):
      print("\nRound: " + str(self.turn))

      for i in range(5, -1, -1):
          print("\t", end="")
          for j in range(7):
              print("| " + str(self.board[i][j]), end=" ")
          print("|")
      print("\t  _   _   _   _   _   _   _ ")
      print("\t  1   2   3   4   5   6   7 ")

      if self.done:
          print("Game Over!")
          if self.winner == Winner.white:
              print("X is the winner")
          elif self.winner == Winner.black:
              print("O is the winner")
          else:
              print("Game was a draw")

  @property
  def observation(self):
      return ''.join(''.join(x for x in y) for y in self.board)


class Game(object):
  """A single episode of interaction with the environment."""

  def __init__(self, action_space_size: int, discount: float):
    self.environment = Environment().reset()  # Game specific environment.
    self.history = []
    self.rewards = []
    self.child_visits = []
    self.root_values = []
    self.action_space_size = action_space_size
    self.discount = discount

  def terminal(self) -> bool:
    # Game specific termination rules.
    return self.environment.done

  def legal_actions(self) -> List[Action]:
    # Game specific calculation of legal actions.
    return self.environment.legal_actions()

  def apply(self, action: Action):
    reward = self.environment.step(action)
    reward = reward if self.environment.turn % 2 != 0 and reward == 1 else -reward
    self.rewards.append(reward)
    self.history.append(action)

  def store_search_statistics(self, root: Node):
    sum_visits = sum(child.visit_count for child in root.children.values())
    action_space = (Action(index) for index in range(self.action_space_size))
    self.child_visits.append([
        root.children[a].visit_count / sum_visits if a in root.children else 0
        for a in action_space
    ])
    self.root_values.append(root.value())

  def make_image(self, state_index: int):
    # Game specific feature planes.    
    o = Environment().reset()

    for current_index in range(0, state_index):
      o.step(self.history[current_index])

    black_ary, white_ary = o.black_and_white_plane()
    state = [black_ary, white_ary] if o.player_turn() == Player.black else [white_ary, black_ary]
    return numpy.array(state)

  def make_target(self, state_index: int, num_unroll_steps: int, td_steps: int,
                  to_play: Player):
    # The value target is the discounted root value of the search tree N steps
    # into the future, plus the discounted sum of all rewards until then.
    targets = []
    for current_index in range(state_index, state_index + num_unroll_steps + 1):
      bootstrap_index = current_index + td_steps
      if bootstrap_index < len(self.root_values):
        value = self.root_values[bootstrap_index] * self.discount**td_steps
      else:
        value = 0

      for i, reward in enumerate(self.rewards[current_index:bootstrap_index]):
        value += reward * self.discount**i  # pytype: disable=unsupported-operands

      if current_index < len(self.root_values):
        targets.append((value, self.rewards[current_index],
                        self.child_visits[current_index]))
      else:
        # States past the end of games are treated as absorbing states.
        targets.append((0, 0, []))
    return targets

  def to_play(self) -> Player:
    return self.environment.player_turn

  def action_history(self) -> ActionHistory:
    return ActionHistory(self.history, self.action_space_size)



In [0]:

class ReplayBuffer(object):

  def __init__(self, config: MuZeroConfig):
    self.window_size = config.window_size
    self.batch_size = config.batch_size
    self.buffer = []

  def save_game(self, game):    
    if len(self.buffer) > self.window_size:
      self.buffer.pop(0)
    self.buffer.append(game)

  def sample_batch(self, num_unroll_steps: int, td_steps: int):
    games = [self.sample_game() for _ in range(self.batch_size)]
    game_pos = [(g, self.sample_position(g)) for g in games]
    return [(g.make_image(i), g.history[i:i + num_unroll_steps],
             g.make_target(i, num_unroll_steps, td_steps, g.to_play()))
            for (g, i) in game_pos]

  def sample_game(self) -> Game:
    # Sample game from buffer either uniformly or according to some priority.
    return numpy.random.choice(self.buffer)

  def sample_position(self, game) -> int:
    # Sample position from game either uniformly or according to some priority.
    return numpy.random.choice(game.history)

# Nets
class NetworkOutput(typing.NamedTuple):
  value: float
  reward: float
  policy_logits: Dict[Action, float]
  hidden_state: List[float]


class Conv(nn.Module):
    def __init__(self, filters0, filters1, kernel_size, bn=False):
        super().__init__()
        self.conv = nn.Conv2d(filters0, filters1, kernel_size, stride=1, padding=kernel_size//2, bias=False)
        self.bn = None
        if bn:
            self.bn = nn.BatchNorm2d(filters1)

    def forward(self, x):
        h = self.conv(x)
        if self.bn is not None:
            h = self.bn(h)
        return h

class ResidualBlock(nn.Module):
    def __init__(self, filters):
        super().__init__()
        self.conv = Conv(filters, filters, 3, True)

    def forward(self, x):
        return F.relu(x + (self.conv(x)))

class Representation(nn.Module):
    ''' Conversion from observation to inner abstract state '''
    def __init__(self, input_shape):
        super().__init__()
        self.input_shape = input_shape
        self.board_size = self.input_shape[1] * self.input_shape[2]

        self.layer0 = Conv(self.input_shape[0], num_filters, 3, bn=True)
        self.blocks = nn.ModuleList([ResidualBlock(num_filters) for _ in range(num_blocks)])

    def forward(self, x):
        h = F.relu(self.layer0(x))
        for block in self.blocks:
            h = block(h)
        return h

class Prediction(nn.Module):
    ''' Policy and value prediction from inner abstract state '''
    def __init__(self, action_shape):
        super().__init__()
        self.board_size = 42
        self.action_size = action_shape

        self.conv_p1 = Conv(num_filters, 4, 1, bn=True)
        self.conv_p2 = Conv(4, 1, 1)

        self.conv_v = Conv(num_filters, 4, 1, bn=True)
        self.fc_v = nn.Linear(self.board_size * 4, 1, bias=False)

    def forward(self, rp):
        h_p = F.relu(self.conv_p1(rp))
        h_p = self.conv_p2(h_p).view(-1, self.action_size)

        h_v = F.relu(self.conv_v(rp))
        h_v = self.fc_v(h_v.view(-1, self.board_size * 4))

        # range of value is -1 ~ 1
        return F.softmax(h_p, dim=-1), torch.tanh(h_v)

class Dynamics(nn.Module):
    '''Abstruct state transition'''
    def __init__(self, rp_shape, act_shape):
        super().__init__()
        self.rp_shape = rp_shape
        self.layer0 = Conv(rp_shape[0] + act_shape[0], num_filters, 3, bn=True)
        self.blocks = nn.ModuleList([ResidualBlock(num_filters) for _ in range(num_blocks)])

    def forward(self, rp, a):
        h = torch.cat([rp, a], dim=1)
        h = self.layer0(h)
        for block in self.blocks:
            h = block(h)
        return h

class Network(nn.Module):

  def __init__(self, action_space_size: int):
    super().__init__()
    self.steps = 0
    self.action_space_size = action_space_size
    input_shape = (2, 6, 7)
    rp_shape = (num_filters, *input_shape[1:])
    self.representation = Representation(input_shape).to(device)
    self.prediction = Prediction(action_space_size).to(device)
    self.dynamics = Dynamics(rp_shape, (2, 6, 7)).to(device)
    self.eval()
  
  def predict_initial_inference(self, x):    
    assert x.ndim in (3, 4)
    assert x.shape == (2, 6, 7) or x.shape[1:] == (2, 6, 7)
    orig_x = x
    if x.ndim == 3:
        x = x.reshape(1, 2, 6, 7)
    
    x = torch.Tensor(x).to(device)
    h = self.representation(x)
    policy, value = self.prediction(h)
    
    if orig_x.ndim == 3:
        return h[0], policy[0], value[0]
    else:
        return h, policy, value

  def predict_recurrent_inference(self, x, a):

    if x.ndim == 3:
      x = x.reshape(1, 2, 6, 7)

    a = numpy.full((1, 2, 6, 7), a)

    g = self.dynamics(x, torch.Tensor(a).to(device))
    policy, value = self.prediction(g)
    
    return g[0], policy[0], value[0]

  def initial_inference(self, image) -> NetworkOutput:
    # representation + prediction function
    h, p, v = self.predict_initial_inference(image.astype(numpy.float32))
    return NetworkOutput(v, 0, p, h)

  def recurrent_inference(self, hidden_state, action) -> NetworkOutput:
    # dynamics + prediction function
    g, p, v = self.predict_recurrent_inference(hidden_state, action)
    return NetworkOutput(v, 0, p, g) 

  def training_steps(self) -> int:
    # How many steps / batches the network has been trained for.
    return self.steps


class SharedStorage(object):

  def __init__(self):
    self._networks = {}

  def latest_network(self) -> Network:
    if self._networks:
      return self._networks[max(self._networks.keys())]
    else:
      # policy -> uniform, value -> 0, reward -> 0
      return make_uniform_network()

  def old_network(self) -> Network:
    if self._networks:
      return self._networks[min(self._networks.keys())]
    else:
      # policy -> uniform, value -> 0, reward -> 0
      return make_uniform_network()

  def save_network(self, step: int, network: Network):
    self._networks[step] = network


In [0]:

################################################################################
############################# Testing the latest net ###########################
################################################################################

# Battle against random agents
def vs_random(network, n=100):
    results = {}
    for i in range(n):
        first_turn = i % 2 == 0
        turn = first_turn
        game = config.new_game()
        r = 0
        while not game.terminal():
            if turn:
              root = Node(0)
              current_observation = game.make_image(-1)
              expand_node(root, game.to_play(), game.legal_actions(),
                          network.initial_inference(current_observation))
              add_exploration_noise(config, root)
              run_mcts(config, root, game.action_history(), network)
              action = select_action(config, len(game.history), root, network)
            else:
              action = numpy.random.choice(game.legal_actions())
            game.apply(action)
            turn = not turn
        if ((game.environment.winner == Winner.white and first_turn) 
            or (game.environment.winner == Winner.black and not first_turn)):
          r = 1
        elif ((game.environment.winner == Winner.black and first_turn) 
            or (game.environment.winner == Winner.white and not first_turn)):
          r = -1
        results[r] = results.get(r, 0) + 1
    return results

def random_vs_random(n=100):
    results = {}
    for i in range(n):
        first_turn = i % 2 == 0
        turn = first_turn
        game = config.new_game()
        r = 0
        while not game.terminal():
            action = numpy.random.choice(game.legal_actions())
            game.apply(action)
            turn = not turn
        if ((game.environment.winner == Winner.white and first_turn) 
            or (game.environment.winner == Winner.black and not first_turn)):
          r = 1
        elif ((game.environment.winner == Winner.black and first_turn) 
            or (game.environment.winner == Winner.white and not first_turn)):
          r = -1
        results[r] = results.get(r, 0) + 1
    return results

def latest_vs_older(last, old, n=100):
    results = {}
    for i in range(n):
        first_turn = i % 2 == 0
        turn = first_turn
        game = config.new_game()
        r = 0
        while not game.terminal():
            if turn:
              root = Node(0)
              current_observation = game.make_image(-1)
              expand_node(root, game.to_play(), game.legal_actions(),
                          last.initial_inference(current_observation))
              add_exploration_noise(config, root)
              run_mcts(config, root, game.action_history(), last)
              action = select_action(config, len(game.history), root, last)
            else:
              root = Node(0)
              current_observation = game.make_image(-1)
              expand_node(root, game.to_play(), game.legal_actions(),
                          old.initial_inference(current_observation))
              add_exploration_noise(config, root)
              run_mcts(config, root, game.action_history(), old)
              action = select_action(config, len(game.history), root, old)
            game.apply(action)
            turn = not turn
        if ((game.environment.winner == Winner.white and first_turn) 
            or (game.environment.winner == Winner.black and not first_turn)):
          r = 1
        elif ((game.environment.winner == Winner.black and first_turn) 
            or (game.environment.winner == Winner.white and not first_turn)):
          r = -1
        results[r] = results.get(r, 0) + 1
    return results

##### End Helpers ########
##########################


# MuZero training is split into two independent parts: Network training and
# self-play data generation.
# These two parts only communicate by transferring the latest network checkpoint
# from the training to the self-play, and the finished games from the self-play
# to the training.
def muzero(config: MuZeroConfig):
  storage = SharedStorage()
  replay_buffer = ReplayBuffer(config)

  # Start n concurrent actor threads
  threads = list()
  for _ in range(config.num_actors):
    t = threading.Thread(target=launch_job, args=(run_selfplay, config, storage, replay_buffer))
    threads.append(t)

  # Start all threads
  for x in threads:
    x.start() 

  train_network(config, storage, replay_buffer)

  return storage.latest_network()


##################################
####### Part 1: Self-Play ########


# Each self-play job is independent of all others; it takes the latest network
# snapshot, produces a game and makes it available to the training job by
# writing it to a shared replay buffer.
def run_selfplay(config: MuZeroConfig, storage: SharedStorage,
                 replay_buffer: ReplayBuffer):
  while True:
    network = storage.latest_network()
    game = play_game(config, network)
    replay_buffer.save_game(game)


# Each game is produced by starting at the initial board position, then
# repeatedly executing a Monte Carlo Tree Search to generate moves until the end
# of the game is reached.
def play_game(config: MuZeroConfig, network: Network) -> Game:
  game = config.new_game()  
  while not game.terminal() and len(game.history) < config.max_moves:
    # At the root of the search tree we use the representation function to
    # obtain a hidden state given the current observation.
    root = Node(0)
    current_observation = game.make_image(-1)
    expand_node(root, game.to_play(), game.legal_actions(),
                network.initial_inference(current_observation))
    add_exploration_noise(config, root)

    # We then run a Monte Carlo Tree Search using only action sequences and the
    # model learned by the network.
    run_mcts(config, root, game.action_history(), network)
    action = select_action(config, len(game.history), root, network)
    game.apply(action)
    game.store_search_statistics(root)
  return game


# Core Monte Carlo Tree Search algorithm.
# To decide on an action, we run N simulations, always starting at the root of
# the search tree and traversing the tree according to the UCB formula until we
# reach a leaf node.
def run_mcts(config: MuZeroConfig, root: Node, action_history: ActionHistory,
             network: Network):
  min_max_stats = MinMaxStats(config.known_bounds)

  for _ in range(config.num_simulations):
    history = action_history.clone()
    node = root
    search_path = [node]

    while node.expanded():
      action, node = select_child(config, node, min_max_stats)
      history.add_action(action)
      search_path.append(node)

    # Inside the search tree we use the dynamics function to obtain the next
    # hidden state given an action and the previous hidden state.
    parent = search_path[-2]
    network_output = network.recurrent_inference(parent.hidden_state,
                                                 history.last_action())
    expand_node(node, history.to_play(), history.action_space(), network_output)

    backpropagate(search_path, network_output.value, history.to_play(),
                  config.discount, min_max_stats)


def select_action(config: MuZeroConfig, num_moves: int, node: Node,
                  network: Network):
  visit_counts = [
      (child.visit_count, action) for action, child in node.children.items()
  ]
  t = config.visit_softmax_temperature_fn(
      num_moves=num_moves, training_steps=network.training_steps())
  _, action = softmax_sample(visit_counts, t)
  return action


# Select the child with the highest UCB score.
def select_child(config: MuZeroConfig, node: Node,
                 min_max_stats: MinMaxStats):
  _, action, child = max(
      (ucb_score(config, node, child, min_max_stats), action,
       child) for action, child in node.children.items())
  return action, child


# The score for a node is based on its value, plus an exploration bonus based on
# the prior.
def ucb_score(config: MuZeroConfig, parent: Node, child: Node,
              min_max_stats: MinMaxStats) -> float:
  pb_c = math.log((parent.visit_count + config.pb_c_base + 1) /
                  config.pb_c_base) + config.pb_c_init
  pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1)

  prior_score = pb_c * child.prior
  value_score = min_max_stats.normalize(child.value())
  return prior_score + value_score


# We expand a node using the value, reward and policy prediction obtained from
# the neural network.
def expand_node(node: Node, to_play: Player, actions: List[Action],
                network_output: NetworkOutput):
  node.to_play = to_play
  node.hidden_state = network_output.hidden_state
  node.reward = network_output.reward
  policy = {a: math.exp(network_output.policy_logits[a]) for a in actions}
  policy_sum = sum(policy.values())
  for action, p in policy.items():
    node.children[action] = Node(p / policy_sum)


# At the end of a simulation, we propagate the evaluation all the way up the
# tree to the root.
def backpropagate(search_path: List[Node], value: float, to_play: Player,
                  discount: float, min_max_stats: MinMaxStats):
  for node in search_path:
    node.value_sum += value if node.to_play == to_play else -value
    node.visit_count += 1
    min_max_stats.update(node.value())

    value = node.reward + discount * value


# At the start of each search, we add dirichlet noise to the prior of the root
# to encourage the search to explore new actions.
def add_exploration_noise(config: MuZeroConfig, node: Node):
  actions = list(node.children.keys())
  noise = numpy.random.dirichlet([config.root_dirichlet_alpha] * len(actions))
  frac = config.root_exploration_fraction
  for a, n in zip(actions, noise):
    node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac


######### End Self-Play ##########
##################################


In [0]:

##################################
####### Part 2: Training #########
def train_network(config: MuZeroConfig, storage: SharedStorage,
                  replay_buffer: ReplayBuffer):

  network = Network(config.action_space_size).to(device)

  while True:

    optimizer = optim.SGD(network.parameters(), lr=0.01, weight_decay=config.lr_decay_rate,
                          momentum=config.momentum)

    while not len(replay_buffer.buffer) > 0:
      pass
    
    for i in range(config.training_steps):
      if i % config.checkpoint_interval == 0 and i > 0:
        storage.save_network(i, network)     
        # Test against random agent    
        vs_random_once = vs_random(network)
        print('network_vs_random = ', sorted(vs_random_once.items()), end='\n')
        vs_older = latest_vs_older(storage.latest_network(), storage.old_network())
        print('lastnet_vs_older = ', sorted(vs_older.items()), end='\n') 
      batch = replay_buffer.sample_batch(config.num_unroll_steps, config.td_steps)
      update_weights(batch, network, optimizer)
    storage.save_network(config.training_steps, network)

def update_weights(batch, network, optimizer):    

  network.train()    

  p_loss, v_loss = 0, 0

  for image, actions, targets in batch:
    # Initial step, from the real observation.
    value, reward, policy_logits, hidden_state = network.initial_inference(image)
    predictions = [(1.0, value, reward, policy_logits)]

    # Recurrent steps, from action and previous hidden state.
    for action in actions:
      value, reward, policy_logits, hidden_state = network.recurrent_inference(hidden_state, action)
      predictions.append((1.0 / len(actions), value, reward, policy_logits))

    for prediction, target in zip(predictions, targets):
      if(len(target[2]) > 0):
        _ , value, reward, policy_logits = prediction
        target_value, target_reward, target_policy = target

        p_loss += torch.sum(-torch.Tensor(numpy.array(target_policy)).to(device) * torch.log(policy_logits))
        v_loss += torch.sum((torch.Tensor([target_value]).to(device) - value) ** 2)
  
  optimizer.zero_grad()    
  total_loss = (p_loss + v_loss)
  total_loss.backward()
  optimizer.step()
  network.steps += 1
  print('p_loss %f v_loss %f' % (p_loss / len(batch), v_loss / len(batch)))

######### End Training ###########
##################################

################################################################################
############################# End of pseudocode ################################
################################################################################



In [6]:
print(device)

# Stubs to make the typechecker happy.
def softmax_sample(distribution, temperature: float):
  if temperature == 0:
    temperature = 1
  distribution = numpy.array(distribution)**(1/temperature)
  p_sum = distribution.sum()
  sample_temp = distribution/p_sum
  return 0, numpy.argmax(numpy.random.multinomial(1, sample_temp, 1))

def launch_job(f, *args):
  f(*args)

def make_uniform_network():
  return Network(make_connect4_config().action_space_size).to(device)



cuda:0


In [0]:
config = make_connect4_config()
vs_random_once = random_vs_random()
print('random_vs_random = ', sorted(vs_random_once.items()), end='\n')
network = muzero(config)

random_vs_random =  [(-1, 47), (1, 53)]
p_loss 9.893979 v_loss 0.601927
p_loss 8.691347 v_loss 3.955910
p_loss 8.554790 v_loss 4.140779
p_loss 8.814151 v_loss 6.870962
p_loss 9.364299 v_loss 2.571944
p_loss 9.304607 v_loss 5.093513
p_loss 9.411983 v_loss 5.341243
p_loss 9.229336 v_loss 5.062500
p_loss 9.419596 v_loss 5.937500
p_loss 9.608006 v_loss 6.640625
p_loss 10.105663 v_loss 7.640625
p_loss 9.597007 v_loss 7.640625
p_loss 9.869622 v_loss 7.015625
p_loss 9.549665 v_loss 7.109375
p_loss 9.376898 v_loss 6.093750
p_loss 9.610895 v_loss 5.859375
p_loss 9.609826 v_loss 5.078125
p_loss 9.732680 v_loss 7.875000
p_loss 9.679747 v_loss 6.250000
p_loss 9.736465 v_loss 6.484375
p_loss 9.609228 v_loss 6.000000
p_loss 9.397275 v_loss 6.484375
p_loss 9.608634 v_loss 5.859375
p_loss 9.546373 v_loss 5.953125
p_loss 9.638430 v_loss 5.796875
p_loss 9.276681 v_loss 5.468750
p_loss 9.638519 v_loss 6.796875
p_loss 9.609526 v_loss 5.859375
p_loss 9.485651 v_loss 6.156250
p_loss 9.641564 v_loss 5.843750