In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import sys
sys.path.append('../..')
from muzero.config import make_atari_config
from muzero.continous import ContinousActionDecoder, ContinousActionEncoder, ContinousMuzeroNet, VitConfig, tokenizer
from muzero.gym_env import create_atari_environment
import numpy as np

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
runtime_device = "cuda"

random_state = np.random.RandomState(42)

def environment_builder():
    return create_atari_environment(
        env_name="Pong",
        screen_height=224,
        screen_width=224,
        frame_skip=4,
        frame_stack=2,
        max_episode_steps=1000,
        seed=random_state.randint(1, 2**31),
        noop_max=30,
        terminal_on_life_loss=False,
        clip_reward=False,
        output_actions=True,
        resize_and_gray=False
    )

eval_env, eval_actions = environment_builder()

config = make_atari_config(
        num_training_steps=10,
        batch_size=2,
        min_replay_size=2,
        use_tensorboard=False,
        clip_grad=True,
    )
    
formatted_actions = [f"action: {action}" for action in eval_actions]
print(f"formatted actions: {formatted_actions}")

tokenized_actions = tokenizer(formatted_actions, padding=True, return_tensors="pt").to(runtime_device)
# print(f"tokenized actions: {tokenized_actions}")
action_encoder = ContinousActionEncoder()
action_embeddings = action_encoder(tokenized_actions.input_ids, tokenized_actions.attention_mask)

print("action embeddings shape: ", action_embeddings.shape)

action_decoder = ContinousActionDecoder(action_embeddings)

network = ContinousMuzeroNet(
    action_encoder,
    action_decoder,
    action_embeddings.shape[-1],
    VitConfig(),
    config.num_planes,
    config.value_support_size,
    config.reward_support_size,
)

  deprecation(
  deprecation(


formatted actions: ['action: NOOP', 'action: FIRE', 'action: RIGHT', 'action: LEFT', 'action: RIGHTFIRE', 'action: LEFTFIRE']


  logger.warn(
The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
  return F.scaled_dot_product_attention(


action embeddings shape:  torch.Size([6, 2048])




In [4]:
%env PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python

env: PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python


In [5]:


import copy
import logging
import multiprocessing
from typing import Callable, Iterable, List, Optional

import gym
import torch
from muzero.config import MuZeroConfig
from muzero.mcts import uct_search
from muzero.network import MuZeroNet
from muzero.pipeline import compute_mc_return_target, compute_n_step_target, handle_exit_signal, init_absl_logging, make_unroll_sequence
from muzero.replay import Transition
from muzero.trackers import make_actor_trackers


def make_continous_unroll_sequence(
    observations: List[np.ndarray],
    actions: List[np.ndarray],
    rewards: List[float],
    pi_probs: List[np.ndarray],
    values: List[float],
    priorities: List[float],
    unroll_steps: int,
) -> Iterable[Transition]:
    """Turn a lists of episode history into a list of structured transition object,
    and stack unroll_steps for actions, rewards, values, MCTS policy.

    Args:
        observations: a list of history environment observations.
        actions: a list of history actual actions taken in the environment.
        rewards: a list of history reward received from the environment.
        pi_probs: a list of history policy probabilities from the MCTS search result.
        values: a list of n-step target value.
        priorities: a list of priorities for each transition.
        unroll_steps: number of unroll steps during traning.

    Returns:
        yeilds tuple of structured Transition object and the associated priority for the specific transition.

    """

    T = len(observations)

    # States past the end of games are treated as absorbing states.
    if len(actions) == T:
        actions += [0] * unroll_steps
    if len(rewards) == T:
        rewards += [0] * unroll_steps
    if len(values) == T:
        values += [0] * unroll_steps
    if len(pi_probs) == T:
        absorb_policy = np.ones_like(pi_probs[-1]) / len(pi_probs[-1])
        pi_probs += [absorb_policy] * unroll_steps

    assert len(actions) == len(rewards) == len(values) == len(pi_probs) == T + unroll_steps

    for t in range(T):
        end_index = t + unroll_steps
        action_sequence = torch.stack([action.cpu().float() for action in actions[t:end_index]])
        print(" ================= action_sequence: ", action_sequence)
        print(" ================= action_sequence shape: ", [action.shape for action in action_sequence])
        stacked_action = action_sequence.numpy()
        stacked_reward = np.array(rewards[t:end_index], dtype=np.float32)
        stacked_value = np.array(values[t:end_index], dtype=np.float32)
        stacked_pi_prob = np.array(pi_probs[t:end_index], dtype=np.float32)

        yield (
            Transition(
                state=observations[t],  # no stacking for observation, since it is only used to get initial hidden state.
                action=stacked_action,
                reward=stacked_reward,
                value=stacked_value,
                pi_prob=stacked_pi_prob,
            ),
            priorities[t],
        )
        

@torch.no_grad()
def run_self_play(
    config: MuZeroConfig,
    rank: int,
    network: MuZeroNet,
    device: torch.device,
    env: gym.Env,
    data_queue: multiprocessing.Queue,
    train_steps_counter: multiprocessing.Value,
    stop_event: multiprocessing.Event,
    tag: str = None,
    no_mask: bool = False,
    action_decoder: Optional[Callable] = None,
    action_encoder: Optional[Callable] = None,
) -> None:
    """Run self-play for as long as needed, only stop if `stop_event` is set to True.

    Args:
        config: a MuZeroConfig instance.
        rank: actor process rank.
        network: a MuZeroNet instance for acting.
        device: PyTorch runtime device.
        env: actor's env.
        data_queue: a multiprocessing.Queue instance to send samples to leaner.
        train_steps_counter: a multiprocessing.Value instance to count current training steps.
        stop_event: a multiprocessing.Event instance signals stop run pipeline.
        tag: add tag to tensorboard log dir.
    """

    init_absl_logging()
    handle_exit_signal()
    logging.info(f'Start self-play actor {rank}')

    tb_log_dir = f'actor{rank}'
    if tag is not None and tag != '':
        tb_log_dir = f'{tag}_{tb_log_dir}'

    trackers = make_actor_trackers(tb_log_dir) if config.use_tensorboard else []
    for tracker in trackers:
        tracker.reset()

    network = network.to(device=device)
    network.eval()
    game = 0

    while not stop_event.is_set():  # For each new game.
        obs = env.reset()
        done = False
        episode_trajectory = []
        steps = 0

        # Play and record transitions.
        # the second check is necessary becase the pipeline could have already stopped while the actor is in the middle of a game.
        while not done and not stop_event.is_set():
            # Make a copy of current player id.
            player_id = copy.deepcopy(env.current_player)
            # print(" ================= obs shape: ", obs.shape)
            action, pi_prob, root_value = uct_search(
                state=obs,
                network=network,
                device=device,
                config=config,
                temperature=config.visit_softmax_temperature_fn(steps, train_steps_counter.value),
                actions_mask=env.actions_mask,#None if no_mask else env.actions_mask,
                current_player=env.current_player,
                opponent_player=env.opponent_player,
                action_encoder=action_encoder,
            )
            
            # if action_decoder is not None:
                # action = action_decoder(action)
            # print(" ================= action: ", action)

            next_obs, reward, done, _ = env.step(action)
            steps += 1
            print (" ================= iteration: ", steps)
            if (steps % 10) == 0:
                done = True
            action = action if action_encoder is None else action_encoder(action)
            for tracker in trackers:
                tracker.step(reward, done)

            episode_trajectory.append((obs, action, reward, pi_prob, root_value, player_id))
            obs = next_obs
            
            # Send samples to learner every N steps on Atari games.
            # Here we accmulate N + unroll_steps + td_steps because
            # we needs these extra sequences to compute the target and unroll sequences.
            if (
                not config.is_board_game
                and len(episode_trajectory) == config.acc_seq_length + config.unroll_steps + config.td_steps
            ):
                # Unpack list of tuples into seperate lists.
                observations, actions, rewards, pi_probs, root_values, _ = map(list, zip(*episode_trajectory))
                # Compute n_step target value.
                target_values = compute_n_step_target(rewards, root_values, config.td_steps, config.discount)

                priorities = np.abs(np.array(root_values) - np.array(target_values))

                print(" ================= actions: ", actions[: config.acc_seq_length + config.unroll_steps])
                # Make unroll sequences and send to learner.
                for transition, priority in make_continous_unroll_sequence(
                    observations[: config.acc_seq_length],
                    actions[: config.acc_seq_length + config.unroll_steps],
                    rewards[: config.acc_seq_length + config.unroll_steps],
                    pi_probs[: config.acc_seq_length + config.unroll_steps],
                    target_values[: config.acc_seq_length + config.unroll_steps],
                    priorities[: config.acc_seq_length + config.unroll_steps],
                    config.unroll_steps,
                ):
                    # data_queue.put((transition, priority))
                    print(" ================= transition: ", transition)
                    print(" ================= priority: ", priority)

                del episode_trajectory[: config.acc_seq_length]
                del (observations, actions, rewards, pi_probs, root_values, priorities, target_values)

        game += 1

        # Unpack list of tuples into seperate lists.
        observations, actions, rewards, pi_probs, root_values, player_ids = map(list, zip(*episode_trajectory))
        
        if config.is_board_game:
            # Using MC returns as target value.
            target_values = compute_mc_return_target(rewards, player_ids)
        else:
            # Compute n_step target value.
            target_values = compute_n_step_target(rewards, root_values, config.td_steps, config.discount)

        priorities = np.abs(np.array(root_values) - np.array(target_values))
        print(" =============== full unroll")
        # Make unroll sequences and send to learner.
        for transition, priority in make_continous_unroll_sequence(
            observations, actions, rewards, pi_probs, target_values, priorities, config.unroll_steps
        ):
            data_queue.put((transition, priority))

        del episode_trajectory[:]
        del (observations, actions, rewards, pi_probs, root_values, priorities, player_ids, target_values)

    logging.info(f'Stop self-play actor {rank}')


In [6]:
import multiprocessing
import torch
from torch.optim.lr_scheduler import MultiStepLR
from muzero.atari_v2.run_training import ActionEncoderWith
from muzero.pipeline import run_training
from muzero.replay import PrioritizedReplay

optimizer = torch.optim.Adam(network.parameters(), lr=config.lr_init, weight_decay=config.weight_decay)
lr_scheduler = MultiStepLR(optimizer, milestones=config.lr_milestones, gamma=config.lr_decay_rate)
replay = PrioritizedReplay(
        10,
        0.0,
        0.0,
        random_state,
    )
data_queue = multiprocessing.SimpleQueue()
train_steps_counter = multiprocessing.Value('i', 0)
manager = multiprocessing.Manager()
checkpoint_files = manager.list()
stop_event = multiprocessing.Event()

action_encoder = ActionEncoderWith(action_embeddings)

run_self_play(
            config,
            0,
            network,
            runtime_device,
            eval_env,
            data_queue,
            train_steps_counter,
            stop_event,
            "tag",
            False,
            action_decoder,
            action_encoder,
            )

I0409 00:29:14.819081 31244 1644529528.py:110] Start self-play actor 0


        [ 0.6836, -1.2031,  1.2188,  ..., -2.0156,  0.5859,  2.7500],
        [ 0.4648, -0.3516,  1.2188,  ..., -0.6562,  1.9766,  1.3828],
        [ 0.6836,  0.4941,  1.2188,  ..., -1.2031,  0.9609,  1.3828],
        [ 1.0234, -0.1973,  1.2188,  ..., -0.7070,  1.0312,  1.3828]])


: 