In [0]:
import datetime
import os
import gym
import numpy
import torch

# from .abstract_game import AbstractGame

observation_shape,
        num_blocks,
        num_channels,
        reduced_channels,
        fc_reward_layers,
        full_support_size,
        block_output_size,

#MuZero Config

In [0]:
class MuZeroConfig:
    def __init__(self, num_node):
        self.seed = 0  # Seed for numpy, torch and the game
        ### Game

        #What are the observations of our TSP?
        self.observation_shape = (1, num_node, 2)  # Dimensions of the game observation, must be 3D (channel, height, width). For a 1D array, please reshape it to (1, 1, length of array)
        #How many possible actions do we have
        self.action_space = [i for i in range(num_node-1)]  # Fixed list of all possible actions. You should only edit the length
        #Fixed to single player
        self.players = [i for i in range(1)]  # List of players. You should only edit the length
        self.stacked_observations = 0  # Number of previous observation and previous actions to add to the current observation



        ### Self-Play
        self.num_actors = 1  # Number of simultaneous threads self-playing to feed the replay buffer
        self.max_moves = num_node  # Maximum number of moves if game is not finished before !!! NEED TO UNDERSTAND WHAT MAX_MOVES means !!!
        #To be changed later
        self.num_simulations = 25  # Number of future moves self-simulated
        self.discount = 1  # Chronological discount of the reward
        #What?
        self.temperature_threshold = 6  # Number of moves before dropping temperature to 0 (ie playing according to the max)

        # Root prior exploration noise
        self.root_dirichlet_alpha = 0.1
        self.root_exploration_fraction = 0.25

        # UCB formula
        self.pb_c_base = 19652
        self.pb_c_init = 1.25



        ### Network
        self.network = "resnet"  # "resnet" / "fullyconnected"
        self.support_size = 10  # Value and reward are scaled (with almost sqrt) and encoded on a vector with a range of -support_size to support_size

        # Residual Network
        self.downsample = False  # Downsample observations before representation network (See paper appendix Network Architecture)
        self.blocks = 1  # Number of blocks in the ResNet
        self.channels = 16  # Number of channels in the ResNet
        self.reduced_channels = 16  # Number of channels before heads of dynamic and prediction networks
        self.resnet_fc_reward_layers = [8]  # Define the hidden layers in the reward head of the dynamic network
        self.resnet_fc_value_layers = [8]  # Define the hidden layers in the value head of the prediction network
        self.resnet_fc_policy_layers = [8]  # Define the hidden layers in the policy head of the prediction network

        # Fully Connected Network
        self.encoding_size = 32
        self.fc_reward_layers = [16]  # Define the hidden layers in the reward network
        self.fc_value_layers = []  # Define the hidden layers in the value network
        self.fc_policy_layers = []  # Define the hidden layers in the policy network
        self.fc_representation_layers = []  # Define the hidden layers in the representation network
        self.fc_dynamics_layers = [16]  # Define the hidden layers in the dynamics network



        ### Training
        self.results_path = os.path.join(os.path.dirname("/content/drive/My Drive/deep-learning-final"), "/results", os.path.basename("/content/drive/My Drive/deep-learning-final")[:-3], datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S"))  # Path to store the model weights and TensorBoard logs
        self.training_steps = 100000  # Total number of training steps (ie weights update according to a batch)
        self.batch_size = 64  # Number of parts of games to train on at each training step
        self.checkpoint_interval = 10  # Number of training steps before using the model for sef-playing
        self.value_loss_weight = 0.25  # Scale the value loss to avoid overfitting of the value function, paper recommends 0.25 (See paper appendix Reanalyze)
        self.training_device = "cuda" if torch.cuda.is_available() else "cpu"  # Train on GPU if available

        self.optimizer = "Adam"  # "Adam" or "SGD". Paper uses SGD
        self.weight_decay = 1e-4  # L2 weights regularization
        self.momentum = 0.9  # Used only if optimizer is SGD

        # Exponential learning rate schedule
        self.lr_init = 0.01  # Initial learning rate
        self.lr_decay_rate = 1  # Set it to 1 to use a constant learning rate
        self.lr_decay_steps = 10000


        ### Replay Buffer
        self.window_size = 3000  # Number of self-play games to keep in the replay buffer
        self.num_unroll_steps = 20  # Number of game moves to keep for every batch element
        self.td_steps = 20  # Number of steps in the future to take into account for calculating the target value
        self.use_last_model_value = False  # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)

        # Prioritized Replay (See paper appendix Training)
        self.PER = True  # Select in priority the elements in the replay buffer which are unexpected for the network
        self.use_max_priority = True  # Use the n-step TD error as initial priority. Better for large replay buffer
        self.PER_alpha = 0.5  # How much prioritization is used, 0 corresponding to the uniform case, paper suggests 1
        self.PER_beta = 1.0


        ### Adjust the self play / training ratio to avoid over/underfitting
        self.self_play_delay = 0  # Number of seconds to wait after each played game
        self.training_delay = 0  # Number of seconds to wait after each training step
        self.ratio = None  # Desired self played games per training step ratio. Equivalent to a synchronous version, training can take much longer. Set it to None to disable it

        def visit_softmax_temperature_fn(self, trained_steps):
          """
          Parameter to alter the visit count distribution to ensure that the action selection becomes greedier as training progresses.
          The smaller it is, the more likely the best action (ie with the highest visit count) is chosen.

          Returns:
              Positive float.
          """
          return 1


In [54]:
class Game():
    """
    Game wrapper.
    """

    def __init__(self, num_node, seed=None):
        print("here")
        self.env = TSP(num_node)

    def step(self, action, state):
        """
        Apply action to the game.
        
        Args:
            action : action of the action_space to take.

        Returns:
            The new observation, the reward and a boolean if the game has ended.
        """
        state, reward, done = self.env.step(action, state)
        # return state, reward * 20, done ---> need to check why reward*20 
        return state, reward, done

    def to_play(self):
        """
        Return the current player.

        Returns:
            The current player, it should be an element of the players list in the config. 
        """
        return self.env.to_play()

    def legal_actions(self, state):
        """
        Should return the legal actions at each turn, if it is not available, it can return
        the whole action space. At each turn, the game have to be able to handle one of returned actions.
        
        For complex game where calculating legal moves is too long, the idea is to define the legal actions
        equal to the action space but to return a negative reward if the action is illegal.
    
        Returns:
            An array of integers, subset of the action space.
        """
        return self.env.legal_actions()

    def reset(self):
        """
        Reset the game for a new game.
        
        Returns:
            Initial state of the game.
        """
        return self.env.reset()

    def close(self):
        """
        Properly close the game.
        """
        pass

    def render(self, state):
        """
        Display the game state and graph
        """
        self.env.render()

    def encode_board(self):
        # return self.env.encode_board()
        pass

    def human_to_action(self):
        """
        For multiplayer games, ask the user for a legal action
        and return the corresponding action number.

        Returns:
            An integer from the action space.
        """
        # while True:
        #     try:
        #         row = int(
        #             input(
        #                 "Enter the row (1, 2 or 3) to play for the player {}: ".format(
        #                     self.to_play()
        #                 )
        #             )
        #         )
        #         col = int(
        #             input(
        #                 "Enter the column (1, 2 or 3) to play for the player {}: ".format(
        #                     self.to_play()
        #                 )
        #             )
        #         )
        #         choice = (row - 1) * 3 + (col - 1)
        #         if (
        #             choice in self.legal_actions()
        #             and 1 <= row
        #             and 1 <= col
        #             and row <= 3
        #             and col <= 3
        #         ):
        #             break
        #     except:
        #         pass
        #     print("Wrong input, try again")
        # return choice

        pass

    def action_to_string(self, action_number):
        # """
        # Convert an action number to a string representing the action.
        
        # Args:
        #     action_number: an integer from the action space.

        # Returns:
        #     String representing the action.
        # """
        # row = 3 - action_number // 3
        # col = action_number % 3 + 1
        # return "Play row {}, column {}".format(row, col)
        pass
      
    def state_to_string(self, state):
      """
      Input:
          state: current state
      Returns:
          index of state
      """
      s = ''
      for i in range(self.num_node):
          s += str(int(state[i, 0]))
      return s

ERROR! Session/line number was not unique in database. History logging moved to new session 61


Class TSP



In [0]:
class TSP:
    def __init__(self, num_node):
        self.graph = np.random.rand(num_node, 2)
        self.player = 1
        self.num_node = num_node
        self.getInitState()
      
    def getInitState(self):
        """
        Returns:
            first_state: a representation of the graph
            left column representing visited nodes
            right column will always have a single 1 and the rest are 0's. index with the 1 in the right column is current node
        """
        
        # Always start with first node as current node 
        first_state = np.zeros([self.num_node, 2])
        first_state[0,0] = 1
        first_state[0,1] = 1
        return first_state

    def to_play(self):
        return 0

    def reset(self):
        self.board = np.random.rand(self.num_node, 2)
        self.player = 1
        return self.getInitState()
    
    def step(self, action, state):
      next_s = state.copy()
      # zero out current node
      next_s[:, 1] = 0
      # 1 in left column for visited, 1 in right column for current node
      next_s[action, :] = 1
      prev_a = np.where(state[:, 1] == 1)[0][0]
      # get xy coordinates for prev_node and current_node from the graph
      prev_node = self.graph[prev_a]
      current_node = self.graph[action]
      reward = 1 - np.linalg.norm(current_node - prev_node)
      if self.num_node == np.sum(next_s[:, 0]): #end of game
          reward += 1 - np.linalg.norm(current_node - self.graph[0])
            
      return next_s, reward, self.is_finished(state)


    def get_observation(self):
      # observation == state
      
      pass

    def legal_actions(self, state):
      return 1 - state[:, 0]

    def is_finished(self, state):
      """
      Input:
        state: current state
      Returns:
        r: 0 if game has not ended. 1 if it has
               
      """
      end = False
      if self.num_node == np.sum(state[:, 0]):
          end = True
      return end

In [0]:
def render(self, state):
    print("State:")
    print(state)
    print("Graph:")
    print(self.graph)

In [64]:
from google.colab import drive
drive.mount('/content/drive',force_remount=True)

Mounted at /content/drive


#MuZero.Py Class

In [97]:
import copy
import importlib
import os
import time

import numpy as np
import ray
import torch
from torch.utils.tensorboard import SummaryWriter



class MuZero:
    """
    Main class to manage MuZero.

    Args:
        game_name (str): Name of the game module, it should match the name of a .py file
        in the "./games" directory.

    Example:
        >>> muzero = MuZero("cartpole")
        >>> muzero.train()
        >>> muzero.test()
    """

    def __init__(self, game_name, num_node):
        self.game_name = game_name

        # Load the game and the config from the module with the game name
        try:
            # game_module = importlib.import_module("games." + self.game_name)
            self.config = MuZeroConfig(num_node)
            self.Game = TSP(num_node)
        except Exception as err:
            print(
                '{} is not a supported game name, try "cartpole" or refer to the documentation for adding a new game.'.format(
                    self.game_name
                )
            )
            raise err

        # Fix random generator seed
        numpy.random.seed(self.config.seed)
        torch.manual_seed(self.config.seed)

        # Weights used to initialize components
        self.muzero_weights = MuZeroNetwork(self.config).get_weights()

    def train(self):
        ray.init(ignore_reinit_error=True)
        training_worker = Trainer.remote(copy.deepcopy(self.muzero_weights), self.config)
        self_play_workers = [
            SelfPlay.remote(
                copy.deepcopy(self.muzero_weights),
                self.Game, #maybe need to add seed back!
                self.config,
            )
            for seed in range(self.config.num_actors)
        ]
        shared_storage_worker = SharedStorage.remote(
            copy.deepcopy(self.muzero_weights), self.game_name, self.config,
        )
        os.makedirs(self.config.results_path, exist_ok=True)
        writer = SummaryWriter(self.config.results_path)

        print("1")
        replay_buffer_worker = ReplayBuffer.remote(self.config)
        print("2")
        test_worker = SelfPlay.remote(
            copy.deepcopy(self.muzero_weights),
            self.Game, # maybe need to put seed back !!!
            self.config,
        )
        print("3")
        # Launch workers
        [
            self_play_worker.continuous_self_play.remote(
                shared_storage_worker, replay_buffer_worker
            )
            for self_play_worker in self_play_workers
        ]
        print("4")
        test_worker.continuous_self_play.remote(shared_storage_worker, None, True)
        print("5")
        training_worker.continuous_update_weights.remote(
            replay_buffer_worker, shared_storage_worker
        )
        print("6")

        print(
            "\nTraining...\nRun tensorboard --logdir ./results and go to http://localhost:6006/ to see in real time the training performance.\n"
        )
        # Save hyperparameters to TensorBoard
        hp_table = [
            "| {} | {} |".format(key, value)
            for key, value in self.config.__dict__.items()
        ]
        writer.add_text(
            "Hyperparameters",
            "| Parameter | Value |\n|-------|-------|\n" + "\n".join(hp_table),
        )
        # Loop for monitoring in real time the workers
        counter = 0
        infos = ray.get(shared_storage_worker.get_infos.remote())
        try:
            while infos["training_step"] < self.config.training_steps:
                # Get and save real time performance
                infos = ray.get(shared_storage_worker.get_infos.remote())
                writer.add_scalar(
                    "1.Total reward/1.Total reward", infos["total_reward"], counter,
                )
                writer.add_scalar(
                    "1.Total reward/2.Episode length", infos["episode_length"], counter,
                )
                writer.add_scalar(
                    "1.Total reward/3.Player 0 MuZero reward",
                    infos["player_0_reward"],
                    counter,
                )
                writer.add_scalar(
                    "1.Total reward/4.Player 1 Random reward",
                    infos["player_1_reward"],
                    counter,
                )
                writer.add_scalar(
                    "2.Workers/1.Self played games",
                    ray.get(replay_buffer_worker.get_self_play_count.remote()),
                    counter,
                )
                writer.add_scalar(
                    "2.Workers/2.Training steps", infos["training_step"], counter
                )
                writer.add_scalar(
                    "2.Workers/3.Self played games per training step ratio",
                    ray.get(replay_buffer_worker.get_self_play_count.remote())
                    / max(1, infos["training_step"]),
                    counter,
                )
                writer.add_scalar("2.Workers/4.Learning rate", infos["lr"], counter)
                writer.add_scalar(
                    "3.Loss/1.Total weighted loss", infos["total_loss"], counter
                )
                writer.add_scalar("3.Loss/Value loss", infos["value_loss"], counter)
                writer.add_scalar("3.Loss/Reward loss", infos["reward_loss"], counter)
                writer.add_scalar("3.Loss/Policy loss", infos["policy_loss"], counter)
                print(
                    "Last test reward: {0:.2f}. Training step: {1}/{2}. Played games: {3}. Loss: {4:.2f}".format(
                        infos["total_reward"],
                        infos["training_step"],
                        self.config.training_steps,
                        ray.get(replay_buffer_worker.get_self_play_count.remote()),
                        infos["total_loss"],
                    ),
                    end="\r",
                )
                counter += 1
                time.sleep(0.5)
        except KeyboardInterrupt as err:
            # Comment the line below to be able to stop the training but keep running
            # raise err
            pass
        self.muzero_weights = ray.get(shared_storage_worker.get_weights.remote())
        # End running actors
        ray.shutdown()

    def test(self, render=True, opponent="self", muzero_player=None):
        """
        Test the model in a dedicated thread.

        Args:
            render: Boolean to display or not the environment.

            opponent: "self" for self-play, "human" for playing against MuZero and "random"
            for a random agent.

            muzero_player: Integer with the player number of MuZero in case of multiplayer
            games, None let MuZero play all players turn by turn.
        """
        print("\nTesting...")
        ray.init(ignore_reinit_error=True)
        self_play = SelfPlay.remote()
        self_play_workers = self_play.remote(
            copy.deepcopy(self.muzero_weights),
            self.Game, # maybe need to return seed back!!!!
            self.config,
        )
        history = ray.get(
            self_play_workers.play_game.remote(0, 0, render, opponent, muzero_player)
        )
        ray.shutdown()
        return sum(history.reward_history)

    def load_model(self, path=None):
        if not path:
            path = os.path.join(self.config.results_path, "model.weights")
        try:
            self.muzero_weights = torch.load(path)
            print("\nUsing weights from {}".format(path))
        except FileNotFoundError:
            print("\nThere is no model saved in {}.".format(path))


if __name__ == "__main__":
    # Initialize MuZero
    num_node = 5
    muzero = MuZero("TSP", num_node)

    while True:
        # Configure running options
        options = [
            "Train",
            "Load pretrained model",
            "Render some self play games",
            "Exit",
        ]
        print()
        for i in range(len(options)):
            print("{}. {}".format(i, options[i]))

        choice = input("Enter a number to choose an action: ")
        valid_inputs = [str(i) for i in range(len(options))]
        while choice not in valid_inputs:
            choice = input("Invalid input, enter a number listed above: ")
        choice = int(choice)
        if choice == 0:
            muzero.train()
        elif choice == 1:
            path = input("Enter a path to the model.weights: ") ### set 1 path and pass it as const 
            while not os.path.isfile(path):
                path = input("Invalid path. Try again: ")
            muzero.load_model(path)
        elif choice == 2:
            muzero.test(render=True, opponent="self", muzero_player=None)
        else:
            break
        print("\nDone")

    ## Successive training, create a new config file for each experiment
    # experiments = ["cartpole", "tictactoe"]
    # for experiment in experiments:
    #     print("\nStarting experiment {}".format(experiment))
    #     try:
    #         muzero = MuZero(experiment)
    #         muzero.train()
    #     except:
    #         print("Skipping {}, an error has occurred.".format(experiment))


0. Train
1. Load pretrained model
2. Render some self play games
3. Exit
Enter a number to choose an action: 0


2020-04-05 15:14:26,373	INFO resource_spec.py:212 -- Starting Ray with 7.13 GiB memory available for workers and up to 3.57 GiB for objects. You can adjust these settings with ray.init(memory=<bytes>, object_store_memory=<bytes>).
2020-04-05 15:14:26,686	INFO services.py:1148 -- View the Ray dashboard at [1m[32mlocalhost:8265[39m[22m


1
2
3
4
5
6

Training...
Run tensorboard --logdir ./results and go to http://localhost:6006/ to see in real time the training performance.



Traceback (most recent call last):
  File "/usr/lib/python3.6/asyncio/base_events.py", line 1062, in create_server
    sock.bind(sa)
OSError: [Errno 99] Cannot assign requested address

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/ray/dashboard/dashboard.py", line 1142, in <module>
    dashboard.run()
  File "/usr/local/lib/python3.6/dist-packages/ray/dashboard/dashboard.py", line 570, in run
    aiohttp.web.run_app(self.app, host=self.host, port=self.port)
  File "/usr/local/lib/python3.6/dist-packages/aiohttp/web.py", line 433, in run_app
    reuse_port=reuse_port))
  File "/usr/lib/python3.6/asyncio/base_events.py", line 484, in run_until_complete
    return future.result()
  File "/usr/local/lib/python3.6/dist-packages/aiohttp/web.py", line 359, in _run_app
    await site.start()
  File "/usr/local/lib/python3.6/dist-packages/aiohttp/web_runner.py", line 104, in start
    reuse

[2m[36m(pid=4521)[0m THCudaCheck FAIL file=/pytorch/aten/src/THC/THCGeneral.cpp line=50 error=100 : no CUDA-capable device is detected
[2m[36m(pid=4506)[0m tensor([[[1., 1.],
[2m[36m(pid=4506)[0m          [0., 0.],
[2m[36m(pid=4506)[0m          [0., 0.],
[2m[36m(pid=4506)[0m          [0., 0.],
[2m[36m(pid=4506)[0m          [0., 0.]]])
[2m[36m(pid=4542)[0m tensor([[[1., 1.],
[2m[36m(pid=4542)[0m          [0., 0.],
[2m[36m(pid=4542)[0m          [0., 0.],
[2m[36m(pid=4542)[0m          [0., 0.],
[2m[36m(pid=4542)[0m          [0., 0.]]])


2020-04-05 15:14:37,217	ERROR worker.py:1012 -- Possible unhandled error from worker: [36mray::Trainer.__init__()[39m (pid=4521, ip=172.28.0.2)
  File "python/ray/_raylet.pyx", line 452, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 407, in ray._raylet.execute_task.function_executor
  File "<ipython-input-22-116e65588825>", line 27, in __init__
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 425, in to
    return self._apply(convert)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 201, in _apply
    module._apply(fn)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 201, in _apply
    module._apply(fn)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 223, in _apply
    param_applied = fn(param)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 423, in convert
    return t.to(device, dtype if t.is_floating_p



2020-04-05 15:14:38,222	ERROR worker.py:1012 -- Possible unhandled error from worker: [36mray::SelfPlay.continuous_self_play()[39m (pid=4506, ip=172.28.0.2)
  File "python/ray/_raylet.pyx", line 452, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 407, in ray._raylet.execute_task.function_executor
  File "<ipython-input-96-79c2928b38c4>", line 52, in continuous_self_play
  File "<ipython-input-96-79c2928b38c4>", line 131, in play_game
  File "<ipython-input-96-79c2928b38c4>", line 250, in run
  File "<ipython-input-93-bf4957b7be6d>", line 508, in initial_inference
  File "<ipython-input-93-bf4957b7be6d>", line 433, in representation
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "<ipython-input-93-bf4957b7be6d>", line 278, in forward
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **


Done

0. Train
1. Load pretrained model
2. Render some self play games
3. Exit
Enter a number to choose an action: 3


In [95]:
!tensorboard --logdir ./results

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.2.0 at http://localhost:6006/ (Press CTRL+C to quit)
^C


In [0]:
import math

import torch


class MuZeroNetwork:
    def __new__(cls, config):
        if config.network == "fullyconnected":
            return MuZeroFullyConnectedNetwork(
                config.observation_shape,
                config.stacked_observations,
                len(config.action_space),
                config.encoding_size,
                config.fc_reward_layers,
                config.fc_value_layers,
                config.fc_policy_layers,
                config.fc_representation_layers,
                config.fc_dynamics_layers,
                config.support_size,
            )
        elif config.network == "resnet":
            return MuZeroResidualNetwork(
                config.observation_shape,
                config.stacked_observations,
                len(config.action_space),
                config.blocks,
                config.channels,
                config.reduced_channels,
                config.resnet_fc_reward_layers,
                config.resnet_fc_value_layers,
                config.resnet_fc_policy_layers,
                config.support_size,
                config.downsample,
            )
        else:
            raise ValueError(
                'The network parameter should be "fullyconnected" or "resnet".'
            )


##################################
######## Fully Connected #########


class MuZeroFullyConnectedNetwork(torch.nn.Module):
    def __init__(
        self,
        observation_shape,
        stacked_observations,
        action_space_size,
        encoding_size,
        fc_reward_layers,
        fc_value_layers,
        fc_policy_layers,
        fc_representation_layers,
        fc_dynamics_layers,
        support_size,
    ):
        super().__init__()
        self.action_space_size = action_space_size
        self.full_support_size = 2 * support_size + 1

        self.representation_network = FullyConnectedNetwork(
            observation_shape[0]
            * observation_shape[1]
            * observation_shape[2]
            * (stacked_observations + 1)
            + stacked_observations * observation_shape[1] * observation_shape[2],
            fc_representation_layers,
            encoding_size,
        )

        self.dynamics_encoded_state_network = FullyConnectedNetwork(
            encoding_size + self.action_space_size, fc_dynamics_layers, encoding_size
        )
        self.dynamics_reward_network = FullyConnectedNetwork(
            encoding_size + self.action_space_size,
            fc_reward_layers,
            self.full_support_size,
        )

        self.prediction_policy_network = FullyConnectedNetwork(
            encoding_size, [], self.action_space_size
        )
        self.prediction_value_network = FullyConnectedNetwork(
            encoding_size, fc_value_layers, self.full_support_size,
        )

    def prediction(self, encoded_state):
        policy_logits = self.prediction_policy_network(encoded_state)
        value = self.prediction_value_network(encoded_state)
        return policy_logits, value

    def representation(self, observation):
        encoded_state = self.representation_network(
            observation.view(observation.shape[0], -1)
        )
        # Scale encoded state between [0, 1] (See appendix paper Training)
        min_encoded_state = encoded_state.min(1, keepdim=True)[0]
        max_encoded_state = encoded_state.max(1, keepdim=True)[0]
        scale_encoded_state = max_encoded_state - min_encoded_state
        scale_encoded_state[scale_encoded_state == 0] = 1
        encoded_state_normalized = (
            encoded_state - min_encoded_state
        ) / scale_encoded_state
        return encoded_state_normalized

    def dynamics(self, encoded_state, action):
        # Stack encoded_state with a game specific one hot encoded action (See paper appendix Network Architecture)
        action_one_hot = (
            torch.zeros((action.shape[0], self.action_space_size))
            .to(action.device)
            .float()
        )
        action_one_hot.scatter_(1, action.long(), 1.0)
        x = torch.cat((encoded_state, action_one_hot), dim=1)

        next_encoded_state = self.dynamics_encoded_state_network(x)

        # Scale encoded state between [0, 1] (See paper appendix Training)
        min_next_encoded_state = next_encoded_state.min(1, keepdim=True)[0]
        max_next_encoded_state = next_encoded_state.max(1, keepdim=True)[0]
        scale_next_encoded_state = max_next_encoded_state - min_next_encoded_state
        scale_next_encoded_state[scale_next_encoded_state == 0] = 1
        next_encoded_state_normalized = (
            next_encoded_state - min_next_encoded_state
        ) / scale_next_encoded_state

        reward = self.dynamics_reward_network(x)
        return next_encoded_state_normalized, reward

    def initial_inference(self, observation):
        encoded_state = self.representation(observation)
        policy_logits, value = self.prediction(encoded_state)
        # reward equal to 0 for consistency
        reward = (
            torch.zeros(1, self.full_support_size)
            .scatter(1, torch.tensor([[self.full_support_size // 2]]).long(), 1.0)
            .repeat(len(observation), 1)
            .to(observation.device)
        )

        return (
            value,
            reward,
            policy_logits,
            encoded_state,
        )

    def recurrent_inference(self, encoded_state, action):
        next_encoded_state, reward = self.dynamics(encoded_state, action)
        policy_logits, value = self.prediction(next_encoded_state)
        return value, reward, policy_logits, next_encoded_state

    def get_weights(self):
        return {key: value.cpu() for key, value in self.state_dict().items()}

    def set_weights(self, weights):
        self.load_state_dict(weights)


###### End Fully Connected #######
##################################


##################################
############# ResNet #############


def conv3x3(in_channels, out_channels, stride=1):
    return torch.nn.Conv2d(
        in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False
    )


# Residual block
class ResidualBlock(torch.nn.Module):
    def __init__(self, num_channels, stride=1):
        super().__init__()
        self.conv1 = conv3x3(num_channels, num_channels, stride)
        self.bn1 = torch.nn.BatchNorm2d(num_channels)
        self.relu = torch.nn.ReLU()
        self.conv2 = conv3x3(num_channels, num_channels)
        self.bn2 = torch.nn.BatchNorm2d(num_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += x
        out = self.relu(out)
        return out


# Downsample observations before representation network (See paper appendix Network Architecture)
class DownSample(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(
            in_channels,
            out_channels // 2,
            kernel_size=3,
            stride=2,
            padding=1,
            bias=False,
        )
        self.resblocks1 = torch.nn.ModuleList(
            [ResidualBlock(out_channels // 2) for _ in range(2)]
        )
        self.conv2 = torch.nn.Conv2d(
            out_channels // 2,
            out_channels,
            kernel_size=3,
            stride=2,
            padding=1,
            bias=False,
        )
        self.resblocks2 = torch.nn.ModuleList(
            [ResidualBlock(out_channels) for _ in range(3)]
        )
        self.pooling1 = torch.nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
        self.resblocks3 = torch.nn.ModuleList(
            [ResidualBlock(out_channels) for _ in range(3)]
        )
        self.pooling2 = torch.nn.AvgPool2d(kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        out = self.conv1(x)
        for block in self.resblocks1:
            out = block(out)
        out = self.conv2(out)
        for block in self.resblocks2:
            out = block(out)
        out = self.pooling1(out)
        for block in self.resblocks3:
            out = block(out)
        out = self.pooling2(out)
        return out


class RepresentationNetwork(torch.nn.Module):
    def __init__(
        self,
        observation_shape,
        stacked_observations,
        num_blocks,
        num_channels,
        downsample,
    ):
        super().__init__()
        self.use_downsample = downsample
        if self.use_downsample:
            self.downsample = DownSample(
                observation_shape[0] * (stacked_observations + 1)
                + stacked_observations,
                num_channels,
            )
        self.conv = conv3x3(
            num_channels
            if downsample
            else observation_shape[0] * (stacked_observations + 1)
            + stacked_observations,
            num_channels,
        )
        self.bn = torch.nn.BatchNorm2d(num_channels)
        self.relu = torch.nn.ReLU()
        self.resblocks = torch.nn.ModuleList(
            [ResidualBlock(num_channels) for _ in range(num_blocks)]
        )

    def forward(self, x):
        if self.use_downsample:
            out = self.downsample(x)
        else:
            out = x
        out = self.conv(out)
        out = self.bn(out)
        out = self.relu(out)
        for block in self.resblocks:
            out = block(out)
        return out


class DynamicNetwork(torch.nn.Module):
    def __init__(
        self,
        observation_shape,
        num_blocks,
        num_channels,
        reduced_channels,
        fc_reward_layers,
        full_support_size,
        block_output_size,
    ):
        super().__init__()
        self.observation_shape = observation_shape
        self.conv = conv3x3(num_channels, num_channels - 1)
        self.bn = torch.nn.BatchNorm2d(num_channels - 1)
        self.relu = torch.nn.ReLU()
        self.resblocks = torch.nn.ModuleList(
            [ResidualBlock(num_channels - 1) for _ in range(num_blocks)]
        )

        self.conv1x1 = torch.nn.Conv2d(num_channels - 1, reduced_channels, 1)
        self.block_output_size = block_output_size
        self.fc = FullyConnectedNetwork(
            self.block_output_size,
            fc_reward_layers,
            full_support_size,
            activation=None,
        )

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        for block in self.resblocks:
            out = block(out)
        state = out
        out = self.conv1x1(out)
        out = out.view(-1, self.block_output_size)
        reward = self.fc(out)
        return state, reward


class PredictionNetwork(torch.nn.Module):
    def __init__(
        self,
        observation_shape,
        action_space_size,
        num_blocks,
        num_channels,
        reduced_channels,
        fc_value_layers,
        fc_policy_layers,
        full_support_size,
        block_output_size,
    ):
        super().__init__()
        self.observation_shape = observation_shape
        self.resblocks = torch.nn.ModuleList(
            [ResidualBlock(num_channels) for _ in range(num_blocks)]
        )

        self.conv1x1 = torch.nn.Conv2d(num_channels, reduced_channels, 1)
        self.block_output_size = block_output_size
        self.fc_value = FullyConnectedNetwork(
            self.block_output_size, fc_value_layers, full_support_size, activation=None,
        )
        self.fc_policy = FullyConnectedNetwork(
            self.block_output_size,
            fc_policy_layers,
            action_space_size,
            activation=None,
        )

    def forward(self, x):
        out = x
        for block in self.resblocks:
            out = block(out)
        out = self.conv1x1(out)
        out = out.view(-1, self.block_output_size)
        value = self.fc_value(out)
        policy = self.fc_policy(out)
        return policy, value


class MuZeroResidualNetwork(torch.nn.Module):
    def __init__(
        self,
        observation_shape,
        stacked_observations,
        action_space_size,
        num_blocks,
        num_channels,
        reduced_channels,
        fc_reward_layers,
        fc_value_layers,
        fc_policy_layers,
        support_size,
        downsample,
    ):
        super().__init__()
        self.action_space_size = action_space_size
        self.full_support_size = 2 * support_size + 1
        block_output_size = (
            (
                reduced_channels
                * (observation_shape[1] // 16)
                * (observation_shape[2] // 16)
            )
            if downsample
            else (reduced_channels * observation_shape[1] * observation_shape[2])
        )

        self.representation_network = RepresentationNetwork(
            observation_shape,
            stacked_observations,
            num_blocks,
            num_channels,
            downsample,
        )

        self.dynamics_network = DynamicNetwork(
            observation_shape,
            num_blocks,
            num_channels + 1,
            reduced_channels,
            fc_reward_layers,
            self.full_support_size,
            block_output_size,
        )

        self.prediction_network = PredictionNetwork(
            observation_shape,
            action_space_size,
            num_blocks,
            num_channels,
            reduced_channels,
            fc_value_layers,
            fc_policy_layers,
            self.full_support_size,
            block_output_size,
        )

    def prediction(self, encoded_state):
        policy, value = self.prediction_network(encoded_state)
        return policy, value

    def representation(self, observation):
        encoded_state = self.representation_network(observation)

        # Scale encoded state between [0, 1] (See appendix paper Training)
        min_encoded_state = (
            encoded_state.view(
                -1,
                encoded_state.shape[1],
                encoded_state.shape[2] * encoded_state.shape[3],
            )
            .min(2, keepdim=True)[0]
            .unsqueeze(-1)
        )
        max_encoded_state = (
            encoded_state.view(
                -1,
                encoded_state.shape[1],
                encoded_state.shape[2] * encoded_state.shape[3],
            )
            .max(2, keepdim=True)[0]
            .unsqueeze(-1)
        )
        scale_encoded_state = max_encoded_state - min_encoded_state
        scale_encoded_state[scale_encoded_state == 0] = 1
        encoded_state_normalized = (
            encoded_state - min_encoded_state
        ) / scale_encoded_state
        return encoded_state_normalized

    def dynamics(self, encoded_state, action):
        # Stack encoded_state with a game specific one hot encoded action (See paper appendix Network Architecture)
        action_one_hot = (
            torch.ones(
                (
                    encoded_state.shape[0],
                    1,
                    encoded_state.shape[2],
                    encoded_state.shape[3],
                )
            )
            .to(action.device)
            .float()
        )
        action_one_hot = (
            action[:, :, None, None] * action_one_hot / self.action_space_size
        )
        x = torch.cat((encoded_state, action_one_hot), dim=1)
        next_encoded_state, reward = self.dynamics_network(x)

        # Scale encoded state between [0, 1] (See paper appendix Training)
        min_next_encoded_state = (
            next_encoded_state.view(
                -1,
                next_encoded_state.shape[1],
                next_encoded_state.shape[2] * next_encoded_state.shape[3],
            )
            .min(2, keepdim=True)[0]
            .unsqueeze(-1)
        )
        max_next_encoded_state = (
            next_encoded_state.view(
                -1,
                next_encoded_state.shape[1],
                next_encoded_state.shape[2] * next_encoded_state.shape[3],
            )
            .max(2, keepdim=True)[0]
            .unsqueeze(-1)
        )
        scale_next_encoded_state = max_next_encoded_state - min_next_encoded_state
        scale_next_encoded_state[scale_next_encoded_state == 0] = 1
        next_encoded_state_normalized = (
            next_encoded_state - min_next_encoded_state
        ) / scale_next_encoded_state
        return next_encoded_state_normalized, reward

    def initial_inference(self, observation):
        encoded_state = self.representation(observation)
        policy_logits, value = self.prediction(encoded_state)
        # reward equal to 0 for consistency
        reward = (
            torch.zeros(1, self.full_support_size)
            .scatter(1, torch.tensor([[self.full_support_size // 2]]).long(), 1.0)
            .repeat(len(observation), 1)
            .to(observation.device)
        )
        return (
            value,
            reward,
            policy_logits,
            encoded_state,
        )

    def recurrent_inference(self, encoded_state, action):
        next_encoded_state, reward = self.dynamics(encoded_state, action)
        policy_logits, value = self.prediction(next_encoded_state)
        return value, reward, policy_logits, next_encoded_state

    def get_weights(self):
        return {key: value.cpu() for key, value in self.state_dict().items()}

    def set_weights(self, weights):
        self.load_state_dict(weights)


########### End ResNet ###########
##################################


class FullyConnectedNetwork(torch.nn.Module):
    def __init__(self, input_size, layer_sizes, output_size, activation=None):
        super().__init__()
        size_list = [input_size] + layer_sizes
        layers = []
        if 1 < len(size_list):
            for i in range(len(size_list) - 1):
                layers.extend(
                    [
                        torch.nn.Linear(size_list[i], size_list[i + 1]),
                        torch.nn.LeakyReLU(),
                    ]
                )
        layers.append(torch.nn.Linear(size_list[-1], output_size))
        if activation:
            layers.append(activation)
        self.layers = torch.nn.ModuleList(layers)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


def support_to_scalar(logits, support_size):
    """
    Transform a categorical representation to a scalar
    See paper appendix Network Architecture
    """
    # Decode to a scalar
    probabilities = torch.softmax(logits, dim=1)
    support = (
        torch.tensor([x for x in range(-support_size, support_size + 1)])
        .expand(probabilities.shape)
        .float()
        .to(device=probabilities.device)
    )
    x = torch.sum(support * probabilities, dim=1, keepdim=True)

    # Invert the scaling (defined in https://arxiv.org/abs/1805.11593)
    x = torch.sign(x) * (
        ((torch.sqrt(1 + 4 * 0.001 * (torch.abs(x) + 1 + 0.001)) - 1) / (2 * 0.001))
        ** 2
        - 1
    )
    return x


def scalar_to_support(x, support_size):
    """
    Transform a scalar to a categorical representation with (2 * support_size + 1) categories
    See paper appendix Network Architecture
    """
    # Reduce the scale (defined in https://arxiv.org/abs/1805.11593)
    x = torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + 0.001 * x

    # Encode on a vector
    x = torch.clamp(x, -support_size, support_size)
    floor = x.floor()
    prob = x - floor
    logits = torch.zeros(x.shape[0], x.shape[1], 2 * support_size + 1).to(x.device)
    logits.scatter_(
        2, (floor + support_size).long().unsqueeze(-1), (1 - prob).unsqueeze(-1)
    )
    indexes = floor + support_size + 1
    prob = prob.masked_fill_(2 * support_size < indexes, 0.0)
    indexes = indexes.masked_fill_(2 * support_size < indexes, 0.0)
    logits.scatter_(2, indexes.long().unsqueeze(-1), prob.unsqueeze(-1))
    return logits

In [0]:
import ray
import torch
import os


@ray.remote
class SharedStorage:
    """
    Class which run in a dedicated thread to store the network weights and some information.
    """

    def __init__(self, weights, game_name, config):
        self.config = config
        self.game_name = game_name
        self.weights = weights
        self.infos = {
            "total_reward": 0,
            "player_0_reward": 0,
            "player_1_reward": 0,
            "episode_length": 0,
            "training_step": 0,
            "lr": 0,
            "total_loss": 0,
            "value_loss": 0,
            "reward_loss": 0,
            "policy_loss": 0,
        }

    def get_weights(self):
        return self.weights

    def set_weights(self, weights, path=None):
        self.weights = weights
        if not path:
            path = os.path.join(self.config.results_path, "model.weights")

        torch.save(self.weights, path)

    def get_infos(self):
        return self.infos

    def set_infos(self, key, value):
        self.infos[key] = value

In [0]:
import copy
import math
import time

import numpy
import ray
import torch


@ray.remote
class SelfPlay:
    """
    Class which run in a dedicated thread to play games and save them to the replay-buffer.
    """

    def __init__(self, initial_weights, game, config):
        self.config = config
        self.game = game

        # Fix random generator seed
        numpy.random.seed(self.config.seed)
        torch.manual_seed(self.config.seed)

        # Initialize the network
        self.model = MuZeroNetwork(self.config)
        self.model.set_weights(initial_weights)
        self.model.to(torch.device("cpu"))
        self.model.eval()

    def continuous_self_play(self, shared_storage, replay_buffer, test_mode=False):
        while True:
            self.model.set_weights(
                copy.deepcopy(ray.get(shared_storage.get_weights.remote()))
            )

            # Take the best action (no exploration) in test mode
            # temperature = (
            #     0
            #     if test_mode
            #     else self.config.visit_softmax_temperature_fn(
            #         trained_steps=ray.get(shared_storage.get_infos.remote())[
            #             "training_step"
            #         ]
            #     )
            # )
            # SCHORY WE NEED TO CHANGE 0.5 to temperature
            game_history = self.play_game(
                0.5,
                self.config.temperature_threshold,
                False,
                "self",
                0,
            )

            # Save to the shared storage
            if test_mode:
                shared_storage.set_infos.remote(
                    "total_reward", sum(game_history.reward_history)
                )
                shared_storage.set_infos.remote(
                    "episode_length", len(game_history.action_history)
                )
                if 1 < len(self.config.players):
                    shared_storage.set_infos.remote(
                        "player_0_reward",
                        sum(
                            [
                                reward
                                for i, reward in enumerate(game_history.reward_history)
                                if game_history.to_play_history[i] == 1
                            ]
                        ),
                    )
                    shared_storage.set_infos.remote(
                        "player_1_reward",
                        sum(
                            [
                                reward
                                for i, reward in enumerate(game_history.reward_history)
                                if game_history.to_play_history[i] == 0
                            ]
                        ),
                    )
            if not test_mode:
                replay_buffer.save_game.remote(game_history)

            # Managing the self-play / training ratio
            if not test_mode and self.config.self_play_delay:
                time.sleep(self.config.self_play_delay)
            if not test_mode and self.config.ratio:
                while (
                    ray.get(replay_buffer.get_self_play_count.remote())
                    / max(
                        1, ray.get(shared_storage.get_infos.remote())["training_step"]
                    )
                    > self.config.ratio
                ):
                    time.sleep(0.5)

    def play_game(
        self, temperature, temperature_threshold, render, opponent, muzero_player
    ):
        """
        Play one game with actions based on the Monte Carlo tree search at each moves.
        """
        game_history = GameHistory()
        observation = self.game.reset()
        game_history.action_history.append(0)
        game_history.observation_history.append(observation)
        game_history.reward_history.append(0)
        game_history.to_play_history.append(self.game.to_play())

        done = False

        if render:
            self.game.render()

        with torch.no_grad():
            while (
                not done and len(game_history.action_history) <= self.config.max_moves
            ):
                stacked_observations = game_history.get_stacked_observations(
                    -1, self.config.stacked_observations,
                )

                root, priority, tree_depth = MCTS(self.config).run(
                    self.model,
                    stacked_observations,
                    self.game.legal_actions(observation),
                    self.game.to_play(),
                    False if temperature == 0 else True,
                )

                if render:
                    print("Tree depth: {}".format(tree_depth))
                    print(
                        "Root value for player {0}: {1:.2f}".format(
                            self.game.to_play(), root.value()
                        )
                    )

                # Choose the action
                if opponent == "self" or muzero_player == self.game.to_play():
                    action = self.select_action(
                        root,
                        temperature
                        if not temperature_threshold
                        or len(game_history.action_history) < temperature_threshold
                        else 0,
                    )
                elif opponent == "human":
                    print(
                        "Player {} turn. MuZero suggests {}".format(
                            self.game.to_play(),
                            self.game.action_to_string(self.select_action(root, 0)),
                        )
                    )
                    action = self.game.human_to_action()
                elif opponent == "random":
                    action = numpy.random.choice(self.game.legal_actions(observation))
                else:
                    raise ValueError(
                        'Wrong argument: "opponent" argument should be "self", "human" or "random"'
                    )

                observation, reward, done = self.game.step(action, observation)

                if render:
                    print(
                        "Played action: {}".format(self.game.action_to_string(action))
                    )
                    self.game.render()

                game_history.store_search_statistics(root, self.config.action_space)
                if not self.config.use_max_priority:
                    game_history.priorities.append(priority)

                # Next batch
                game_history.action_history.append(action)
                game_history.observation_history.append(observation)
                game_history.reward_history.append(reward)
                game_history.to_play_history.append(self.game.to_play())

        self.game.close()
        return game_history

    @staticmethod
    def select_action(node, temperature):
        """
        Select action according to the visit count distribution and the temperature.
        The temperature is changed dynamically with the visit_softmax_temperature function 
        in the config.
        """
        visit_counts = numpy.array(
            [child.visit_count for child in node.children.values()]
        )
        actions = [action for action in node.children.keys()]
        if temperature == 0:
            action = actions[numpy.argmax(visit_counts)]
        elif temperature == float("inf"):
            action = numpy.random.choice(actions)
        else:
            # See paper appendix Data Generation
            visit_count_distribution = visit_counts ** (1 / temperature)
            visit_count_distribution = visit_count_distribution / sum(
                visit_count_distribution
            )
            action = numpy.random.choice(actions, p=visit_count_distribution)

        return action


# Game independent
class MCTS:
    """
    Core Monte Carlo Tree Search algorithm.
    To decide on an action, we run N simulations, always starting at the root of
    the search tree and traversing the tree according to the UCB formula until we
    reach a leaf node.
    """

    def __init__(self, config):
        self.config = config

    def run(self, model, observation, legal_actions, to_play, add_exploration_noise):
        """
        At the root of the search tree we use the representation function to obtain a
        hidden state given the current observation.
        We then run a Monte Carlo Tree Search using only action sequences and the model
        learned by the network.
        """
        root = Node(0)
        dn = DynamicNetwork(self.config.observation_shape, self.config.blocks, 
                            self.config.channels, self.config.reduced_channels, 
                            self.config.resnet_fc_reward_layers, self.config.support_size, 
                            self.config.blocks)
        observation = (
            torch.tensor(observation)
            .float()
            .unsqueeze(0)
            .to(next(model.parameters()).device)
        )
        # SCHORY
        print(observation)
        (
            root_predicted_value,
            reward,
            policy_logits,
            hidden_state,
        ) = model.initial_inference(observation)
        root_predicted_value = dn.support_to_scalar(
            root_predicted_value, self.config.support_size
        ).item()
        reward = dn.support_to_scalar(reward, self.config.support_size).item()
        root.expand(
            legal_actions, to_play, reward, policy_logits, hidden_state,
        )
        if add_exploration_noise:
            root.add_exploration_noise(
                dirichlet_alpha=self.config.root_dirichlet_alpha,
                exploration_fraction=self.config.root_exploration_fraction,
            )

        min_max_stats = MinMaxStats()

        max_tree_depth = 0
        for _ in range(self.config.num_simulations):
            virtual_to_play = to_play
            node = root
            search_path = [node]
            current_tree_depth = 0

            while node.expanded():
                current_tree_depth += 1
                action, node = self.select_child(node, min_max_stats)
                search_path.append(node)

                # Players play turn by turn
                if virtual_to_play + 1 < len(self.config.players):
                    virtual_to_play = self.config.players[virtual_to_play + 1]
                else:
                    virtual_to_play = self.config.players[0]

            # Inside the search tree we use the dynamics function to obtain the next hidden
            # state given an action and the previous hidden state
            parent = search_path[-2]
            value, reward, policy_logits, hidden_state = model.recurrent_inference(
                parent.hidden_state,
                torch.tensor([[action]]).to(parent.hidden_state.device),
            )
            value = dn.support_to_scalar(value, self.config.support_size).item()
            reward = dn.support_to_scalar(reward, self.config.support_size).item()
            node.expand(
                self.config.action_space,
                virtual_to_play,
                reward,
                policy_logits,
                hidden_state,
            )

            self.backpropagate(search_path, value, virtual_to_play, min_max_stats)

            max_tree_depth = max(max_tree_depth, current_tree_depth)

        priority = (
            None
            if self.config.use_max_priority
            else numpy.abs(root_predicted_value - root.value()) ** self.config.PER_alpha
        )

        return root, priority, max_tree_depth

    def select_child(self, node, min_max_stats):
        """
        Select the child with the highest UCB score.
        """
        _, action, child = max(
            (self.ucb_score(node, child, min_max_stats), action, child)
            for action, child in node.children.items()
        )
        return action, child

    def ucb_score(self, parent, child, min_max_stats):
        """
        The score for a node is based on its value, plus an exploration bonus based on the prior.
        """
        pb_c = (
            math.log(
                (parent.visit_count + self.config.pb_c_base + 1) / self.config.pb_c_base
            )
            + self.config.pb_c_init
        )
        pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1)

        prior_score = pb_c * child.prior

        if child.visit_count > 0:
            # mean value Q
            value_score = min_max_stats.normalize(
                child.reward + self.config.discount * child.value()
            )
        else:
            value_score = 0

        return prior_score + value_score

    def backpropagate(self, search_path, value, to_play, min_max_stats):
        """
        At the end of a simulation, we propagate the evaluation all the way up the tree
        to the root.
        """
        for node in reversed(search_path):
            node.value_sum += value if node.to_play == to_play else -value
            node.visit_count += 1
            min_max_stats.update(node.reward + self.config.discount * node.value())

            value = node.reward + self.config.discount * value


class Node:
    def __init__(self, prior):
        self.visit_count = 0
        self.to_play = -1
        self.prior = prior
        self.value_sum = 0
        self.children = {}
        self.hidden_state = None
        self.reward = 0

    def expanded(self):
        return len(self.children) > 0

    def value(self):
        if self.visit_count == 0:
            return 0
        return self.value_sum / self.visit_count

    def expand(self, actions, to_play, reward, policy_logits, hidden_state):
        """
        We expand a node using the value, reward and policy prediction obtained from the
        neural network.
        """
        self.to_play = to_play
        self.reward = reward
        self.hidden_state = hidden_state
        policy = {}
        for a in actions:
            try:
                policy[a] = 1 / sum(torch.exp(policy_logits[0] - policy_logits[0][a]))
            except OverflowError:
                print("Warning: prior has been approximated")
                policy[a] = 0.0
        for action, p in policy.items():
            self.children[action] = Node(p)

    def add_exploration_noise(self, dirichlet_alpha, exploration_fraction):
        """
        At the start of each search, we add dirichlet noise to the prior of the root to
        encourage the search to explore new actions.
        """
        actions = list(self.children.keys())
        noise = numpy.random.dirichlet([dirichlet_alpha] * len(actions))
        frac = exploration_fraction
        for a, n in zip(actions, noise):
            self.children[a].prior = self.children[a].prior * (1 - frac) + n * frac


class GameHistory:
    """
    Store only usefull information of a self-play game.
    """

    def __init__(self):
        self.observation_history = []
        self.action_history = []
        self.reward_history = []
        self.to_play_history = []
        self.child_visits = []
        self.root_values = []
        self.priorities = []

    def store_search_statistics(self, root, action_space):
        # Turn visit count from root into a policy
        sum_visits = sum(child.visit_count for child in root.children.values())
        self.child_visits.append(
            [
                root.children[a].visit_count / sum_visits if a in root.children else 0
                for a in action_space
            ]
        )

        self.root_values.append(root.value())

    def get_stacked_observations(self, index, num_stacked_observations):
        """
        Generate a new observation with the observation at the index position
        and num_stacked_observations past observations and actions stacked.
        """
        # Convert to positive index
        index = index % len(self.observation_history)

        stacked_observations = self.observation_history[index].copy()
        for past_observation_index in reversed(
            range(index - num_stacked_observations, index)
        ):
            if 0 <= past_observation_index:
                previous_observation = numpy.concatenate(
                    (
                        self.observation_history[past_observation_index],
                        [
                            numpy.ones_like(stacked_observations[0])
                            * self.action_history[past_observation_index + 1]
                        ],
                    )
                )
            else:
                previous_observation = numpy.concatenate(
                    (
                        numpy.zeros_like(self.observation_history[index]),
                        [numpy.zeros_like(stacked_observations[0])],
                    )
                )

            stacked_observations = numpy.concatenate(
                (stacked_observations, previous_observation)
            )

        return stacked_observations


class MinMaxStats:
    """
    A class that holds the min-max values of the tree.
    """

    def __init__(self):
        self.maximum = -float("inf")
        self.minimum = float("inf")

    def update(self, value):
        self.maximum = max(self.maximum, value)
        self.minimum = min(self.minimum, value)

    def normalize(self, value):
        if self.maximum > self.minimum:
            # We normalize only when we have set the maximum and minimum values
            return (value - self.minimum) / (self.maximum - self.minimum)
        return value

In [0]:
import copy

import numpy
import ray
import torch



@ray.remote
class ReplayBuffer:
    """
    Class which run in a dedicated thread to store played games and generate batch.
    """

    def __init__(self, config):
        self.config = config
        self.buffer = []
        self.game_priorities = []
        self.max_recorded_game_priority = 1.0
        self.self_play_count = 0

        self.model = MuZeroNetwork(self.config)

        # Fix random generator seed
        numpy.random.seed(self.config.seed)
        torch.manual_seed(self.config.seed)

    def save_game(self, game_history):
        if len(self.buffer) > self.config.window_size:
            self.buffer.pop(0)
            self.game_priorities.pop(0)

        if self.config.use_max_priority:
            game_history.priorities = (
                numpy.ones(len(game_history.root_values))
                * self.max_recorded_game_priority
            )
        self.buffer.append(game_history)
        self.game_priorities.append(numpy.mean(game_history.priorities))
        self.self_play_count += 1

    def get_self_play_count(self):
        return self.self_play_count

    def get_batch(self, model_weights):
        (
            index_batch,
            observation_batch,
            action_batch,
            reward_batch,
            value_batch,
            policy_batch,
            weight_batch,
            gradient_scale_batch,
        ) = ([], [], [], [], [], [], [], [])

        total_samples = sum(
            (len(game_history.priorities) for game_history in self.buffer)
        )

        if self.config.use_last_model_value:
            self.model.set_weights(model_weights)

        for _ in range(self.config.batch_size):
            game_index, game_history, game_prob = self.sample_game(self.buffer)
            game_pos, pos_prob = self.sample_position(game_history)

            values, rewards, policies, actions = self.make_target(
                game_history, game_pos
            )

            index_batch.append([game_index, game_pos])
            observation_batch.append(game_history.get_stacked_observations(game_pos, self.config.stacked_observations))
            action_batch.append(actions)
            value_batch.append(values)
            reward_batch.append(rewards)
            policy_batch.append(policies)
            weight_batch.append(
                (total_samples * game_prob * pos_prob) ** (-self.config.PER_beta)
            )
            gradient_scale_batch.append(
                [
                    min(
                        self.config.num_unroll_steps,
                        len(game_history.action_history) - game_pos,
                    )
                ]
                * len(actions)
            )

        weight_batch = numpy.array(weight_batch) / max(weight_batch)

        # observation_batch: batch, channels, height, width
        # action_batch: batch, num_unroll_steps+1
        # value_batch: batch, num_unroll_steps+1
        # reward_batch: batch, num_unroll_steps+1
        # policy_batch: batch, num_unroll_steps+1, len(action_space)
        # weight_batch: batch
        # gradient_scale_batch: batch, num_unroll_steps+1
        return (
            index_batch,
            (
                observation_batch,
                action_batch,
                value_batch,
                reward_batch,
                policy_batch,
                weight_batch,
                gradient_scale_batch,
            ),
        )

    def sample_game(self, buffer):
        """
        Sample game from buffer either uniformly or according to some priority.
        See paper appendix Training.
        """
        game_probs = numpy.array(self.game_priorities) / sum(self.game_priorities)
        game_index_candidates = numpy.arange(0, len(self.buffer), dtype=int)
        game_index = numpy.random.choice(game_index_candidates, p=game_probs)
        game_prob = game_probs[game_index]

        return game_index, self.buffer[game_index], game_prob

    def sample_position(self, game_history):
        """
        Sample position from game either uniformly or according to some priority.
        See paper appendix Training.
        """
        position_probs = numpy.array(game_history.priorities) / sum(
            game_history.priorities
        )
        position_index_candidates = numpy.arange(0, len(position_probs), dtype=int)
        position_index = numpy.random.choice(
            position_index_candidates, p=position_probs
        )
        position_prob = position_probs[position_index]

        return position_index, position_prob

    def update_priorities(self, priorities, index_info):
        """
        Update game and position priorities with priorities calculated during the training.
        See Distributed Prioritized Experience Replay https://arxiv.org/abs/1803.00933
        """
        for i in range(len(index_info)):
            game_index, game_pos = index_info[i]

            # update position priorities
            priority = priorities[i, :]
            start_index = game_pos
            end_index = min(
                game_pos + len(priority), len(self.buffer[game_index].priorities)
            )
            self.buffer[game_index].priorities[start_index:end_index] = priority[
                : end_index - start_index
            ]

            # update game priorities
            self.game_priorities[game_index] = numpy.max(
                self.buffer[game_index].priorities
            )  # option: mean, sum, max

            self.max_recorded_game_priority = numpy.max(self.game_priorities)

    def make_target(self, game_history, state_index):
        """
        Generate targets for every unroll steps.
        """
        dn = DynamicNetwork()
        target_values, target_rewards, target_policies, actions = [], [], [], []
        for current_index in range(
            state_index, state_index + self.config.num_unroll_steps + 1
        ):
            # The value target is the discounted root value of the search tree td_steps into the
            # future, plus the discounted sum of all rewards until then.
            bootstrap_index = current_index + self.config.td_steps
            if bootstrap_index < len(game_history.root_values):
                if self.config.use_last_model_value:
                    # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
                    observation = torch.tensor(
                        game_history.get_stacked_observations(bootstrap_index, self.config.stacked_observations)
                    ).float()
                    last_step_value = dn.support_to_scalar(
                        self.model.initial_inference(observation)[0],
                        self.config.support_size,
                    ).item()
                else:
                    last_step_value = game_history.root_values[bootstrap_index]

                value = last_step_value * self.config.discount ** self.config.td_steps
            else:
                value = 0

            for i, reward in enumerate(
                game_history.reward_history[current_index + 1 : bootstrap_index + 1]
            ):
                value += (
                    reward
                    if game_history.to_play_history[current_index]
                    == game_history.to_play_history[current_index + 1 + i]
                    else -reward
                ) * self.config.discount ** i

            if current_index < len(game_history.root_values):
                target_values.append(value)
                target_rewards.append(game_history.reward_history[current_index])
                target_policies.append(game_history.child_visits[current_index])
                actions.append(game_history.action_history[current_index])
            elif current_index == len(game_history.root_values):
                target_values.append(0)
                target_rewards.append(game_history.reward_history[current_index])
                # Uniform policy
                target_policies.append(
                    [
                        1 / len(game_history.child_visits[0])
                        for _ in range(len(game_history.child_visits[0]))
                    ]
                )
                actions.append(game_history.action_history[current_index])
            else:
                # States past the end of games are treated as absorbing states
                target_values.append(0)
                target_rewards.append(0)
                # Uniform policy
                target_policies.append(
                    [
                        1 / len(game_history.child_visits[0])
                        for _ in range(len(game_history.child_visits[0]))
                    ]
                )
                actions.append(numpy.random.choice(game_history.action_history))

        return target_values, target_rewards, target_policies, actions

**Trainer**

In [0]:
import time

import numpy
import ray
import torch



@ray.remote
class Trainer:
    """
    Class which run in a dedicated thread to train a neural network and save it
    in the shared storage.
    """

    def __init__(self, initial_weights, config):
        self.config = config
        self.training_step = 0

        # Fix random generator seed
        numpy.random.seed(self.config.seed)
        torch.manual_seed(self.config.seed)

        # Initialize the network
        self.model = MuZeroNetwork(self.config)
        self.model.set_weights(initial_weights)
        self.model.to(torch.device(config.training_device))
        self.model.train()

        if self.config.optimizer == "SGD":
            self.optimizer = torch.optim.SGD(
                self.model.parameters(),
                lr=self.config.lr_init,
                momentum=self.config.momentum,
                weight_decay=self.config.weight_decay,
            )
        elif self.config.optimizer == "Adam":
            self.optimizer = torch.optim.Adam(
                self.model.parameters(),
                lr=self.config.lr_init,
                weight_decay=self.config.weight_decay,
            )
        else:
            raise ValueError(
                "{} is not implemented. You can change the optimizer manually in trainer.py."
            )

    def continuous_update_weights(self, replay_buffer, shared_storage_worker):
        # Wait for the replay buffer to be filled
        while ray.get(replay_buffer.get_self_play_count.remote()) < 1:
            time.sleep(0.1)

        # Training loop
        while True:
            index_batch, batch = ray.get(replay_buffer.get_batch.remote(self.model.get_weights()))
            self.update_lr()
            (
                priorities,
                total_loss,
                value_loss,
                reward_loss,
                policy_loss,
            ) = self.update_weights(batch)

            if self.config.PER:
                # Save new priorities in the replay buffer (See https://arxiv.org/abs/1803.00933)
                replay_buffer.update_priorities.remote(priorities, index_batch)

            # Save to the shared storage
            if self.training_step % self.config.checkpoint_interval == 0:
                shared_storage_worker.set_weights.remote(self.model.get_weights())
            shared_storage_worker.set_infos.remote("training_step", self.training_step)
            shared_storage_worker.set_infos.remote(
                "lr", self.optimizer.param_groups[0]["lr"]
            )
            shared_storage_worker.set_infos.remote("total_loss", total_loss)
            shared_storage_worker.set_infos.remote("value_loss", value_loss)
            shared_storage_worker.set_infos.remote("reward_loss", reward_loss)
            shared_storage_worker.set_infos.remote("policy_loss", policy_loss)

            # Managing the self-play / training ratio
            if self.config.training_delay:
                time.sleep(self.config.training_delay)
            if self.config.ratio:
                while (
                    ray.get(replay_buffer.get_self_play_count.remote())
                    / max(1, self.training_step)
                    < self.config.ratio
                ):
                    time.sleep(0.5)

    def update_weights(self, batch):
        """
        Perform one training step.
        """

        (
            observation_batch,
            action_batch,
            target_value,
            target_reward,
            target_policy,
            weight_batch,
            gradient_scale_batch,
        ) = batch

        # Keep values as scalars for calculating the priorities for the prioritized replay
        target_value_scalar = numpy.array(target_value)
        priorities = numpy.zeros_like(target_value_scalar)
        dn = DynamicNetwork()

        device = next(self.model.parameters()).device
        weight_batch = torch.tensor(weight_batch).float().to(device)
        observation_batch = torch.tensor(observation_batch).float().to(device)
        action_batch = torch.tensor(action_batch).float().to(device).unsqueeze(-1)
        target_value = torch.tensor(target_value).float().to(device)
        target_reward = torch.tensor(target_reward).float().to(device)
        target_policy = torch.tensor(target_policy).float().to(device)
        gradient_scale_batch = torch.tensor(gradient_scale_batch).float().to(device)
        # observation_batch: batch, channels, height, width
        # action_batch: batch, num_unroll_steps+1, 1 (unsqueeze)
        # target_value: batch, num_unroll_steps+1
        # target_reward: batch, num_unroll_steps+1
        # target_policy: batch, num_unroll_steps+1, len(action_space)
        # gradient_scale_batch: batch, num_unroll_steps+1

        target_value = dn.scalar_to_support(target_value, self.config.support_size)
        target_reward = dn.scalar_to_support(target_reward, self.config.support_size)
        # target_value: batch, num_unroll_steps+1, 2*support_size+1
        # target_reward: batch, num_unroll_steps+1, 2*support_size+1

        ## Generate predictions
        value, reward, policy_logits, hidden_state = self.model.initial_inference(
            observation_batch
        )
        predictions = [(value, reward, policy_logits)]
        for i in range(1, action_batch.shape[1]):
            value, reward, policy_logits, hidden_state = self.model.recurrent_inference(
                hidden_state, action_batch[:, i]
            )
            # Scale the gradient at the start of the dynamics function (See paper appendix Training)
            hidden_state.register_hook(lambda grad: grad * 0.5)
            predictions.append((value, reward, policy_logits))
        # predictions: num_unroll_steps+1, 3, batch, 2*support_size+1 | 2*support_size+1 | 9 (according to the 2nd dim)

        ## Compute losses
        value_loss, reward_loss, policy_loss = (0, 0, 0)
        value, reward, policy_logits = predictions[0]
        # Ignore reward loss for the first batch step
        current_value_loss, _, current_policy_loss = self.loss_function(
            value.squeeze(-1),
            reward.squeeze(-1),
            policy_logits,
            target_value[:, 0],
            target_reward[:, 0],
            target_policy[:, 0],
        )
        value_loss += current_value_loss
        policy_loss += current_policy_loss
        # Compute priorities for the prioritized replay (See paper appendix Training)
        pred_value_scalar = (
            models.support_to_scalar(value, self.config.support_size)
            .detach()
            .cpu()
            .numpy()
            .squeeze()
        )
        priorities[:, 0] = (
            numpy.abs(pred_value_scalar - target_value_scalar[:, 0])
            ** self.config.PER_alpha
        )

        for i in range(1, len(predictions)):
            value, reward, policy_logits = predictions[i]
            (
                current_value_loss,
                current_reward_loss,
                current_policy_loss,
            ) = self.loss_function(
                value.squeeze(-1),
                reward.squeeze(-1),
                policy_logits,
                target_value[:, i],
                target_reward[:, i],
                target_policy[:, i],
            )

            # Scale gradient by the number of unroll steps (See paper appendix Training)
            current_value_loss.register_hook(
                lambda grad: grad / gradient_scale_batch[:, i]
            )
            current_reward_loss.register_hook(
                lambda grad: grad / gradient_scale_batch[:, i]
            )
            current_policy_loss.register_hook(
                lambda grad: grad / gradient_scale_batch[:, i]
            )

            value_loss += current_value_loss
            reward_loss += current_reward_loss
            policy_loss += current_policy_loss

            # Compute priorities for the prioritized replay (See paper appendix Training)
            pred_value_scalar = (
                models.support_to_scalar(value, self.config.support_size)
                .detach()
                .cpu()
                .numpy()
                .squeeze()
            )
            priorities[:, i] = (
                numpy.abs(pred_value_scalar - target_value_scalar[:, i])
                ** self.config.PER_alpha
            )

        # Scale the value loss, paper recommends by 0.25 (See paper appendix Reanalyze)
        loss = value_loss * self.config.value_loss_weight + reward_loss + policy_loss
        if self.config.PER:
            # Correct PER bias by using importance-sampling (IS) weights
            loss *= weight_batch
        # Mean over batch dimension (pseudocode do a sum)
        loss = loss.mean()

        # Optimize
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.training_step += 1

        return (
            priorities,
            # For log purpose
            loss.item(),
            value_loss.mean().item(),
            reward_loss.mean().item(),
            policy_loss.mean().item(),
        )

    def update_lr(self):
        """
        Update learning rate
        """
        lr = self.config.lr_init * self.config.lr_decay_rate ** (
            self.training_step / self.config.lr_decay_steps
        )
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lr

    @staticmethod
    def loss_function(
        value, reward, policy_logits, target_value, target_reward, target_policy,
    ):
        # Cross-entropy seems to have a better convergence than MSE
        value_loss = (-target_value * torch.nn.LogSoftmax(dim=1)(value)).sum(1)
        reward_loss = (-target_reward * torch.nn.LogSoftmax(dim=1)(reward)).sum(1)
        policy_loss = (-target_policy * torch.nn.LogSoftmax(dim=1)(policy_logits)).sum(
            1
        )
        return value_loss, reward_loss, policy_loss

#MCTS

In [0]:
# import math
# import numpy as np
# EPS = 1e-8

# class MCTS():
#     """
#     This class handles the MCTS tree.
#     """

#     def __init__(self, game, nnet, args):
#         self.game = game
#         self.nnet = nnet
#         self.args = args
#         self.Qsa = {}       # stores Q values for s,a (as defined in the paper)
#         self.Nsa = {}       # stores #times edge s,a was visited
#         self.Ns = {}        # stores #times board s was visited
#         self.Ps = {}        # stores initial policy (returned by neural net)

#         self.Es = {}        # stores game.getGameEnded ended for board s
#         self.Vs = {}        # stores game.getValidMoves for board s
        
#         self.plot = [1000]
#         self.num_sim = [0]

#     def getActionProb(self, graphState, temp=1):
#         """
#         This function performs numMCTSSims simulations of MCTS starting from
#         canonicalBoard.
#         Returns:
#             probs: a policy vector where the probability of the ith action is
#                    proportional to Nsa[(s,a)]**(1./temp)
#         """
#         for i in range(self.args.numMCTSSims):
#             self.search(graphState, i)

#         s = self.game.stringRepresentation(graphState)
#         counts = [self.Nsa[(s,a)] if (s,a) in self.Nsa else 0 for a in range(self.game.getActionSize())]

#         if temp==0:
#             bestA = np.argmax(counts)
#             probs = [0]*len(counts)
#             probs[bestA]=1
#             return probs

#         counts = [x**(1./temp) for x in counts]
#         counts_sum = float(sum(counts))
#         probs = [x/counts_sum for x in counts]
#         return probs


#     def search(self, graphState, num_sim):
#         """
#         This function performs one iteration of MCTS. It is recursively called
#         till a leaf node is found. The action chosen at each node is one that
#         has the maximum upper confidence bound as in the paper.
#         Once a leaf node is found, the neural network is called to return an
#         initial policy P and a value v for the state. This value is propagated
#         up the search path. In case the leaf node is a terminal state, the
#         outcome is propagated up the search path. The values of Ns, Nsa, Qsa are
#         updated.
#         NOTE: the return values are the negative of the value of the current
#         state. This is done since v is in [-1,1] and if v is the value of a
#         state for the current player, then its value is -v for the other player.
#         Returns:
#             v: the negative of the value of the current canonicalBoard
#         """

#         s = self.game.stringRepresentation(graphState)

#         if s not in self.Es:
#             self.Es[s] = self.game.getGameEnded(graphState)
#         if self.Es[s]!=0:
#             # terminal node
#             return 0

#         if s not in self.Ps:
#             # leaf node
#             if self.nnet is not None:
#                 self.Ps[s], v = self.nnet.predict(graphState, self.game.graph)
#             else:
#                 self.Ps[s] = np.ones(self.game.getActionSize()) # random policy
#                 v = 0
#             valids = self.game.getValidMoves(graphState)
#             self.Ps[s] = self.Ps[s]*valids      # masking invalid moves
#             sum_Ps_s = np.sum(self.Ps[s])
#             if sum_Ps_s > 0:
#                 self.Ps[s] /= sum_Ps_s    # renormalize
#             else:
#                 # if all valid moves were masked make all valid moves equally probable
                
#                 # NB! All valid moves may be masked if either your NNet architecture is insufficient or you've get overfitting or something else.
#                 # If you have got dozens or hundreds of these messages you should pay attention to your NNet and/or training process.   
#                 print("All valid moves were masked, do workaround.")
#                 self.Ps[s] = self.Ps[s] + valids
#                 self.Ps[s] /= np.sum(self.Ps[s])

#             self.Vs[s] = valids
#             self.Ns[s] = 0
#             return v

#         valids = self.Vs[s]
#         cur_best = -float('inf')
#         best_act = -1

#         # pick the action with the highest upper confidence bound
#         for a in range(self.game.getActionSize()):
#             if valids[a]:
#                 if (s,a) in self.Qsa:
#                     u = self.Qsa[(s,a)] + self.args.cpuct*self.Ps[s][a]*math.sqrt(self.Ns[s])/(1+self.Nsa[(s,a)])
#                 else:
#                     u = self.args.cpuct*self.Ps[s][a]*math.sqrt(self.Ns[s] + EPS)     # Q = 0 ?

#                 if u > cur_best:
#                     cur_best = u
#                     best_act = a

#         a = best_act
#         next_s, reward = self.game.getNextState(graphState, a)
#         # next_s = self.game.getCanonicalForm(next_s, next_player)

#         v = self.search(next_s, num_sim) + reward

#         if (s,a) in self.Qsa:
#             self.Qsa[(s,a)] = (self.Nsa[(s,a)]*self.Qsa[(s,a)] + v)/(self.Nsa[(s,a)]+1)
#             self.Nsa[(s,a)] += 1

#         else:
#             self.Qsa[(s,a)] = v
#             self.Nsa[(s,a)] = 1
            
#         if v > self.game.getActionSize() - self.plot[-1]:
#             self.plot.append(self.game.getActionSize() - v)
#             self.num_sim.append(num_sim)

#         self.Ns[s] += 1
#         return v