In [1]:
import torch
import gym as g
from gym import spaces
from connect4 import *
from envs import ConnectNEnv
from networks.architecture import RepresentationNetwork, DynamicsNetwork, PredictionNetwork
import numpy as np

In [2]:
env = ConnectNEnv()

In [3]:
# Test
env.step(0)

({'observations': array([[0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0]], dtype=int8),
  'action_mask': array([1, 1, 1, 1, 1, 1, 1], dtype=int8),
  'player_1_board': array([[0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0]], dtype=int8),
  'player_2_board': array([[0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0]], dtype=int8),
  'current_player': 'P2'},
 0.0,
 False,
 {})

In [4]:
a = {'b': 1, 'c': 2}
def foo(b,c):
    return b+c

foo(**a)

3

In [5]:
import typing
from typing import Tuple
import utils as ut

SUPPORT_SIZE_DEFAULT = 601
ENCODED_CHANNELS = 256

class NetworkOutput(typing.NamedTuple):
    value: torch.Tensor
    reward: torch.Tensor
    policy_logits: torch.Tensor
    hidden_state: torch.Tensor

class MuZeroNetwork:

    def __init__(self, num_of_features, board_total_slots, n_possible_actions, configs=None):

        # Setup configs
        configs = self.default_configs(configs)


        self.representation_network = RepresentationNetwork(in_channels=num_of_features,
                                                            **configs['representation'])

        self.prediction_network = PredictionNetwork(in_channels=configs['representation']['n_channels'], 
                                                    board_total_slots=board_total_slots,
                                                    action_space_size=n_possible_actions,
                                                    **configs['prediction'])
        
        self.dynamics_network = DynamicsNetwork(in_channels=configs['representation']['n_channels']+1,
                                                  board_total_slots=board_total_slots,
                                                  **configs['dynamics'])
        
        self.prediction_support_size = configs['prediction']['support_size']
        self.dynamics_support_size = configs['dynamics']['support_size']
        self.action_space_size = n_possible_actions

    def default_configs(self, configs):
        if configs is None:
            configs = {"prediction": {}, "representation": {}, "dynamics": {}}
        # Prediction Network
        prediction = {
            "n_convs": 2,
            "n_channels": ENCODED_CHANNELS,
            "n_residual_layers": 10,
            "kernel_size": (3,3),
            "support_size": SUPPORT_SIZE_DEFAULT
        }
        if "prediction" not in configs: 
            configs["prediction"] = prediction
        else:
            ut.fill_defaults(configs["prediction"], prediction)
            # Check if Support Size is ok
            if not (configs["prediction"]['support_size']-1) % 2 == 0: 
                print("[NETWORK - Prediction] Support Size invalid. Set to default = {}.".format(SUPPORT_SIZE_DEFAULT))
                configs["prediction"]['support_size'] = SUPPORT_SIZE_DEFAULT
        # Representation Network
        representation = {
            "n_channels": ENCODED_CHANNELS,
            "n_residual_layers": 10,
            "kernel_size": (3,3)
        }
        if "representation" not in configs: 
            configs["representation"] = representation
        else:
            ut.fill_defaults(configs["representation"], representation)
        # Dynamics Network
        dynamics = {
            "n_convs": 2,
            "n_channels": ENCODED_CHANNELS,
            "n_residual_layers": 10,
            "kernel_size": (3,3),
            "support_size": SUPPORT_SIZE_DEFAULT
        }
        if "dynamics" not in configs: 
            configs["dynamics"] = dynamics
        else:
            ut.fill_defaults(configs["dynamics"], dynamics)
            # Check if Support Size is ok
            if not (configs["dynamics"]['support_size']-1) % 2 == 0: 
                print("[NETWORK - Dynamics] Support Size invalid. Set to default = {}.".format(SUPPORT_SIZE_DEFAULT))
                configs["dynamics"]['support_size'] = SUPPORT_SIZE_DEFAULT

        return configs

    def representation(self, image: torch.Tensor) -> torch.Tensor:
        state_representation = self.representation_network(image)
        orig_shape = state_representation.shape
        # Scale image along each channel
        max_per_channel = state_representation.view(
            orig_shape[0],
            orig_shape[1],
            -1
        ).max(2, keepdim=True)[0].unsqueeze(-1)
        min_per_channel = state_representation.view(
            orig_shape[0],
            orig_shape[1],
            -1
        ).min(2, keepdim=True)[0].unsqueeze(-1)
        scale = max_per_channel - min_per_channel
        scale[scale <= 0] += 1e-5
        return (state_representation - min_per_channel) / scale
    
    def prediction(self, encoded_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Predict via state the policy logits and value function
        return self.prediction_network(encoded_state)
    
    def dynamics(self, encoded_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Encode the action
        enc_state_shape = encoded_state.shape
        encoded_action = torch.zeros((enc_state_shape[0], 1, enc_state_shape[2], enc_state_shape[3])) / (self.action_space_size-action)
        encoded_action = encoded_action * action[:,:,None,None] / self.action_space_size
        encoded_action = encoded_action.to(action.device)
        encoded_full_state = torch.cat((encoded_state, encoded_action), dim=1)

        state_representation, logits_reward = self.dynamics_network(encoded_full_state)
        orig_shape = state_representation.shape
        # Scale image along each channel
        max_per_channel = state_representation.view(
            orig_shape[0],
            orig_shape[1],
            -1
        ).max(2, keepdim=True)[0].unsqueeze(-1)
        min_per_channel = state_representation.view(
            orig_shape[0],
            orig_shape[1],
            -1
        ).min(2, keepdim=True)[0].unsqueeze(-1)
        scale = max_per_channel - min_per_channel
        scale[scale <= 0] += 1e-5

        return (state_representation - min_per_channel) / scale, logits_reward
  
    def initial_inference(self, image: torch.Tensor) -> NetworkOutput:
        # representation + prediction function
        state_representation = self.representation(image)
        logits_value, logits_policy = self.prediction(state_representation)

        logits_reward = torch.ones(image.shape[0], self.prediction_support_size) * -float("inf")
        logits_reward[:, self.prediction_support_size//2] = 0.0

        return NetworkOutput(logits_value, logits_reward, logits_policy, state_representation)

    def recurrent_inference(self, hidden_state: torch.Tensor, action: torch.Tensor) -> NetworkOutput:
        # dynamics + prediction function
        next_state, logits_reward = self.dynamics(hidden_state, action)
        logits_value, logits_policy = self.prediction(next_state)

        return NetworkOutput(logits_value, logits_reward, logits_policy, next_state)

    def get_weights(self):
        # Returns the weights of this network.
        return []

    def training_steps(self) -> int:
        # How many steps / batches the network has been trained for.
        return 0
    
    def from_output_to_scalar(self, network_output: NetworkOutput, softmax=False, type_output="prediction"):
        value = self.from_support_to_scalar(network_output.value, 
                                                      self.prediction_support_size if type_output == "prediction" else self.dynamics_support_size)
        reward = self.from_support_to_scalar(network_output.reward, 
                                                      self.prediction_support_size if type_output == "prediction" else self.dynamics_support_size)
        if softmax: policy_logits = torch.nn.functional.softmax(network_output.policy_logits, dim=1)
        else: policy_logits = network_output.policy_logits
        return NetworkOutput(value, reward, policy_logits, network_output.hidden_state)
    
    def from_support_to_scalar(self, weights: torch.Tensor, support_size: int) -> torch.Tensor:
        # Get value for each support
        support_vector = torch.arange(-(support_size-1)//2, (support_size-1)//2+1).expand(weights.shape).float().to(weights.device)
        w_softmax = torch.nn.functional.softmax(weights, dim=-1)
        result = torch.sum(support_vector*w_softmax, dim=1, keepdim=True) # Keep dims make it N x D -> N x 1
        # Result is trained with a scaling function h(x), apply it inversely
        return inverse_h(result)


def h(x: torch.Tensor, eps = 1e-2) -> torch.Tensor:
    elem = torch.sqrt(torch.sign(x)+1) - 1
    return torch.sign(x) * elem + eps * x

def inverse_h(x: torch.Tensor, eps = 1e-2) -> torch.Tensor:
    elem = torch.abs(x) + 1 + eps
    elem = torch.sqrt(1 + 4 * eps * elem) - 1
    elem = ((elem / 2 * eps) ** 2) - 1
    return torch.sign(x) * elem


# ![](images/inverseh.png)

In [6]:
muzeronet = MuZeroNetwork(3, 42, 7)

In [7]:
# Test representation
image = torch.rand((1,3,6,7))
print(image)
encoded_state = muzeronet.representation(image)

tensor([[[[0.2793, 0.3484, 0.6801, 0.6109, 0.0728, 0.4982, 0.2352],
          [0.5097, 0.6615, 0.9096, 0.8316, 0.2216, 0.8873, 0.1291],
          [0.6508, 0.6248, 0.6954, 0.8079, 0.2231, 0.5972, 0.5354],
          [0.8676, 0.8135, 0.3035, 0.0280, 0.7482, 0.3695, 0.4282],
          [0.9611, 0.8446, 0.1314, 0.7810, 0.8367, 0.4655, 0.8060],
          [0.2959, 0.4835, 0.6410, 0.7972, 0.1725, 0.5003, 0.2296]],

         [[0.0675, 0.8370, 0.9388, 0.4957, 0.1592, 0.8173, 0.7731],
          [0.8025, 0.1598, 0.4814, 0.8783, 0.7965, 0.1854, 0.7862],
          [0.9774, 0.5251, 0.4034, 0.6666, 0.7908, 0.1842, 0.2214],
          [0.8769, 0.8494, 0.1338, 0.8129, 0.7374, 0.9001, 0.9285],
          [0.3529, 0.6490, 0.1805, 0.1140, 0.4561, 0.1123, 0.0419],
          [0.3973, 0.0233, 0.7207, 0.7013, 0.3009, 0.0132, 0.9424]],

         [[0.7181, 0.9676, 0.2415, 0.5392, 0.3161, 0.1329, 0.4151],
          [0.0890, 0.5870, 0.7857, 0.1411, 0.7410, 0.3931, 0.1935],
          [0.5513, 0.0143, 0.8901, 0.8968, 0

In [8]:
# Test dynamics
action = torch.Tensor([2]).view(1,1)
print(action, action.shape)
muzeronet.dynamics(encoded_state=encoded_state, action=action)

tensor([[2.]]) torch.Size([1, 1])


(tensor([[[[0.5175, 0.3643, 0.0323,  ..., 0.1118, 0.1513, 0.1049],
           [0.7420, 0.2397, 0.2805,  ..., 0.0000, 0.2620, 0.2658],
           [0.3583, 0.2325, 0.6319,  ..., 0.0000, 1.0000, 0.6708],
           [0.0000, 0.3259, 0.0518,  ..., 0.2258, 0.4766, 0.1939],
           [0.2224, 0.4339, 0.2925,  ..., 0.5070, 0.0025, 0.0921],
           [0.1938, 0.3775, 0.0000,  ..., 0.1066, 0.1329, 0.0897]],
 
          [[0.2286, 0.4749, 0.1027,  ..., 0.1579, 0.3901, 0.3258],
           [0.0833, 0.0000, 0.1507,  ..., 0.1810, 0.6664, 0.0273],
           [0.0919, 0.3583, 0.0000,  ..., 0.0000, 0.0662, 0.0000],
           [0.0073, 0.8034, 0.2324,  ..., 0.0000, 0.1834, 0.1025],
           [0.3590, 0.2252, 1.0000,  ..., 0.0000, 0.0130, 0.0000],
           [0.2111, 0.1974, 0.2062,  ..., 0.0000, 0.0000, 0.0311]],
 
          [[0.0000, 0.0516, 0.0046,  ..., 0.0000, 0.0000, 0.1962],
           [0.1012, 0.5061, 0.5874,  ..., 0.1837, 0.0000, 0.2807],
           [0.2496, 0.4617, 0.0000,  ..., 0.0753, 0.0405

In [9]:
# Test prediction
muzeronet.prediction(encoded_state=encoded_state)

(tensor([[-3.6757e-02,  5.4064e-02, -1.0949e-01, -1.8319e-01, -6.3007e-01,
           2.0709e-01, -7.0721e-01, -1.6107e-01, -8.3084e-01, -4.8846e-02,
          -6.3158e-02, -3.8812e-01,  2.7220e-02,  8.0743e-01,  1.4291e-01,
          -3.8848e-01, -1.0129e-01, -1.5350e-01,  4.1856e-01,  5.6221e-01,
          -2.4436e-02, -4.1840e-02,  1.1696e-01, -1.0826e+00, -2.7318e-01,
           4.4613e-01,  4.6218e-01,  7.0496e-01, -4.9479e-02,  6.4129e-02,
           3.0242e-01, -7.4272e-02,  3.9882e-01,  5.8185e-01, -2.2959e-01,
           8.8912e-02,  1.9694e-02,  2.5652e-01, -5.6458e-01,  1.8985e-01,
           1.7924e-01,  5.1286e-01,  5.5696e-02, -1.2432e-01,  5.5975e-01,
           5.3409e-02, -9.2141e-01, -3.7296e-01,  7.9450e-02,  7.5594e-01,
           1.1389e-01,  6.3021e-01, -7.0905e-01,  2.1813e-01, -5.6908e-01,
          -3.0731e-02, -4.1262e-01, -1.9667e-01,  4.3636e-02,  1.2757e-02,
          -3.8744e-01, -3.5967e-01, -4.4385e-01, -5.9582e-01, -6.8166e-01,
           4.2717e-01,  2

In [10]:
# Test network units
muzeronet.recurrent_inference(encoded_state, action)
muzeronet.initial_inference(image).policy_logits

tensor([[-0.3521, -0.5187,  0.6819, -0.4228, -0.1884,  0.0488,  0.4418]],
       grad_fn=<AddmmBackward0>)

In [11]:
muzeronet.from_support_to_scalar(muzeronet.initial_inference(image).reward, muzeronet.prediction_support_size)

tensor([[-0.]])

In [12]:
from architecture.game import Game
from typing import List

MAXIMUM_FLOAT_VALUE = float("inf")

class MinMaxStats(object):
  """A class that holds the min-max values of the tree."""

  def __init__(self, known_bounds: Tuple = None):
    assert known_bounds is None or len(known_bounds) == 2
    self.maximum = known_bounds[1] if known_bounds else -MAXIMUM_FLOAT_VALUE
    self.minimum = known_bounds[0] if known_bounds else MAXIMUM_FLOAT_VALUE

  def update(self, value: float):
    self.maximum = max(self.maximum, value)
    self.minimum = min(self.minimum, value)

  def normalize(self, value: float) -> float:
    if self.maximum > self.minimum:
      # We normalize only when we have set the maximum and minimum values.
      return (value - self.minimum) / (self.maximum - self.minimum)
    return value

class Node:

    def __init__(self, prior: float):
        self.visit_count = 0
        self.player = None
        self.prior = prior
        self.value_sum = 0
        self.children = {}
        self.hidden_state = None
        self.reward = 0

    def expanded(self) -> bool:
        return len(self.children) > 0

    def value(self) -> float:
        if self.visit_count == 0:
           return 0
        return self.value_sum / self.visit_count

class MuZeroConfig:

    def __init__(self, game_class, game_configs: dict = {},
                 history_len = 7, max_moves = 30,
                 root_dirichlet_alpha = 0.3, root_exploration_factor = 0.25,
                 known_bounds = None, c1 = 1.25, c2 = 19.652, num_simulations = 100,
                 discount = 0.99):
        self.game_class = game_class
        self.game_configs = game_configs

        self.history_len = history_len
        self.max_moves = max_moves
        self.root_dirichlet_alpha = root_dirichlet_alpha
        self.root_exploration_factor = root_exploration_factor
        self.known_bounds = known_bounds
        self.c1 = c1
        self.c2 = c2
        self.num_simulations = num_simulations
        self.discount = discount

    def new_game(self):
        return Game(self.game_class(**self.game_configs), self.history_len) 
    
    def temperature_value(self, num_counts: int) -> float:
        if num_counts < self.max_moves:
            return 1.0
        else:
            return 0.0
        
class MuZeroConnectN(MuZeroNetwork):

    def __init__(self, width = 6, length = 7, len_features = 7, configs = {}):
        super(MuZeroConnectN, self).__init__(len_features*2, width*length, length, configs)

from copy import deepcopy

def play_game(config: MuZeroConfig, network: MuZeroNetwork) -> Game:
    game = config.new_game()

    while not game.terminal() and len(game.history) < config.max_moves:

        # At the root of the search tree we use the representation function to
        # obtain a hidden state given the current observation.
        root = Node(0)
        current_observation = game.make_image(-1)
        expand_node(root, game.to_play(), game.action_mask(),
                    network.from_output_to_scalar(network.initial_inference(current_observation), softmax=False))
        add_exploration_noise(config, root)

        # We then run a Monte Carlo Tree Search using only action sequences and the
        # model learned by the network.
        run_mcts(config, root, game.action_history, game.to_play(), game.action_space(mask=True), network)
        
        print("Simulation finished. Current game history {}.".format(game.length_of_history()))
        game.present_game()
        action = select_action(config, game.length_of_history(), root, network)
        game.step(action)
        game.store_search_statistics(root)
    return game
    
def expand_node(node: Node, player: str, action_mask: torch.Tensor, initial_inference: NetworkOutput):
    node.player = player
    node.hidden_state = initial_inference.hidden_state
    node.reward = initial_inference.reward

    with torch.no_grad():
        action_probabilities = initial_inference.policy_logits.clone()
        action_probabilities[action_mask.view(action_probabilities.shape[0], -1) == 0] = -MAXIMUM_FLOAT_VALUE
        action_probabilities = torch.nn.functional.softmax(action_probabilities, dim=1)

        for action in np.arange(0, action_mask.shape[0])[action_mask==1]:
            node.children[action] = Node(action_probabilities[:, action].item())

def add_exploration_noise(config: MuZeroConfig, node: Node):
    noise = np.random.dirichlet([config.root_dirichlet_alpha]*len(node.children)) * config.root_exploration_factor
    for i, child in enumerate(node.children.values()):
        child.prior = (1-config.root_exploration_factor) * child.prior + noise[i]

def run_mcts(config: MuZeroConfig, root: Node, action_history_global: List, player: str, full_action_mask: torch.Tensor, network: MuZeroNetwork):
    min_max = MinMaxStats(config.known_bounds)

    for _ in range(config.num_simulations):
        # Selection
        # Select the node based on a UCB formula

        action_history = deepcopy(action_history_global)
        node = root
        visited_childs: List[Node] = []
        while node.expanded():
            visited_childs.append(node)
            action, node = select_child(config, node, min_max)
            action_history.append(action)

        # Expansion and Simulation
        # Expand selected node
        last_hidden_state = visited_childs[-1].hidden_state
        action = torch.Tensor([action_history[-1]]).view(1,1)
        network_output = network.from_output_to_scalar(network.recurrent_inference(last_hidden_state, action), softmax=False)
        expand_node(node, player, torch.Tensor(full_action_mask), network_output)
        visited_childs.append(node)

        # Backpropagation
        backpropagate(config, visited_childs, player, network_output.value, min_max)
        
def select_child(config: MuZeroConfig, node: Node, min_max: MinMaxStats):
    # For each node, calculate ucb
    _, action, child = max((ucb_score(config, node, child, min_max), action, child) for action, child in node.children.items())
    return action, child

def ucb_score(config: MuZeroConfig, parent: Node, node: Node, min_max: MinMaxStats):
    probability_term = node.prior * np.sqrt(parent.visit_count) / (node.visit_count + 1)
    probability_term *= config.c1 + np.log((parent.visit_count + config.c2 + 1) / config.c2)
    return min_max.normalize(node.value()) + probability_term

def backpropagate(config: MuZeroConfig, visited_childs: List, player: str, value: float, min_max: MinMaxStats):
    for child in visited_childs[::-1]:
        child.value_sum += value if child.player == player else -value
        child.visit_count += 1
        min_max.update(child.value())

        value = child.reward + config.discount * value

def select_action(config: MuZeroConfig, num_actions: int, node: Node, network: MuZeroNetwork):
    temperature = config.temperature_value(num_actions)
    actions, counts = list(zip(*[(action, child.visit_count) for action, child in node.children.items()]))
    if temperature != 0.0: counts = np.array(counts) ** 1/temperature
    else: return actions[np.argmax(counts)]
    prob_distribution = counts / np.sum(counts)
    return np.random.choice(actions, p=prob_distribution)


In [13]:
# Testing MinMax

minmax = MinMaxStats((1,3)) # minmax = MinMaxStats()
minmax.normalize(10)
# minmax.update(1)
# minmax.update(3)
minmax.normalize(190)

94.5

In [14]:
# testing Muzero config
from envs import ConnectNEnv

muzeroconfig = MuZeroConfig(ConnectNEnv, {})
muzeroconfig.new_game()
muzeroconfig.temperature_value(30)

0.0

In [15]:
# Testing atomic functions
node = Node(0)
node1 = Node(0)
node1.visit_count = 10
node2 = Node(0.1)
node2.visit_count = 20
node.children = {1: node1, 2: node2}
select_action(muzeroconfig, 20, node, muzeronet)

2

In [16]:
muzeronet = MuZeroConnectN()
play_game(muzeroconfig, muzeronet)

Simulation finished. Current game history 0.
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || || |||
ooooooooooooooooooooooooo
Simulation finished. Current game history 1.
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || || |||
||---------------------||
||| || ||1|| || || || |||
ooooooooooooooooooooooooo
Simulation finished. Current game history 2.
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || || |||
||---------------------||
||| || 

<architecture.game.Game at 0x291265450>

[tensor([6.])]