In [1]:
import torch
import gym as g
from gym import spaces
from connect4 import *
from envs import ConnectNEnv
from networks.architecture import RepresentationNetwork, DynamicsNetwork, PredictionNetwork
from architecture.game import Target
import numpy as np

In [2]:
from architecture.engine import *
from architecture.network import MuZeroNetwork

In [3]:
muzeronet = MuZeroNetwork(3, 42, 7)

In [4]:
# Test representation
image = torch.rand((1,3,6,7))
print(image)
encoded_state = muzeronet.representation(image)

tensor([[[[0.2797, 0.3040, 0.9832, 0.9988, 0.3749, 0.7835, 0.3174],
          [0.2785, 0.9220, 0.5011, 0.6089, 0.1654, 0.8720, 0.0912],
          [0.9946, 0.6199, 0.9503, 0.0078, 0.0275, 0.2187, 0.7940],
          [0.7271, 0.9436, 0.3944, 0.2669, 0.3033, 0.6880, 0.7317],
          [0.3456, 0.3858, 0.1502, 0.1603, 0.3491, 0.1643, 0.3143],
          [0.9120, 0.0264, 0.7435, 0.9332, 0.0969, 0.4670, 0.9831]],

         [[0.3059, 0.2775, 0.7557, 0.7167, 0.1140, 0.6954, 0.2232],
          [0.8256, 0.3419, 0.1594, 0.6122, 0.8190, 0.7771, 0.7111],
          [0.7063, 0.8975, 0.7128, 0.8875, 0.4645, 0.6768, 0.2822],
          [0.8053, 0.8893, 0.8285, 0.6905, 0.7734, 0.5998, 0.5583],
          [0.2472, 0.6939, 0.1773, 0.2455, 0.5866, 0.6015, 0.4868],
          [0.6315, 0.3113, 0.7685, 0.4551, 0.4608, 0.9445, 0.5995]],

         [[0.6289, 0.5063, 0.9205, 0.0549, 0.2008, 0.5700, 0.7619],
          [0.2182, 0.6626, 0.3760, 0.1435, 0.7542, 0.3110, 0.1492],
          [0.3473, 0.7272, 0.8303, 0.4252, 0

In [5]:
# Test dynamics
action = torch.Tensor([2]).view(1,1)
print(action, action.shape)
a: NetworkOutput = muzeronet.dynamics(encoded_state=encoded_state, action=action)

tensor([[2.]]) torch.Size([1, 1])


In [6]:
# Test prediction
muzeronet.prediction(encoded_state=encoded_state)

(tensor([[ 2.1049e-01,  1.3975e-01, -1.1045e-01, -8.5051e-01,  5.4529e-01,
           4.1885e-01,  7.6424e-02,  6.0599e-01,  2.6680e-01,  3.4567e-01,
           1.6376e-01, -1.2995e-01, -4.4439e-01,  5.0904e-02, -5.7539e-02,
           1.9377e-01, -1.2343e-01, -5.5293e-01,  1.0115e-01, -1.4732e-01,
           1.6315e-01, -6.1254e-01,  2.0189e-01, -3.5490e-01, -7.7985e-02,
           4.9717e-01, -4.0054e-01, -1.1507e+00, -3.8067e-01,  1.8780e-01,
          -3.0046e-01, -1.2041e-01, -3.9061e-01, -3.8826e-01,  1.0007e-01,
           2.3968e-02, -4.9382e-01,  9.8615e-02, -1.0473e-01,  2.7974e-01,
           2.2425e-01,  7.0073e-02,  6.6448e-01, -5.3587e-01,  9.3668e-01,
          -3.3355e-01, -7.4464e-01,  1.7596e-01,  6.5505e-03, -2.6111e-01,
           2.2988e-01, -1.0597e-01, -2.0222e-01, -1.1842e-01, -5.7459e-01,
          -1.6419e-01,  1.2231e-03, -2.5197e-01,  2.2678e-01, -2.3397e-01,
          -4.5855e-01,  1.4848e-02,  6.8383e-02, -3.3902e-01, -4.5281e-01,
           6.6314e-01, -1

In [7]:
# Test network units
ri = muzeronet.recurrent_inference(encoded_state, action)
ii = muzeronet.initial_inference(image)

In [8]:
ri.policy_logits.shape

torch.Size([1, 7])

In [9]:
muzeronet.from_support_to_scalar(muzeronet.initial_inference(image).reward, muzeronet.prediction_support_size)

tensor([[-0.]])

In [10]:
# Testing MinMax

minmax = MinMaxStats((1,3)) # minmax = MinMaxStats()
minmax.normalize(10)
# minmax.update(1)
# minmax.update(3)
minmax.normalize(190)

94.5

In [11]:
# testing Muzero config
from envs import ConnectNEnv

muzeroconfig = MuZeroConfig(ConnectNEnv, {})
muzeroconfig.new_game()
muzeroconfig.temperature_value(30)

0.0

In [12]:
# Testing atomic functions
node = Node(0)
node1 = Node(0)
node1.visit_count = 10
node2 = Node(0.1)
node2.visit_count = 20
node.children = {1: node1, 2: node2}
select_action(muzeroconfig, 20, node, muzeronet)

2

In [13]:
from architecture.network import MuZeroConnectN

muzeronet = MuZeroConnectN()
play_game(muzeroconfig, muzeronet)

Simulation finished. Current game history 0.
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || || |||
ooooooooooooooooooooooooo
Simulation finished. Current game history 1.
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || ||1|||
ooooooooooooooooooooooooo
Simulation finished. Current game history 2.
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || || |||
||---------------------||
||| || || || || || || |||
||---------------------||
||| || 

KeyboardInterrupt: 

In [14]:
from architecture.replay import ReplayBuffer

In [15]:
from typing import List, Tuple
from torch import Tensor
from torch.nn.functional import softmax, cross_entropy

class MuZero:

    def __init__(self, config: MuZeroConfig):
        self.steps = 0
        self.config = config
        self.network = self.load_network()

    def execute(self):
        replay_buffer = ReplayBuffer(self.config.max_buffer_size, self.config.batch_size)

        while True:
            for _ in range(2):
                self.self_play(replay_buffer)
            self.train_model(replay_buffer)

    def self_play(self, replay_buffer: ReplayBuffer):
        network = self.get_latest_network()
        game = play_game(self.config, network)
        replay_buffer.add_game(game)

    def train_model(self, replay_buffer: ReplayBuffer):
        network = self.get_latest_network()
        optimizer = torch.optim.SGD(network.parameters(), lr=self.config.lr, momentum=self.config.momentum)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=self.config.decaying_func())

        for i in self.config.training_steps:
            if i%self.config.step_for_saving == 0:
                self.save_network()
            batch = replay_buffer.sample_batch(self.config.unroll_steps, self.config.td_steps)
            self.update_weights(batch, optimizer, network, scheduler)

    def update_weights(self, batch: List[Tuple[Tensor, List, List]], optimizer: torch.optim.Optimizer, network: MuZeroNetwork, scheduler: torch.optim.lr_scheduler.LRScheduler = None):

        loss = 0
        for image, actions, targets in batch:
            value, reward, policy_logits, hidden_state = network.initial_inference(image).unpack()
            # Grab all predictions as a tuple of: 
            # - gradient_scale (based on the number of actions)
            # - predicted value
            # - predicted logits
            # - predicted reward
            predictions: List[Tuple[float, Tensor, Tensor, Tensor]] = [(1.0, value, policy_logits, reward)]

            for action in actions:
                value, reward, policy_logits, hidden_state = network.recurrent_inference(hidden_state=hidden_state, action=action).unpack()

                # Scale the gradient
                predictions.append((1.0/len(actions), value, policy_logits, reward))
                # Adjust hidden_state gradient
                hidden_state = ut.scale_torch_gradient(hidden_state, self.config.gradient_scale_factor)

            for pred, target in zip(predictions, targets):

                l = self.calculate_loss(pred, target, network)
                loss += l

        for weight in network.parameters():
            loss = self.config.weight_decay * weight.norm()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

    def calculate_loss(self, pred: List[Tuple[float, Tensor, Tensor, Tensor]], target: Target, network: MuZeroNetwork):
        scaling_factor, value, policy_logits, reward = pred
        t_target_value, t_target_reward, softmaxed_t_target_policy = network.from_target_to_support(target)
        
        softmaxed_value_logits = softmax(value, dim=1)
        softmaxed_reward_logits = softmax(reward, dim=1)
        softmaxed_policy_logits = softmax(policy_logits, dim=1)
        softmaxed_t_target_value = softmax(t_target_value, dim=1)
        softmaxed_t_target_reward = softmax(t_target_reward, dim=1)

        lv = cross_entropy(softmaxed_value_logits, softmaxed_t_target_value, reduction="sum")
        lr = cross_entropy(softmaxed_reward_logits, softmaxed_t_target_reward, reduction="sum")
        la = cross_entropy(softmaxed_policy_logits, softmaxed_t_target_policy, reduction="sum")

        return ut.scale_torch_gradient(lv + lr + la, scaling_factor)

    def get_latest_network(self):
        network = MuZeroConnectN()
        self.load_network(network)
        return network
        
    def load_network(self, network: MuZeroNetwork):
        loaded_nets = network.load_network()

    def save_network(self, network: MuZeroNetwork):
        network.save_network()
