# MuZero

This notebook is just an example to show how to run [MuZero (https://github.com/werner-duvaud/muzero-general)](https://github.com/werner-duvaud/muzero-general) in Google Colab or Jupyter Notebook. You can also launch MuZero directly by cloning the github repository and running the command `python muzero.py`. See [readme](https://github.com/werner-duvaud/muzero-general) for detailed instructions.

In [1]:
import math
from abc import ABC, abstractmethod

import torch

In [2]:
import tensorflow as tf

In [3]:
logits_torch = torch.tensor([[-0.0078,  0.4007, -0.1592,  0.3446,  0.2232, -0.0814, -0.1400,  0.6435,
                     0.5683, -0.4653, -0.3925,  0.0400,  0.1729, -0.0152,  0.0551, -0.1042,
                     0.5908,  0.0915,  0.2475,  0.2067, -0.5062]])
support_size_torch = 10
x_torch = torch.tensor([[-0.4167]])

In [20]:
def support_to_scalar_torch(logits, support_size):
    """
    Transform a categorical representation to a scalar
    See paper appendix Network Architecture
    """
    # Decode to a scalar
    probabilities = torch.softmax(logits, dim=1)
    
    print("stonks:",
        torch.tensor([x for x in range(-support_size, support_size + 1)])
        .expand(probabilities.shape)
    )
    support = (
        torch.tensor([x for x in range(-support_size, support_size + 1)])
        .expand(probabilities.shape)
        .float()
        .to(device=probabilities.device)
    )
    x = torch.sum(support * probabilities, dim=1, keepdim=True)
    print("mul stonks:", support * probabilities)
    print("big stonks:", x)

    # Invert the scaling (defined in https://arxiv.org/abs/1805.11593)
    x = torch.sign(x) * (
        ((torch.sqrt(1 + 4 * 0.001 * (torch.abs(x) + 1 + 0.001)) - 1) / (2 * 0.001))
        ** 2
        - 1
    )
    return x

support_to_scalar_torch(logits_torch, support_size_torch)

stonks: tensor([[-10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,
           4,   5,   6,   7,   8,   9,  10]])
mul stonks: tensor([[-0.4143, -0.5610, -0.2849, -0.4125, -0.3132, -0.1925, -0.1452, -0.2384,
         -0.1474, -0.0262,  0.0000,  0.0435,  0.0993,  0.1234,  0.1765,  0.1881,
          0.4523,  0.3203,  0.4279,  0.4621,  0.2517]])
big stonks: tensor([[-0.1907]])


tensor([[-0.4167]])

In [5]:
logits_tf = tf.constant([[-0.0078,  0.4007, -0.1592,  0.3446,  0.2232, -0.0814, -0.1400,  0.6435,
                     0.5683, -0.4653, -0.3925,  0.0400,  0.1729, -0.0152,  0.0551, -0.1042,
                     0.5908,  0.0915,  0.2475,  0.2067, -0.5062]])
support_size_tf = 10
x_tf = tf.constant([[-0.4167]])

In [27]:
def support_to_scalar_tf(logits, support_size):
    """
    Transform a categorical representation to a scalar
    See paper appendix Network Architecture
    """
    # Decode to a scalar
    probabilities = tf.nn.softmax(logits, axis=1)
    support = tf.constant([x for x in range(-support_size, support_size + 1)], dtype=tf.float32)
    support = tf.broadcast_to(support, probabilities.shape)
    x = tf.math.reduce_sum(support * probabilities, axis=1, keepdims=True)
    
    # Invert the scaling (defined in https://arxiv.org/abs/1805.11593)
    x = tf.math.sign(x) * (
        ((tf.math.sqrt(1 + 4 * 0.001 * (tf.math.abs(x) + 1 + 0.001)) - 1) / (2 * 0.001))
        ** 2
        - 1
    )
    return x

support_to_scalar_tf(logits_tf, support_size_tf)

<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-0.41668355]], dtype=float32)>

### scalar to support

In [52]:
def scalar_to_support_torch(x, support_size):
    """
    Transform a scalar to a categorical representation with (2 * support_size + 1) categories
    See paper appendix Network Architecture
    """
    # Reduce the scale (defined in https://arxiv.org/abs/1805.11593)
    x = torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + 0.001 * x

    # Encode on a vector
    x = torch.clamp(x, -support_size, support_size)
    floor = x.floor()
    prob = x - floor
#     print("prob:", prob.shape)
    logits = torch.zeros(x.shape[0], x.shape[1], 2 * support_size + 1).to(x.device)
    logits.scatter_(
        2, (floor + support_size).long().unsqueeze(-1), (1 - prob).unsqueeze(-1)
    )
    print("stonks:", logits)
    print("stonks:", logits.shape)
    indexes = floor + support_size + 1
    prob = prob.masked_fill_(2 * support_size < indexes, 0.0)
    indexes = indexes.masked_fill_(2 * support_size < indexes, 0.0)
    logits.scatter_(2, indexes.long().unsqueeze(-1), prob.unsqueeze(-1))
#     return logits

scalar_to_support_torch(
    torch.tensor([[2.1025, 2.0129, 2.1019, 2.0780, 2.0724, 2.0395],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [2.1156, 1.9219, 1.9878, 2.0014, 2.0230, 1.9698],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.9868, 2.0178, 2.1383, 1.9936, 2.0089, 1.9848],
        [1.9503, 1.9601, 2.0509, 1.9869, 1.9181, 1.9885],
        [2.0072, 1.9588, 2.0606, 2.0278, 2.0264, 2.0652]]), 
    support_size_torch)

stonks: tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0

In [50]:
def scalar_to_support_tf(x, support_size):
    """
    Transform a scalar to a categorical representation with (2 * support_size + 1) categories
    See paper appendix Network Architecture
    """
    # Reduce the scale (defined in https://arxiv.org/abs/1805.11593)
    x = tf.math.sign(x) * (tf.math.sqrt(tf.math.abs(x) + 1) - 1) + 0.001 * x

    # Encode on a vector
    x = tf.clip_by_value(x, -support_size, support_size)
#     print("x:", x)
    floor = tf.math.floor(x)
    prob = x - floor
#     print("prob:", prob.shape)
    logits = tf.zeros([x.shape[0], x.shape[1], 2 * support_size + 1])
    logits = tf.scatter_nd(
        tf.expand_dims(floor + support_size, -1), tf.expand_dims(1 - prob, -1) # shape required, account for dims=2
    )
    print("stonks:", logits)
    indexes = floor + support_size + 1
    prob = tf.where(2 * support_size < indexes, prob, 0.0)
    indexes = tf.where(2 * support_size < indexes, indexes, 0.0)
    logits.tf.scatter_nd(
        tf.expand_dims(indexes.long(), -1), tf.expand_dims(prob, -1) # shape required, account for dims=2
    )
#     return logits
scalar_to_support_tf(
    tf.constant([[2.1025, 2.0129, 2.1019, 2.0780, 2.0724, 2.0395],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [2.1156, 1.9219, 1.9878, 2.0014, 2.0230, 1.9698],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.9868, 2.0178, 2.1383, 1.9936, 2.0089, 1.9848],
        [1.9503, 1.9601, 2.0509, 1.9869, 1.9181, 1.9885],
        [2.0072, 1.9588, 2.0606, 2.0278, 2.0264, 2.0652]]), 
    support_size_tf)

stonks: tf.Tensor(
[[[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 ...

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]], shape=(16, 6, 21), dtype=f

TypeError: scatter_nd() missing 1 required positional argument: 'shape'