In [1]:
import torch
import gym as g
from gym import spaces
from connect4 import *
from envs import ConnectNEnv
from networks.architecture import RepresentationNetwork, DynamicsNetwork, PredictionNetwork
import numpy as np

In [2]:
env = ConnectNEnv()

In [3]:
# Test
env.step(0)

({'observations': array([[0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0]], dtype=int8),
  'action_mask': array([1, 1, 1, 1, 1, 1, 1], dtype=int8),
  'player_1_board': array([[0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0]], dtype=int8),
  'player_2_board': array([[0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0]], dtype=int8),
  'current_player': 'P2'},
 0.0,
 False,
 {})

In [4]:
a = {'b': 1, 'c': 2}
def foo(b,c):
    return b+c

foo(**a)

3

In [5]:
import typing
from typing import Tuple
import utils as ut

SUPPORT_SIZE_DEFAULT = 601
ENCODED_CHANNELS = 256

class NetworkOutput(typing.NamedTuple):
    value: torch.Tensor
    reward: torch.Tensor
    policy_logits: torch.Tensor
    hidden_state: torch.Tensor

class MuZeroNetwork:

    def __init__(self, num_of_features, board_total_slots, n_possible_actions, configs=None):

        # Setup configs
        configs = self.default_configs(configs)


        self.representation_network = RepresentationNetwork(in_channels=num_of_features,
                                                            **configs['representation'])

        self.prediction_network = PredictionNetwork(in_channels=configs['representation']['n_channels'], 
                                                    board_total_slots=board_total_slots,
                                                    action_space_size=n_possible_actions,
                                                    **configs['prediction'])
        
        self.dynamics_network = DynamicsNetwork(in_channels=configs['representation']['n_channels']+1,
                                                  board_total_slots=board_total_slots,
                                                  **configs['dynamics'])
        
        self.prediction_support_size = configs['prediction']['support_size']
        self.dynamics_support_size = configs['dynamics']['support_size']
        self.action_space_size = n_possible_actions

    def default_configs(self, configs):
        if configs is None:
            configs = {"prediction": {}, "representation": {}, "dynamics": {}}
        # Prediction Network
        prediction = {
            "n_convs": 2,
            "n_channels": ENCODED_CHANNELS,
            "n_residual_layers": 10,
            "kernel_size": (3,3),
            "support_size": SUPPORT_SIZE_DEFAULT
        }
        if "prediction" not in configs: 
            configs["prediction"] = prediction
        else:
            ut.fill_defaults(configs["prediction"], prediction)
            # Check if Support Size is ok
            if not (configs["prediction"]['support_size']-1) % 2 == 0: 
                print("[NETWORK - Prediction] Support Size invalid. Set to default = {}.".format(SUPPORT_SIZE_DEFAULT))
                configs["prediction"]['support_size'] = SUPPORT_SIZE_DEFAULT
        # Representation Network
        representation = {
            "n_channels": ENCODED_CHANNELS,
            "n_residual_layers": 10,
            "kernel_size": (3,3)
        }
        if "representation" not in configs: 
            configs["representation"] = representation
        else:
            ut.fill_defaults(configs["representation"], representation)
        # Dynamics Network
        dynamics = {
            "n_convs": 2,
            "n_channels": ENCODED_CHANNELS,
            "n_residual_layers": 10,
            "kernel_size": (3,3),
            "support_size": SUPPORT_SIZE_DEFAULT
        }
        if "dynamics" not in configs: 
            configs["dynamics"] = dynamics
        else:
            ut.fill_defaults(configs["dynamics"], dynamics)
            # Check if Support Size is ok
            if not (configs["dynamics"]['support_size']-1) % 2 == 0: 
                print("[NETWORK - Dynamics] Support Size invalid. Set to default = {}.".format(SUPPORT_SIZE_DEFAULT))
                configs["dynamics"]['support_size'] = SUPPORT_SIZE_DEFAULT

        return configs

    def representation(self, image: torch.Tensor) -> torch.Tensor:
        state_representation = self.representation_network(image)
        orig_shape = state_representation.shape
        # Scale image along each channel
        max_per_channel = state_representation.view(
            orig_shape[0],
            orig_shape[1],
            -1
        ).max(2, keepdim=True)[0].unsqueeze(-1)
        min_per_channel = state_representation.view(
            orig_shape[0],
            orig_shape[1],
            -1
        ).min(2, keepdim=True)[0].unsqueeze(-1)
        scale = max_per_channel - min_per_channel
        scale[scale <= 0] += 1e-5
        return (state_representation - min_per_channel) / scale
    
    def prediction(self, encoded_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Predict via state the policy logits and value function
        return self.prediction_network(encoded_state)
    
    def dynamics(self, encoded_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Encode the action
        enc_state_shape = encoded_state.shape
        encoded_action = torch.zeros((enc_state_shape[0], 1, enc_state_shape[2], enc_state_shape[3])) / (self.action_space_size-action)
        encoded_action = encoded_action * action[:,:,None,None] / self.action_space_size
        encoded_action = encoded_action.to(action.device)
        encoded_full_state = torch.cat((encoded_state, encoded_action), dim=1)

        print(encoded_full_state.shape)

        state_representation, logits_reward = self.dynamics_network(encoded_full_state)
        orig_shape = state_representation.shape
        # Scale image along each channel
        max_per_channel = state_representation.view(
            orig_shape[0],
            orig_shape[1],
            -1
        ).max(2, keepdim=True)[0].unsqueeze(-1)
        min_per_channel = state_representation.view(
            orig_shape[0],
            orig_shape[1],
            -1
        ).min(2, keepdim=True)[0].unsqueeze(-1)
        scale = max_per_channel - min_per_channel
        scale[scale <= 0] += 1e-5

        return (state_representation - min_per_channel) / scale, logits_reward
  
    def initial_inference(self, image: torch.Tensor) -> NetworkOutput:
        # representation + prediction function
        state_representation = self.representation(image)
        logits_value, logits_policy = self.prediction(state_representation)

        logits_reward = torch.ones(image.shape[0], self.prediction_support_size) * -float("inf")
        logits_reward[:, self.prediction_support_size//2] = 0.0

        return NetworkOutput(logits_value, logits_reward, logits_policy, state_representation)

    def recurrent_inference(self, hidden_state: torch.Tensor, action: torch.Tensor) -> NetworkOutput:
        # dynamics + prediction function
        next_state, logits_reward = self.dynamics(hidden_state, action)
        logits_value, logits_policy = self.prediction(next_state)

        return NetworkOutput(logits_value, logits_reward, logits_policy, next_state)

    def get_weights(self):
        # Returns the weights of this network.
        return []

    def training_steps(self) -> int:
        # How many steps / batches the network has been trained for.
        return 0
    
    def from_output_to_scalar(self, network_output: NetworkOutput, softmax=False, type_output="prediction"):
        network_output.value = self.from_support_to_scalar(network_output.value, 
                                                      self.prediction_support_size if type_output == "prediction" else self.dynamics_support_size)
        network_output.reward = self.from_support_to_scalar(network_output.reward, 
                                                      self.prediction_support_size if type_output == "prediction" else self.dynamics_support_size)
        if softmax: network_output.policy_logits = torch.nn.functional.softmax(network_output.policy_logits, dim=1)
        return network_output
    
    def from_support_to_scalar(self, weights: torch.Tensor, support_size: int) -> torch.Tensor:
        # Get value for each support
        support_vector = torch.arange(-(support_size-1)//2, (support_size-1)//2+1).expand(weights.shape).float().to(weights.device)
        w_softmax = torch.nn.functional.softmax(weights, dim=-1)
        result = torch.sum(support_vector*w_softmax, dim=1, keepdim=True) # Keep dims make it N x D -> N x 1
        # Result is trained with a scaling function h(x), apply it inversely
        return inverse_h(result)


def h(x: torch.Tensor, eps = 1e-2) -> torch.Tensor:
    elem = torch.sqrt(torch.sign(x)+1) - 1
    return torch.sign(x) * elem + eps * x

def inverse_h(x: torch.Tensor, eps = 1e-2) -> torch.Tensor:
    elem = torch.abs(x) + 1 + eps
    elem = torch.sqrt(1 + 4 * eps * elem) - 1
    elem = ((elem / 2 * eps) ** 2) - 1
    return torch.sign(x) * elem


# ![](images/inverseh.png)

In [6]:
muzeronet = MuZeroNetwork(3, 42, 7)

In [7]:
# Test representation
image = torch.rand((1,3,6,7))
print(image)
encoded_state = muzeronet.representation(image)

tensor([[[[0.0386, 0.2679, 0.6434, 0.3073, 0.8630, 0.3129, 0.4499],
          [0.4925, 0.2986, 0.3502, 0.7347, 0.1663, 0.4028, 0.7979],
          [0.6699, 0.4449, 0.0810, 0.9388, 0.5247, 0.7416, 0.2737],
          [0.1082, 0.8635, 0.9558, 0.2710, 0.5138, 0.7743, 0.6261],
          [0.2529, 0.4488, 0.0707, 0.1945, 0.8124, 0.0782, 0.0609],
          [0.4100, 0.6586, 0.2137, 0.4823, 0.1103, 0.5244, 0.7674]],

         [[0.2759, 0.8199, 0.7194, 0.1569, 0.5720, 0.8551, 0.8672],
          [0.0933, 0.3919, 0.8842, 0.7446, 0.6284, 0.6497, 0.1409],
          [0.9564, 0.7636, 0.0699, 0.7818, 0.6452, 0.6021, 0.8207],
          [0.6475, 0.8894, 0.8558, 0.5189, 0.8223, 0.4391, 0.8745],
          [0.2931, 0.5125, 0.9808, 0.2686, 0.1246, 0.3248, 0.5832],
          [0.8053, 0.3317, 0.6834, 0.5776, 0.2721, 0.7754, 0.6393]],

         [[0.3497, 0.2233, 0.3920, 0.1998, 0.8273, 0.5209, 0.9985],
          [0.2892, 0.8194, 0.2255, 0.3047, 0.6062, 0.0439, 0.5854],
          [0.4943, 0.4999, 0.2096, 0.6426, 0

In [8]:
# Test dynamics
action = torch.Tensor([2]).view(1,1)
print(action, action.shape)
muzeronet.dynamics(encoded_state=encoded_state, action=action)

tensor([[2.]]) torch.Size([1, 1])
torch.Size([1, 257, 6, 7])


(tensor([[[[2.1336e-01, 2.1317e-01, 5.1032e-01,  ..., 9.9445e-03,
            1.3274e-02, 1.2339e-01],
           [4.6829e-02, 3.4161e-03, 7.2424e-02,  ..., 0.0000e+00,
            5.9160e-02, 3.3592e-01],
           [1.7596e-01, 3.4086e-02, 2.7011e-01,  ..., 0.0000e+00,
            1.0000e+00, 0.0000e+00],
           [1.2625e-01, 1.7195e-01, 7.9518e-01,  ..., 0.0000e+00,
            4.0509e-01, 0.0000e+00],
           [1.1800e-01, 4.5758e-02, 2.5525e-01,  ..., 5.0373e-01,
            2.6358e-01, 4.4332e-01],
           [4.0323e-01, 4.6812e-01, 2.1140e-01,  ..., 1.8294e-01,
            1.6341e-01, 2.4817e-02]],
 
          [[3.1781e-02, 2.5497e-01, 5.8762e-01,  ..., 0.0000e+00,
            2.3663e-01, 1.6260e-01],
           [4.3784e-01, 3.8282e-01, 8.6672e-02,  ..., 8.5371e-01,
            1.0000e+00, 2.6081e-01],
           [1.5693e-01, 1.5369e-01, 5.8974e-01,  ..., 5.9598e-01,
            2.7737e-01, 4.7345e-01],
           [5.8503e-01, 1.5988e-01, 3.1179e-01,  ..., 1.0412e-01,
    

In [9]:
# Test prediction
muzeronet.prediction(encoded_state=encoded_state)

(tensor([[ 4.5555e-01,  4.7511e-01, -3.8508e-01,  2.3560e-01, -6.3837e-01,
          -1.5231e-01, -3.7004e-01,  1.2284e-01, -4.0252e-01, -2.4308e-01,
          -3.7495e-01,  9.5926e-01, -4.5886e-02,  6.4188e-02,  4.8787e-01,
           2.1640e-02,  1.6183e-01, -7.6442e-01,  5.7617e-01, -3.4079e-01,
           2.7731e-01,  6.8389e-01, -9.9300e-01, -4.9711e-01,  2.7927e-01,
           3.2501e-02,  2.8765e-01, -6.9023e-01,  8.7574e-01,  4.2982e-01,
           6.8780e-02, -5.7019e-01, -6.4011e-01,  1.3984e-01, -3.5868e-01,
          -7.6078e-01,  4.3228e-01,  6.1998e-01,  5.2586e-01, -5.5810e-01,
          -5.4709e-01,  2.1746e-01, -6.0428e-01, -6.6087e-01,  3.6234e-01,
          -8.6921e-01, -2.3995e-02,  1.1392e+00, -1.0026e-01, -2.1926e-01,
          -7.4684e-02,  4.3344e-01,  9.8553e-02, -6.4525e-01, -1.4633e-01,
           9.5027e-02,  2.7810e-01, -1.0050e-01, -1.3950e-01, -3.2122e-02,
          -1.8442e-01, -3.3964e-01,  2.0572e-01,  3.2133e-01, -1.4949e-01,
          -4.0256e-01,  9

In [10]:
# Test network units
muzeronet.recurrent_inference(encoded_state, action)
muzeronet.initial_inference(image).policy_logits

torch.Size([1, 257, 6, 7])


tensor([[ 0.6114,  0.1041, -0.3909,  0.7535, -0.1721,  0.0511,  0.0193]],
       grad_fn=<AddmmBackward0>)

In [42]:
muzeronet.from_support_to_scalar(muzeronet.initial_inference(image).reward, muzeronet.prediction_support_size)

tensor([[-0.]])

In [None]:
from architecture.game import Game
from typing import List

MAXIMUM_FLOAT_VALUE = float("inf")

class MinMaxStats(object):
  """A class that holds the min-max values of the tree."""

  def __init__(self, known_bounds: Tuple = None):
    assert known_bounds is None or len(known_bounds) == 2
    self.maximum = known_bounds[1] if known_bounds else -MAXIMUM_FLOAT_VALUE
    self.minimum = known_bounds[0] if known_bounds else MAXIMUM_FLOAT_VALUE

  def update(self, value: float):
    self.maximum = max(self.maximum, value)
    self.minimum = min(self.minimum, value)

  def normalize(self, value: float) -> float:
    if self.maximum > self.minimum:
      # We normalize only when we have set the maximum and minimum values.
      return (value - self.minimum) / (self.maximum - self.minimum)
    return value

class Node:

    def __init__(self, prior: float):
        self.visit_count = 0
        self.player = None
        self.prior = prior
        self.value_sum = 0
        self.children = {}
        self.hidden_state = None
        self.reward = 0

    def expanded(self) -> bool:
        return len(self.children) > 0

    def value(self) -> float:
        if self.visit_count == 0:
           return 0
        return self.value_sum / self.visit_count

class MuZeroConfig:

    def __init__(self, game_class, game_configs: dict = {},
                 history_len = 7, max_moves = 30,
                 root_dirichlet_alpha = 0.3, root_exploration_factor = 0.25,
                 known_bounds = None, c1 = 1.25, c2 = 19.652, num_simulations = 100,
                 discount = 0.99):
        self.game_class = game_class
        self.game_configs = game_configs

        self.history_len = history_len
        self.max_moves = max_moves
        self.root_dirichlet_alpha = root_dirichlet_alpha
        self.root_exploration_factor = root_exploration_factor
        self.known_bounds = known_bounds
        self.c1 = c1
        self.c2 = c2
        self.num_simulations = num_simulations
        self.discount = discount

    def new_game(self):
        return Game(self.game_class(**self.game_configs), self.history_len) 
    
    def temperature_value(self, num_counts: int) -> float:
        if num_counts < self.max_moves:
            return 1.0
        else:
            return 0.0

def play_game(config: MuZeroConfig, network: MuZeroNetwork) -> Game:
    game = config.new_game()

    while not game.terminal() and len(game.history) < config.max_moves:
        # At the root of the search tree we use the representation function to
        # obtain a hidden state given the current observation.
        root = Node(0)
        current_observation = game.make_image(-1)
        expand_node(root, game.to_play(), game.action_mask(),
                    network.from_output_to_scalar(network.initial_inference(current_observation), softmax=False))
        add_exploration_noise(config, root)

        # We then run a Monte Carlo Tree Search using only action sequences and the
        # model learned by the network.
        run_mcts(config, root, game.action_history(), game.to_play(), game.action_space(mask=True), network)
        action = select_action(config, len(game.history), root, network)
        game.step(action)
        game.store_search_statistics(root)
    return game
    
def expand_node(node: Node, player: str, action_mask: torch.Tensor, initial_inference: NetworkOutput):
    node.player = player
    node.hidden_state = initial_inference.hidden_state
    node.reward = initial_inference.reward

    with torch.no_grad():
        action_probabilities = initial_inference.policy_logits.clone()
        action_probabilities[action_mask.view(action_probabilities.shape[0], -1) == 0] = -MAXIMUM_FLOAT_VALUE
        action_probabilities = torch.nn.functional.softmax(action_probabilities, dim=1)

        for action in torch.arange(0, action_mask.shape[0])[action_mask==1]:
            node.children[action] = Node(action_probabilities[:, action])

def add_exploration_noise(config: MuZeroConfig, node: Node):
    noise = np.random.dirichlet([config.root_dirichlet_alpha]*len(node.children)) * config.root_exploration_factor
    for action in node.children.keys():
        node[action].prior = (1-config.root_exploration_factor) * node[action].prior + noise[action]

def run_mcts(config: MuZeroConfig, root: Node, action_history_global: List, player: str, full_action_mask: np.array, network: MuZeroNetwork):
    min_max = MinMaxStats(config.known_bounds)

    for _ in range(config.num_simulations):
        # Selection
        # Select the node based on a UCB formula
        action_history = action_history_global.clone()
        node = root
        visited_childs: List[Node] = []
        while node.expanded():
            visited_childs.append(node)
            action, node = select_child(config, node, min_max)
            action_history.append(action)

        # Expansion and Simulation
        # Expand selected node
        last_hidden_state = visited_childs[-1].hidden_state
        network_output = network.from_output_to_scalar(network.recurrent_inference(last_hidden_state, action_history[-1]), softmax=False)
        expand_node(node, player, full_action_mask, network_output)
        visited_childs.append(node)

        # Backpropagation
        backpropagate(config, visited_childs, player, network_output.value)
        
def select_child(config: MuZeroConfig, node: Node, min_max: MinMaxStats):
    # For each node, calculate ucb
    _, action, child = max((ucb_score(config, node, child, min_max), action, child) for action, child in node.children.items())
    return action, child

def ucb_score(config: MuZeroConfig, parent: Node, node: Node, min_max: MinMaxStats):
    probability_term = node.prior * np.sqrt(parent.visit_count) / (node.visit_count + 1)
    probability_term *= config.c1 + np.log((parent.visit_count + config.c2 + 1) / config.c2)
    return min_max.normalize(node.value()) + probability_term

def backpropagate(config: MuZeroConfig, visited_childs: List, player: str, value: float, min_max: MinMaxStats):
    for child in visited_childs[::-1]:
        child.value_sum += value if child.to_play() == player else -value
        child.visit_count += 1
        min_max.update(child.value())

        value = child.reward + config.discount * value

def select_action(config: MuZeroConfig, num_actions: int, node: Node, network: MuZeroNetwork):
    temperature = config.temperature_value(num_actions)
    actions, counts = list(zip(*[(action, child.visit_count) for action, child in node.children.items()]))
    counts = np.array(counts) ** 1/temperature
    prob_distribution = counts / np.sum(counts)
    return np.random.choice(actions, p=prob_distribution)

    


In [55]:
temperature = 1.0
actions, counts= list(zip(*[(1,10),(2,30),(4,20)]))
counts = np.array(counts) ** 1/temperature
prob_distribution = counts / np.sum(counts)
np.random.choice(actions, p=prob_distribution)

4