In [1]:
import torch
import gym as g
from gym import spaces
from connect4 import *
from envs import ConnectNEnv
from networks.architecture import RepresentationNetwork, DynamicsNetwork, PredictionNetwork
import numpy as np

In [2]:
env = ConnectNEnv()

In [3]:
# Test
env.step(0)

({'observations': array([[ 0,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  0,  0,  0,  0],
         [-1,  0,  0,  0,  0,  0,  0],
         [ 1,  0,  0,  0,  0,  0,  0]], dtype=int8),
  'action_mask': array([1, 1, 1, 1, 1, 1, 1], dtype=int8)},
 0.0,
 False,
 {})

In [4]:
a = {'b': 1, 'c': 2}
def foo(b,c):
    return b+c

foo(**a)

3

In [5]:
import typing
from typing import Tuple
import utils as ut

SUPPORT_SIZE_DEFAULT = 601
ENCODED_CHANNELS = 256

class NetworkOutput(typing.NamedTuple):
  value: torch.Tensor
  reward: torch.Tensor
  policy_logits: torch.Tensor
  hidden_state: torch.Tensor


class MuZeroNetwork:

    def __init__(self, num_of_features, board_total_slots, n_possible_actions, configs=None):

        # Setup configs
        configs = self.default_configs(configs)


        self.representation_network = RepresentationNetwork(in_channels=num_of_features,
                                                            **configs['representation'])

        self.prediction_network = PredictionNetwork(in_channels=configs['representation']['n_channels'], 
                                                    board_total_slots=board_total_slots,
                                                    action_space_size=n_possible_actions,
                                                    **configs['prediction'])
        
        self.dynamics_network = DynamicsNetwork(in_channels=configs['representation']['n_channels']+1,
                                                  board_total_slots=board_total_slots,
                                                  **configs['dynamics'])
        
        self.prediction_support_size = configs['prediction']['support_size']
        self.dynamics_support_size = configs['dynamics']['support_size']
        self.action_space_size = n_possible_actions

    def default_configs(self, configs):
        if configs is None:
            configs = {"prediction": {}, "representation": {}, "dynamics": {}}
        # Prediction Network
        prediction = {
            "n_convs": 2,
            "n_channels": ENCODED_CHANNELS,
            "n_residual_layers": 10,
            "kernel_size": (3,3),
            "support_size": SUPPORT_SIZE_DEFAULT
        }
        if "prediction" not in configs: 
            configs["prediction"] = prediction
        else:
            ut.fill_defaults(configs["prediction"], prediction)
            # Check if Support Size is ok
            if not (configs["prediction"]['support_size']-1) % 2 == 0: 
                print("[NETWORK - Prediction] Support Size invalid. Set to default = {}.".format(SUPPORT_SIZE_DEFAULT))
                configs["prediction"]['support_size'] = SUPPORT_SIZE_DEFAULT
        # Representation Network
        representation = {
            "n_channels": ENCODED_CHANNELS,
            "n_residual_layers": 10,
            "kernel_size": (3,3)
        }
        if "representation" not in configs: 
            configs["representation"] = representation
        else:
            ut.fill_defaults(configs["representation"], representation)
        # Dynamics Network
        dynamics = {
            "n_convs": 2,
            "n_channels": ENCODED_CHANNELS,
            "n_residual_layers": 10,
            "kernel_size": (3,3),
            "support_size": SUPPORT_SIZE_DEFAULT
        }
        if "dynamics" not in configs: 
            configs["dynamics"] = dynamics
        else:
            ut.fill_defaults(configs["dynamics"], dynamics)
            # Check if Support Size is ok
            if not (configs["dynamics"]['support_size']-1) % 2 == 0: 
                print("[NETWORK - Dynamics] Support Size invalid. Set to default = {}.".format(SUPPORT_SIZE_DEFAULT))
                configs["dynamics"]['support_size'] = SUPPORT_SIZE_DEFAULT

        return configs

    def representation(self, image: torch.Tensor) -> torch.Tensor:
        state_representation = self.representation_network(image)
        orig_shape = state_representation.shape
        # Scale image along each channel
        max_per_channel = state_representation.view(
            orig_shape[0],
            orig_shape[1],
            -1
        ).max(2, keepdim=True)[0].unsqueeze(-1)
        min_per_channel = state_representation.view(
            orig_shape[0],
            orig_shape[1],
            -1
        ).min(2, keepdim=True)[0].unsqueeze(-1)
        scale = max_per_channel - min_per_channel
        scale[scale <= 0] += 1e-5
        return (state_representation - min_per_channel) / scale
    
    def prediction(self, encoded_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Predict via state the policy logits and value function
        return self.prediction_network(encoded_state)
    
    def dynamics(self, encoded_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Encode the action
        enc_state_shape = encoded_state.shape
        encoded_action = torch.zeros((enc_state_shape[0], 1, enc_state_shape[2], enc_state_shape[3])) / (self.action_space_size-action)
        encoded_action = encoded_action * action[:,:,None,None] / self.action_space_size
        encoded_action = encoded_action.to(action.device)
        encoded_full_state = torch.cat((encoded_state, encoded_action), dim=1)

        print(encoded_full_state.shape)

        state_representation, logits_reward = self.dynamics_network(encoded_full_state)
        orig_shape = state_representation.shape
        # Scale image along each channel
        max_per_channel = state_representation.view(
            orig_shape[0],
            orig_shape[1],
            -1
        ).max(2, keepdim=True)[0].unsqueeze(-1)
        min_per_channel = state_representation.view(
            orig_shape[0],
            orig_shape[1],
            -1
        ).min(2, keepdim=True)[0].unsqueeze(-1)
        scale = max_per_channel - min_per_channel
        scale[scale <= 0] += 1e-5

        return (state_representation - min_per_channel) / scale, logits_reward
  
    def initial_inference(self, image: torch.Tensor) -> NetworkOutput:
        # representation + prediction function
        state_representation = self.representation(image)
        logits_value, logits_policy = self.prediction(state_representation)

        logits_reward = torch.zeros(image.shape[0], self.prediction_support_size)

        return NetworkOutput(logits_value, logits_reward, logits_policy, state_representation)

    def recurrent_inference(self, hidden_state: torch.Tensor, action: torch.Tensor) -> NetworkOutput:
        # dynamics + prediction function
        next_state, logits_reward = self.dynamics(hidden_state, action)
        logits_value, logits_policy = self.prediction(next_state)

        return NetworkOutput(logits_value, logits_reward, logits_policy, next_state)

    def get_weights(self):
        # Returns the weights of this network.
        return []

    def training_steps(self) -> int:
        # How many steps / batches the network has been trained for.
        return 0
    
    def from_support_to_scalar(self, weights: torch.Tensor, support_size: int) -> torch.Tensor:
        # Get value for each support
        support_vector = torch.arange(-(support_size-1)//2, (support_size-1)//2).expand(weights.shape).float().to(weights.device)
        result = torch.sum(support_vector*weights, dim=1, keepdim=True) # Keep dims make it N x D -> N x 1
        # Result is trained with a scaling function h(x), apply it inversely
        return inverse_h(result)


def h(x: torch.Tensor) -> torch.Tensor:
    pass

def inverse_h(x: torch.Tensor, eps = 1e-2) -> torch.Tensor:
    elem = torch.abs(x) + 1 + eps
    elem = torch.sqrt(1 + 4 * eps * elem) - 1
    elem = ((elem / 2 * eps) ** 2) - 1
    return torch.sign(x) * elem



# ![](images/inverseh.png)

In [11]:
muzeronet = MuZeroNetwork(3, 42, 7)

In [12]:
# Test representation
image = torch.rand((1,3,6,7))
print(image)
encoded_state = muzeronet.representation(image)

tensor([[[[0.3396, 0.1705, 0.9067, 0.3015, 0.4843, 0.2493, 0.8091],
          [0.0954, 0.8586, 0.8146, 0.5099, 0.4827, 0.4314, 0.9615],
          [0.8459, 0.8526, 0.1366, 0.2439, 0.9187, 0.1951, 0.9152],
          [0.8723, 0.9504, 0.4819, 0.2677, 0.5266, 0.6212, 0.0015],
          [0.1731, 0.4920, 0.1660, 0.8713, 0.8609, 0.2624, 0.1184],
          [0.9195, 0.8847, 0.8362, 0.9872, 0.9564, 0.0052, 0.5384]],

         [[0.2679, 0.6632, 0.9216, 0.4113, 0.8169, 0.2212, 0.3664],
          [0.2072, 0.8230, 0.9791, 0.6515, 0.2288, 0.0188, 0.2261],
          [0.3840, 0.3647, 0.7040, 0.5691, 0.3923, 0.8403, 0.3259],
          [0.1336, 0.4195, 0.6025, 0.5909, 0.4972, 0.2397, 0.0983],
          [0.6398, 0.3640, 0.1469, 0.2483, 0.4789, 0.7458, 0.9003],
          [0.2490, 0.9299, 0.1819, 0.7008, 0.5127, 0.8322, 0.2555]],

         [[0.5280, 0.7950, 0.4844, 0.4556, 0.1177, 0.0922, 0.4482],
          [0.4907, 0.3488, 0.9549, 0.0738, 0.5507, 0.4129, 0.9673],
          [0.6971, 0.3102, 0.1684, 0.2828, 0

In [13]:
# Test dynamics
action = torch.Tensor([2]).view(1,1)
print(action, action.shape)
muzeronet.dynamics(encoded_state=encoded_state, action=action)

tensor([[2.]]) torch.Size([1, 1])
torch.Size([1, 257, 6, 7])
torch.Size([1, 64, 6, 7])


(tensor([[[[0.1217, 0.6850, 0.8082,  ..., 0.0000, 0.3831, 0.4769],
           [0.1029, 0.5352, 0.0297,  ..., 0.0050, 0.2202, 0.6478],
           [0.0000, 0.3143, 1.0000,  ..., 0.7915, 0.0026, 0.3885],
           [0.0000, 0.1737, 0.2516,  ..., 0.3246, 0.2991, 0.2592],
           [0.1659, 0.0000, 0.0000,  ..., 0.0000, 0.1414, 0.2332],
           [0.0000, 0.0158, 0.1331,  ..., 0.0481, 0.4544, 0.5205]],
 
          [[0.5004, 0.3580, 0.0000,  ..., 0.0999, 0.0000, 0.3197],
           [0.0115, 0.4902, 0.0723,  ..., 0.0000, 0.0830, 0.2309],
           [0.2399, 0.5105, 0.5426,  ..., 0.0000, 0.2799, 0.3287],
           [0.1589, 0.9887, 0.5815,  ..., 0.0000, 0.1023, 0.3007],
           [0.5100, 0.6150, 0.3675,  ..., 0.0000, 0.1369, 0.0000],
           [1.0000, 0.7945, 0.5523,  ..., 0.4332, 0.7778, 0.4109]],
 
          [[0.0000, 0.0000, 0.1399,  ..., 0.0303, 0.0870, 0.0000],
           [0.1868, 0.2726, 0.6370,  ..., 0.4075, 0.0000, 0.0000],
           [0.0000, 0.1975, 0.5502,  ..., 0.4425, 0.5733

In [14]:
# Test prediction
muzeronet.prediction(encoded_state=encoded_state)

torch.Size([1, 64, 6, 7])
torch.Size([1, 64, 6, 7])


(tensor([[0.0015, 0.0012, 0.0017, 0.0011, 0.0018, 0.0024, 0.0035, 0.0015, 0.0017,
          0.0022, 0.0021, 0.0020, 0.0008, 0.0022, 0.0013, 0.0010, 0.0020, 0.0017,
          0.0011, 0.0030, 0.0015, 0.0026, 0.0023, 0.0026, 0.0015, 0.0014, 0.0015,
          0.0013, 0.0012, 0.0008, 0.0017, 0.0018, 0.0017, 0.0010, 0.0014, 0.0017,
          0.0008, 0.0013, 0.0017, 0.0028, 0.0026, 0.0022, 0.0009, 0.0017, 0.0013,
          0.0019, 0.0014, 0.0010, 0.0012, 0.0009, 0.0017, 0.0029, 0.0017, 0.0019,
          0.0023, 0.0009, 0.0017, 0.0014, 0.0024, 0.0015, 0.0018, 0.0020, 0.0013,
          0.0010, 0.0012, 0.0017, 0.0013, 0.0011, 0.0031, 0.0010, 0.0010, 0.0015,
          0.0036, 0.0019, 0.0013, 0.0030, 0.0023, 0.0016, 0.0011, 0.0030, 0.0014,
          0.0017, 0.0010, 0.0028, 0.0023, 0.0020, 0.0021, 0.0017, 0.0009, 0.0021,
          0.0015, 0.0016, 0.0016, 0.0008, 0.0017, 0.0018, 0.0013, 0.0012, 0.0021,
          0.0011, 0.0010, 0.0014, 0.0016, 0.0021, 0.0015, 0.0007, 0.0025, 0.0009,
          0.0019

In [16]:
# Test network units
muzeronet.initial_inference(image)
muzeronet.recurrent_inference(encoded_state, action)

torch.Size([1, 64, 6, 7])
torch.Size([1, 64, 6, 7])
torch.Size([1, 257, 6, 7])
torch.Size([1, 64, 6, 7])
torch.Size([1, 64, 6, 7])
torch.Size([1, 64, 6, 7])


NetworkOutput(value=tensor([[0.0014, 0.0019, 0.0016, 0.0019, 0.0016, 0.0019, 0.0034, 0.0016, 0.0007,
         0.0017, 0.0009, 0.0033, 0.0015, 0.0019, 0.0020, 0.0011, 0.0015, 0.0015,
         0.0023, 0.0024, 0.0008, 0.0014, 0.0036, 0.0017, 0.0012, 0.0011, 0.0012,
         0.0014, 0.0010, 0.0009, 0.0024, 0.0030, 0.0027, 0.0021, 0.0009, 0.0024,
         0.0015, 0.0016, 0.0013, 0.0023, 0.0006, 0.0025, 0.0006, 0.0019, 0.0013,
         0.0020, 0.0025, 0.0019, 0.0019, 0.0015, 0.0014, 0.0020, 0.0018, 0.0015,
         0.0019, 0.0018, 0.0012, 0.0021, 0.0023, 0.0015, 0.0010, 0.0019, 0.0025,
         0.0008, 0.0009, 0.0024, 0.0011, 0.0011, 0.0017, 0.0012, 0.0013, 0.0016,
         0.0030, 0.0014, 0.0013, 0.0035, 0.0028, 0.0016, 0.0011, 0.0016, 0.0015,
         0.0017, 0.0015, 0.0019, 0.0023, 0.0036, 0.0025, 0.0022, 0.0011, 0.0016,
         0.0019, 0.0012, 0.0014, 0.0018, 0.0012, 0.0038, 0.0013, 0.0009, 0.0015,
         0.0030, 0.0014, 0.0023, 0.0012, 0.0022, 0.0009, 0.0007, 0.0038, 0.0009,
        