# 2024年 世界モデル コンペティション 参考notebook  

第8回演習で利用したDreamerに修正を加え，Dreamer v2を用いたベースラインコードになっています．  
こちらを動かしていただけば，提出時にエラーが発生しない結果を得ることができます（参考用としてcolabの無料枠で1時間ほどで終わるようにパラメータを変えているため，性能は出ないです）．  

**目次**
1. [準備](#scrollTo=b986f379-97f5-4449-b4c6-7cc385d1f474)
2. [環境の設定](#scrollTo=c7819663-fffc-44e5-842f-779564dd8227)
3. [補助機能の実装](#scrollTo=6b9cdd13-ce4a-44b4-a01d-5a19d4e38bae)
4. [モデルの実装](#scrollTo=0662612e-701b-41a2-8679-25ad03fef367)
5. [学習](#scrollTo=b06c188f-8a87-42e7-9f61-7f385eccc565)
6. [モデルの保存](#scrollTo=aa693a51-a4cb-4ad4-be2b-322cbd68443d)
7. [学習済みパラメータで評価](#scrollTo=c4b31352-bafa-46ed-8bcc-632a24dfced6)

以下良い性能を出すためにできる工夫の例です．  
- ハイパーパラメータを調整する．  
  - バッチサイズを大きくする．
  - 更新回数を増やす（update_freqを小さくする）．
  - モデルの次元数を大きくする．  など
- Dreamer v2の各モデルのアーキテクチャを変更する．
- Dreamer v2以外の学習手法を用いる．

## 1. 準備  

必要なライブラリのインストール．各自必要なライブラリがある場合は追加でインストールしてください．  

In [1]:
# !pip install gym==0.26.2 gym[atari]==0.26.2 gym[accept-rom-license]==0.26.2 autorom ale-py

ローカルで動かすならこっち
```bash
uv sync
```

### 1.1 ライブラリインポート  

In [2]:
# import libraries
from time import time
import os
import random
from copy import deepcopy
from gym.wrappers import ResizeObservation
import torch
from torch import nn
from torch.nn import functional as F
from functorch import combine_state_for_ensemble
from tensordict.tensordict import TensorDict
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import SliceSampler
import re
import sys
import math
import pandas as pd
import wandb
import datetime
from pathlib import Path
import numpy as np
import gym

In [3]:
# エラー回避
from ale_py.env import gym as ale_gym
from typing import Any, Text

# Patch to allow rendering Atari games.
# The AtariEnv's render method expects the mode to be in self._render_mode
# (usually initialized with env.make) instead of taking mode as a param.
_original_atari_render = ale_gym.AtariEnv.render


def atari_render(self, mode: Text = 'rgb_array') -> Any:
  original_render_mode = self._render_mode
  try:
    self._render_mode = mode
    return _original_atari_render(self)
  finally:
    self._render_mode = original_render_mode


ale_gym.AtariEnv.render = atari_render

## 2. 環境の設定  

### 2.1 Repeat Action  
- こちらで実装している環境を用いてOmnicampus上では評価を行います．  
- モデルによって変更する可能性があると想定している部分は以下のとおりです．
    - 画像のレンダリングサイズ(ResizeObervationクラスのshape)．
    - 同じ行動を繰り返す数（RepeatActionクラスのskip）

In [4]:
# define the environment wrapper
class RepeatAction(gym.Wrapper):
    """
    同じ行動を指定された回数自動的に繰り返すラッパー. 観測は最後の行動に対応するものになる
    """
    def __init__(self, env, skip=4, max_steps=100_000):
        gym.Wrapper.__init__(self, env)
        self.max_steps = max_steps if max_steps else float("inf")  # イテレーションの制限
        self.steps = 0  # イテレーション回数のカウント
        self.height = env.observation_space.shape[0]
        self.width = env.observation_space.shape[1]
        self.observation_space = gym.spaces.Box(
            low=0,
            high=255,
            shape=(3, self.height, self.width),
            dtype=np.uint8,
        )
        self._skip = skip

    def reset(self):
        obs = self.env.reset()
        # obs[0]が[64,64,3]の場合、[3,64,64]に変換
        return np.transpose(obs[0], (2, 0, 1))

    def step(self, action):
        if self.steps >= self.max_steps:  # 100kに達したら何も返さないようにする
            print("Reached max iterations.")
            return None

        total_reward = 0.0
        self.steps += 1
        for _ in range(self._skip):
            # ここでaは離散である必要がある．
            discrete_a = torch.argmax(action).cpu()
            obs, reward, done, truncated, info = self.env.step(discrete_a)

            total_reward += reward
            if self.steps >= self.max_steps:  # 100kに達したら終端にする
                done = True

            if done:
                break

        # obsが[64,64,3]の場合、[3,64,64]に変換して返す
        obs_transposed = np.transpose(obs, (2, 0, 1))
        return obs_transposed, total_reward, done, truncated, info
    def rand_act(self):
        # -1から1の間のサイズが[9]のtorch.tensorを返す
        return torch.rand(9) * 2 - 1

In [5]:
# define make_env
def make_env(seed=None, img_size=64, max_steps=100_000):
    env = gym.make("ALE/MsPacman-v5")

    # シード固定
    env.seed(seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)

    env = ResizeObservation(env, (img_size, img_size))
    env = RepeatAction(env=env, skip=4, max_steps=max_steps)

    return env

## 3. 補助機能の実装  
- モデルを保存する際に利用できるクラス，torchのシード値を固定できる関数です．   
- 提出いただくパラメータの保存や読み込みにこちらのクラスを必ず利用する必要はありません  ．

In [6]:
# モデルパラメータをGoogleDriveに保存・後で読み込みするためのヘルパークラス
class TrainedModels:
    def __init__(self, *models) -> None:
        """
        コンストラクタ．

        Parameters
        ----------
        models : nn.Module
            保存するモデル．複数モデルを渡すことができます．

        使用例: trained_models = TraindModels(encoder, rssm, value_model, action_model)
        """
        assert np.all([nn.Module in model.__class__.__bases__ for model in models]), "Arguments for TrainedModels need to be nn models."

        self.models = models

    def save(self, dir: str) -> None:
        """
        initで渡したモデルのパラメータを保存します．
        パラメータのファイル名は01.pt, 02.pt, ... のように連番になっています．

        Parameters
        ----------
        dir : str
            パラメータの保存先．
        """
        for i, model in enumerate(self.models):
            torch.save(
                model.state_dict(),
                os.path.join(dir, f"{str(i + 1).zfill(2)}.pt")
            )

    def load(self, dir: str, device: str) -> None:
        """
        initで渡したモデルのパラメータを読み込みます．

        Parameters
        ----------
        dir : str
            パラメータの保存先．
        device : str
            モデルをどのデバイス(CPU or GPU)に載せるかの設定．
        """
        for i, model in enumerate(self.models):
            model.load_state_dict(
                torch.load(
                    os.path.join(dir, f"{str(i + 1).zfill(2)}.pt"),
                    map_location=device
                )
            )

In [7]:
# set_seed
def set_seed(seed: int) -> None:
    """
    Pytorch, NumPyのシード値を固定します．これによりモデル学習の再現性を担保できます．

    Parameters
    ----------
    seed : int
        シード値．
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

## 4. モデルの実装

https://github.com/DarthUtopian/tdmpc_square_public/tree/main/tdmpc_square/tdmpc_square

TD-M(PC)^2を実装

In [8]:
# init
def weight_init(m):
    """Custom weight initialization for TD-MPC2."""
    if isinstance(m, nn.Linear):
        nn.init.trunc_normal_(m.weight, std=0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Embedding):
        nn.init.uniform_(m.weight, -0.02, 0.02)
    elif isinstance(m, nn.ParameterList):
        for i, p in enumerate(m):
            if p.dim() == 3:  # Linear
                nn.init.trunc_normal_(p, std=0.02)  # Weight
                nn.init.constant_(m[i + 1], 0)  # Bias


def zero_(params):
    """Initialize parameters to zero."""
    for p in params:
        p.data.fill_(0)

In [9]:
# layers
class Ensemble(nn.Module):
    """
    Vectorized ensemble of modules.
    """

    def __init__(self, modules, **kwargs):
        super().__init__()
        modules = nn.ModuleList(modules)
        fn, params, _ = combine_state_for_ensemble(modules)
        self.vmap = torch.vmap(
            fn, in_dims=(0, 0, None), randomness="different", **kwargs
        )
        self.params = nn.ParameterList([nn.Parameter(p) for p in params])
        self._repr = str(modules)

    def forward(self, *args, **kwargs):
        return self.vmap([p for p in self.params], (), *args, **kwargs)

    def __repr__(self):
        return "Vectorized " + self._repr


class ShiftAug(nn.Module):
    """
    Random shift image augmentation.
    Adapted from https://github.com/facebookresearch/drqv2
    """

    def __init__(self, pad=3):
        super().__init__()
        self.pad = pad

    def forward(self, x):
        x = x.float()
        # batch化せずにxが入力されることがある
        if len(x.shape) == 3:
            x = x.unsqueeze(0) # [3, 64, 64] -> [1, 3, 64, 64]
        n, _, h, w = x.size()
        assert h == w
        padding = tuple([self.pad] * 4)
        x = F.pad(x, padding, "replicate")
        eps = 1.0 / (h + 2 * self.pad)
        arange = torch.linspace(
            -1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype
        )[:h]
        arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
        base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
        base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
        shift = torch.randint(
            0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype
        )
        shift *= 2.0 / (h + 2 * self.pad)
        grid = base_grid + shift
        return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)


class PixelPreprocess(nn.Module):
    """
    Normalizes pixel observations to [-0.5, 0.5].
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.div_(255.0).sub_(0.5)


class SimNorm(nn.Module):
    """
    Simplicial normalization.
    Adapted from https://arxiv.org/abs/2204.00616.
    """

    def __init__(self, cfg):
        super().__init__()
        self.dim = cfg.simnorm_dim

    def forward(self, x):
        shp = x.shape
        x = x.view(*shp[:-1], -1, self.dim)
        x = F.softmax(x, dim=-1)
        return x.view(*shp)

    def __repr__(self):
        return f"SimNorm(dim={self.dim})"


class NormedLinear(nn.Linear):
    """
    Linear layer with LayerNorm, activation, and optionally dropout.
    """

    def __init__(self, *args, dropout=0.0, act=nn.Mish(inplace=True), **kwargs):
        super().__init__(*args, **kwargs)
        self.ln = nn.LayerNorm(self.out_features)
        self.act = act
        self.dropout = nn.Dropout(dropout, inplace=True) if dropout else None

    def forward(self, x):
        x = super().forward(x)
        if self.dropout:
            x = self.dropout(x)
        return self.act(self.ln(x))

    def __repr__(self):
        repr_dropout = f", dropout={self.dropout.p}" if self.dropout else ""
        return (
            f"NormedLinear(in_features={self.in_features}, "
            f"out_features={self.out_features}, "
            f"bias={self.bias is not None}{repr_dropout}, "
            f"act={self.act.__class__.__name__})"
        )


def mlp(in_dim, mlp_dims, out_dim, act=None, dropout=0.0):
    """
    Basic building block of TD-MPC2.
    MLP with LayerNorm, Mish activations, and optionally dropout.
    """
    if isinstance(mlp_dims, int):
        mlp_dims = [mlp_dims]
    dims = [in_dim] + mlp_dims + [out_dim]
    mlp = nn.ModuleList()
    for i in range(len(dims) - 2):
        mlp.append(NormedLinear(dims[i], dims[i + 1], dropout=dropout * (i == 0)))
    mlp.append(
        NormedLinear(dims[-2], dims[-1], act=act)
        if act
        else nn.Linear(dims[-2], dims[-1])
    )
    return nn.Sequential(*mlp)


def conv(in_shape, num_channels, act=None):
    """
    Basic convolutional encoder for TD-MPC2 with raw image observations.
    4 layers of convolution with ReLU activations, followed by a linear layer.
    """
    assert in_shape[-1] == 64  # assumes rgb observations to be 64x64
    layers = [
        ShiftAug(),
        PixelPreprocess(),
        nn.Conv2d(in_shape[0], num_channels, 7, stride=2),
        nn.ReLU(inplace=True),
        nn.Conv2d(num_channels, num_channels, 5, stride=2),
        nn.ReLU(inplace=True),
        nn.Conv2d(num_channels, num_channels, 3, stride=2),
        nn.ReLU(inplace=True),
        nn.Conv2d(num_channels, num_channels, 3, stride=1),
        nn.Flatten(),
    ]
    if act:
        layers.append(act)
    return nn.Sequential(*layers)


def enc(cfg, out={}):
    """
    Returns a dictionary of encoders for each observation in the dict.
    """
    for k in cfg.obs_shape.keys():
        if k == "state":
            out[k] = mlp(
                cfg.obs_shape[k][0] + cfg.task_dim,
                max(cfg.num_enc_layers - 1, 1) * [cfg.enc_dim],
                cfg.latent_dim,
                act=SimNorm(cfg),
            )
        elif k == "rgb":
            out[k] = conv(cfg.obs_shape[k], cfg.num_channels, act=SimNorm(cfg))
        else:
            raise NotImplementedError(
                f"Encoder for observation type {k} not implemented."
            )
    return nn.ModuleDict(out)

In [10]:
# math
def soft_ce(pred, target, cfg):
    """Computes the cross entropy loss between predictions and soft targets."""
    pred = F.log_softmax(pred, dim=-1)
    target = two_hot(target, cfg)
    return -(target * pred).sum(-1, keepdim=True)


@torch.jit.script
def _log_std(x, low, dif):
    return low + 0.5 * dif * (torch.tanh(x) + 1)


@torch.jit.script
def _gaussian_residual(eps, log_std):
    return -0.5 * eps.pow(2) - log_std


@torch.jit.script
def _gaussian_logprob(residual):
    return residual - 0.5 * torch.log(2 * torch.pi)


def gaussian_logprob(eps, log_std, size=None):
    """Compute Gaussian log probability."""
    residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True)
    if size is None:
        size = eps.size(-1)
    return _gaussian_logprob(residual) * size


@torch.jit.script
def _squash(pi):
    return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)


def squash(mu, pi, log_pi):
    """Apply squashing function."""
    mu = torch.tanh(mu)
    pi = torch.tanh(pi)
    log_pi -= _squash(pi).sum(-1, keepdim=True)
    return mu, pi, log_pi


@torch.jit.script
def symlog(x):
    """
    Symmetric logarithmic function.
    Adapted from https://github.com/danijar/dreamerv3.
    """
    return torch.sign(x) * torch.log(1 + torch.abs(x))


@torch.jit.script
def symexp(x):
    """
    Symmetric exponential function.
    Adapted from https://github.com/danijar/dreamerv3.
    """
    return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)


def two_hot(x, cfg):
    """Converts a batch of scalars to soft two-hot encoded targets for discrete regression."""
    if cfg.num_bins == 0:
        return x
    elif cfg.num_bins == 1:
        return symlog(x)
    x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax).squeeze(1)
    bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long()
    bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1)
    soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device)
    soft_two_hot.scatter_(1, bin_idx.unsqueeze(1), 1 - bin_offset)
    soft_two_hot.scatter_(1, (bin_idx.unsqueeze(1) + 1) % cfg.num_bins, bin_offset)
    return soft_two_hot


DREG_BINS = None


def two_hot_inv(x, cfg):
    """Converts a batch of soft two-hot encoded vectors to scalars."""
    global DREG_BINS
    if cfg.num_bins == 0:
        return x
    elif cfg.num_bins == 1:
        return symexp(x)
    if DREG_BINS is None:
        DREG_BINS = torch.linspace(cfg.vmin, cfg.vmax, cfg.num_bins, device=x.device)
    x = F.softmax(x, dim=-1)
    x = torch.sum(x * DREG_BINS, dim=-1, keepdim=True)
    return symexp(x)

In [11]:
# world model
class WorldModel(nn.Module):
    """
    TD-MPC2 implicit world model architecture.
    Can be used for both single-task and multi-task experiments.
    """

    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        if cfg.multitask:
            self._task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1)
            self._action_masks = torch.zeros(len(cfg.tasks), cfg.action_dim)
            for i in range(len(cfg.tasks)):
                self._action_masks[i, : cfg.action_dims[i]] = 1.0
        self._encoder = enc(cfg)
        self._dynamics = mlp(
            cfg.latent_dim + cfg.action_dim + cfg.task_dim,
            2 * [cfg.mlp_dim],
            cfg.latent_dim,
            act=SimNorm(cfg),
        )
        self._reward = mlp(
            cfg.latent_dim + cfg.action_dim + cfg.task_dim,
            2 * [cfg.mlp_dim],
            max(cfg.num_bins, 1),
        )
        self._pi = mlp(
            cfg.latent_dim + cfg.task_dim, 2 * [cfg.mlp_dim], 2 * cfg.action_dim
        )
        self._Qs = Ensemble(
            [
                mlp(
                    cfg.latent_dim + cfg.action_dim + cfg.task_dim,
                    2 * [cfg.mlp_dim],
                    max(cfg.num_bins, 1),
                    dropout=cfg.dropout,
                )
                for _ in range(cfg.num_q)
            ]
        )
        self.apply(weight_init)
        zero_([self._reward[-1].weight, self._Qs.params[-2]])
        self._target_Qs = deepcopy(self._Qs).requires_grad_(False)
        self.log_std_min = torch.tensor(cfg.log_std_min)
        self.log_std_dif = torch.tensor(cfg.log_std_max) - self.log_std_min

    @property
    def total_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def to(self, *args, **kwargs):
        """
        Overriding `to` method to also move additional tensors to device.
        """
        super().to(*args, **kwargs)
        if self.cfg.multitask:
            self._action_masks = self._action_masks.to(*args, **kwargs)
        self.log_std_min = self.log_std_min.to(*args, **kwargs)
        self.log_std_dif = self.log_std_dif.to(*args, **kwargs)
        return self

    def train(self, mode=True):
        """
        Overriding `train` method to keep target Q-networks in eval mode.
        """
        super().train(mode)
        self._target_Qs.train(False)
        return self

    def track_q_grad(self, mode=True):
        """
        Enables/disables gradient tracking of Q-networks.
        Avoids unnecessary computation during policy optimization.
        This method also enables/disables gradients for task embeddings.
        """
        for p in self._Qs.parameters():
            p.requires_grad_(mode)
        if self.cfg.multitask:
            for p in self._task_emb.parameters():
                p.requires_grad_(mode)

    def soft_update_target_Q(self):
        """
        Soft-update target Q-networks using Polyak averaging.
        """
        with torch.no_grad():
            for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()):
                p_target.data.lerp_(p.data, self.cfg.tau)

    def task_emb(self, x, task):
        """
        Continuous task embedding for multi-task experiments.
        Retrieves the task embedding for a given task ID `task`
        and concatenates it to the input `x`.
        """
        if isinstance(task, int):
            task = torch.tensor([task], device=x.device)
        emb = self._task_emb(task.long())
        if x.ndim == 3:
            emb = emb.unsqueeze(0).repeat(x.shape[0], 1, 1)
        elif emb.shape[0] == 1:
            emb = emb.repeat(x.shape[0], 1)
        return torch.cat([x, emb], dim=-1)

    def encode(self, obs, task):
        """
        Encodes an observation into its latent representation.
        This implementation assumes a single state-based observation.
        """
        if self.cfg.multitask:
            obs = self.task_emb(obs, task)
        if self.cfg.obs == "rgb" and obs.ndim == 5:
            return torch.stack([self._encoder[self.cfg.obs](o) for o in obs])
        return self._encoder[self.cfg.obs](obs)

    def next(self, z, a, task):
        """
        Predicts the next latent state given the current latent state and action.
        """
        if self.cfg.multitask:
            z = self.task_emb(z, task)
        z = torch.cat([z, a], dim=-1)
        return self._dynamics(z)

    def reward(self, z, a, task):
        """
        Predicts instantaneous (single-step) reward.
        """
        if self.cfg.multitask:
            z = self.task_emb(z, task)
        z = torch.cat([z, a], dim=-1)
        return self._reward(z)

    def pi(self, z, task):
        """
        Samples an action from the policy prior.
        The policy prior is a Gaussian distribution with
        mean and (log) std predicted by a neural network.
        """
        if self.cfg.multitask:
            z = self.task_emb(z, task)

        # Gaussian policy prior
        mu, log_std = self._pi(z).chunk(2, dim=-1)
        log_std = _log_std(log_std, self.log_std_min, self.log_std_dif)
        eps = torch.randn_like(mu)

        if self.cfg.multitask:  # Mask out unused action dimensions
            mu = mu * self._action_masks[task]
            log_std = log_std * self._action_masks[task]
            eps = eps * self._action_masks[task]
            action_dims = self._action_masks.sum(-1)[task].unsqueeze(-1)
        else:  # No masking
            action_dims = None

        log_pi = gaussian_logprob(eps, log_std, size=action_dims)
        pi = mu + eps * log_std.exp()
        mu, pi, log_pi = squash(mu, pi, log_pi)

        return mu, pi, log_pi, log_std
    
    def log_pi_action(self, z, a, task):
        """
        Compute the log probability of an action sequence given the latent states.
        """
        if self.cfg.multitask:
            z = self.task_emb(z, task)
        mu, log_std = self._pi(z).chunk(2, dim=-1)
        eps = (a - mu) / (log_std.exp() + 1e-8)

        if self.cfg.multitask:  # Mask out unused action dimensions
            mu = mu * self._action_masks[task]
            log_std = log_std * self._action_masks[task]
            eps = eps * self._action_masks[task]
            action_dims = self._action_masks.sum(-1)[task].unsqueeze(-1)
        else:  # No masking
            action_dims = None
            
        log_pi = gaussian_logprob(eps, log_std, size=action_dims)
        return log_pi

    def Q(self, z, a, task, return_type="min", target=False):
        """
        Predict state-action value.
        `return_type` can be one of [`min`, `avg`, `all`]:
                - `min`: return the minimum of two randomly subsampled Q-values.
                - `avg`: return the average of two randomly subsampled Q-values.
                - `max`: return the maximum of two randomly subsampled Q-values.
                - `all`: return all Q-values.
        `target` specifies whether to use the target Q-networks or not.
        """
        assert return_type in {"min", "avg", "all", "max"}

        if self.cfg.multitask:
            z = self.task_emb(z, task)
        z = torch.cat([z, a], dim=-1)
        out = (self._target_Qs if target else self._Qs)(z)

        if return_type == "all":
            return out

        Q1, Q2 = out[np.random.choice(self.cfg.num_q, 2, replace=False)]
        Q1, Q2 = two_hot_inv(Q1, self.cfg), two_hot_inv(Q2, self.cfg)

        if return_type == "min":
            return torch.min(Q1, Q2)
        elif return_type == "avg":
            return (Q1 + Q2) / 2
        elif return_type == "max":
            qs_thot = [two_hot_inv(q, self.cfg) for q in out]
            qs_thot = torch.stack(qs_thot, dim=0)
            return torch.max(qs_thot, dim=0)[0]
        ##return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2

In [12]:
# running scale
class RunningScale:
    """Running trimmed scale estimator."""

    def __init__(self, cfg):
        self.cfg = cfg
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")
        self._value = torch.ones(1, dtype=torch.float32, device=self.device)
        self._percentiles = torch.tensor(
            [5, 95], dtype=torch.float32, device=self.device
        )

    def state_dict(self):
        return dict(value=self._value, percentiles=self._percentiles)

    def load_state_dict(self, state_dict):
        self._value.data.copy_(state_dict["value"])
        self._percentiles.data.copy_(state_dict["percentiles"])

    @property
    def value(self):
        return self._value.cpu().item()

    def _percentile(self, x):
        x_dtype, x_shape = x.dtype, x.shape
        x = x.view(x.shape[0], -1)
        in_sorted, _ = torch.sort(x, dim=0)
        positions = self._percentiles * (x.shape[0] - 1) / 100
        floored = torch.floor(positions)
        ceiled = floored + 1
        ceiled[ceiled > x.shape[0] - 1] = x.shape[0] - 1
        weight_ceiled = positions - floored
        weight_floored = 1.0 - weight_ceiled
        d0 = in_sorted[floored.long(), :] * weight_floored[:, None]
        d1 = in_sorted[ceiled.long(), :] * weight_ceiled[:, None]
        return (d0 + d1).view(-1, *x_shape[1:]).type(x_dtype)

    def update(self, x):
        percentiles = self._percentile(x.detach())
        value = torch.clamp(percentiles[1] - percentiles[0], min=1.0)
        self._value.data.lerp_(value, self.cfg.tau)

    def __call__(self, x, update=False):
        if update:
            self.update(x)
        return x * (1 / self.value)

    def __repr__(self):
        return f"RunningScale(S: {self.value})"

In [13]:
# td-mpc2
class TDMPC2:
	"""
	Modified TD-MPC2 agent. Implements training + inference.
	Current implementation supports both state and pixel observations.
	Only support Single-task setting is supported.
	"""

	def __init__(self, cfg):
		self.cfg = cfg
		if torch.cuda.is_available():
			self.device = torch.device("cuda")
		else:
			self.device = torch.device("cpu")
		self.model = WorldModel(cfg).to(self.device)
		self.optim = torch.optim.Adam(
			[
				{
					"params": self.model._encoder.parameters(),
					"lr": self.cfg.lr * self.cfg.enc_lr_scale,
				},
				{"params": self.model._dynamics.parameters()},
				{"params": self.model._reward.parameters()},
				{"params": self.model._Qs.parameters()},
				{
					"params": self.model._task_emb.parameters()
					if self.cfg.multitask
					else []
				},
			],
			lr=self.cfg.lr,
		)
		self.pi_optim = torch.optim.Adam(
			self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5
		)
		self.model.eval()
		self.scale = RunningScale(cfg)
		self.log_pi_scale = RunningScale(cfg) # policy log-probability scale
		self.cfg.iterations += 2 * int(
			cfg.action_dim >= 20
		)  # Heuristic for large action spaces
		self.discount = (
			torch.tensor(
				[self._get_discount(ep_len) for ep_len in cfg.episode_lengths],
				device="cuda",
			)
			if self.cfg.multitask
			else self._get_discount(cfg.episode_length)
		)

	def _get_discount(self, episode_length):
		"""
		Returns discount factor for a given episode length.
		Simple heuristic that scales discount linearly with episode length.
		Default values should work well for most tasks, but can be changed as needed.

		Args:
				episode_length (int): Length of the episode. Assumes episodes are of fixed length.

		Returns:
				float: Discount factor for the task.
		"""
		frac = episode_length / self.cfg.discount_denom
		return min(
			max((frac - 1) / (frac), self.cfg.discount_min), self.cfg.discount_max
		)

	def save(self, fp):
		"""
		Save state dict of the agent to filepath.

		Args:
				fp (str): Filepath to save state dict to.
		"""
		torch.save({"model": self.model.state_dict()}, fp)

	def load(self, fp):
		"""
		Load a saved state dict from filepath (or dictionary) into current agent.

		Args:
				fp (str or dict): Filepath or state dict to load.
		"""
		state_dict = fp if isinstance(fp, dict) else torch.load(fp)
		self.model.load_state_dict(state_dict["model"])

	@torch.no_grad()
	def act(self, obs, t0=False, eval_mode=False, task=None, use_pi=False):
		"""
		Select an action by planning in the latent space of the world model.

		Args:
				obs (torch.Tensor): Observation from the environment.
				t0 (bool): Whether this is the first observation in the episode.
				eval_mode (bool): Whether to use the mean of the action distribution.
				task (int): Task index (only used for multi-task experiments).

		Returns:
				torch.Tensor: Action to take in the environment.
		"""
		# obsをtorch.Tensorに変換
		obs = torch.tensor(obs, device=self.device).unsqueeze(0)
		# obs = obs.to(self.device, non_blocking=True).unsqueeze(0)
		if task is not None: # 絶対に実行されない
			task = torch.tensor([task], device=self.device)
		z = self.model.encode(obs, task) # 画像を潜在空間にエンコード
		if self.cfg.mpc and not use_pi:
			a, mu, std = self.plan(z, t0=t0, eval_mode=eval_mode, task=task)
		else: # こちらは実行されない
			mu, pi, log_pi, log_std = self.model.pi(z, task)
			if eval_mode:
				a = mu[0]
			else:
				a = pi[0]
			mu, std = mu[0], log_std.exp()[0]
		if len(a.shape) == 1:
			a = a.unsqueeze(0)
		return a.cpu(), mu.cpu(), std.cpu()

	@torch.no_grad()
	def _estimate_value(self, z, actions, task, horizon, eval_mode=False):
		"""Estimate value of a trajectory starting at latent state z and executing given actions."""
		G, discount = 0, 1
		for t in range(horizon):
			reward = two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
			z = self.model.next(z, actions[t], task)
			G += discount * reward
			discount *= (
				self.discount[torch.tensor(task)]
				if self.cfg.multitask
				else self.discount
			)
		return G + discount * self.model.Q(
			z, self.model.pi(z, task)[1], task, return_type="avg"
		)

	@torch.no_grad()
	def _estimate_value_parallel(self, z, actions, task):
		"""Estimate value of a trajectory starting at latent state z and executing given actions."""
		G, discount = 0, 1
		for t in range(self.cfg.horizon):
			reward = two_hot_inv(self.model.reward(z, actions[:, t], task), self.cfg)
			z = self.model.next(z, actions[:, t], task)
			G = G + discount * reward
			discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
			discount = discount * discount_update
		return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')

	@torch.no_grad()
	def plan(self, z, t0=False, eval_mode=False, task=None):
		"""
		Plan a sequence of actions using the learned world model.

		Args:
				z (torch.Tensor): Latent state from which to plan.
				t0 (bool): Whether this is the first observation in the episode.
				eval_mode (bool): Whether to use the mean of the action distribution.
				task (Torch.Tensor): Task index (only used for multi-task experiments).

		Returns:
				torch.Tensor: Action to take in the environment.
		"""
		if self.cfg.num_pi_trajs > 0:
			pi_actions = torch.empty(
				self.cfg.horizon,
				self.cfg.num_pi_trajs,
				self.cfg.action_dim,
				device=self.device,
			)
			_z = z.repeat(self.cfg.num_pi_trajs, 1)
			for t in range(self.cfg.horizon - 1):
				pi_actions[t] = self.model.pi(_z, task)[1]
				_z = self.model.next(_z, pi_actions[t], task)
			pi_actions[-1] = self.model.pi(_z, task)[1]

		# Initialize state and parameters
		z = z.repeat(self.cfg.num_samples, 1)
		mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
		std = self.cfg.max_std * torch.ones(
			self.cfg.horizon, self.cfg.action_dim, device=self.device
		)
		if not t0:
			mean[:-1] = self._prev_mean[1:]
		actions = torch.empty(
			self.cfg.horizon,
			self.cfg.num_samples,
			self.cfg.action_dim,
			device=self.device,
		)
		if self.cfg.num_pi_trajs > 0:
			actions[:, : self.cfg.num_pi_trajs] = pi_actions

		# Iterate MPPI
		for _ in range(self.cfg.iterations):
			# Sample actions
			actions[:, self.cfg.num_pi_trajs :] = (
				mean.unsqueeze(1)
				+ std.unsqueeze(1)
				* torch.randn(
					self.cfg.horizon,
					self.cfg.num_samples - self.cfg.num_pi_trajs,
					self.cfg.action_dim,
					device=std.device,
				)
			).clamp(-1, 1)
			if self.cfg.multitask:
				actions = actions * self.model._action_masks[task]

			# Compute elite actions
			value = self._estimate_value(z, actions, task, self.cfg.horizon).nan_to_num_(0)
			elite_idxs = torch.topk(
				value.squeeze(1), self.cfg.num_elites, dim=0
			).indices
			elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]

			# Update parameters
			max_value = elite_value.max(0)[0]
			score = torch.exp(self.cfg.temperature * (elite_value - max_value))
			score /= score.sum(0)
			mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (
				score.sum(0) + 1e-9
			)
			std = torch.sqrt(
				torch.sum(
					score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2, dim=1
				)
				/ (score.sum(0) + 1e-9)
			).clamp_(self.cfg.min_std, self.cfg.max_std)
			if self.cfg.multitask:
				mean = mean * self.model._action_masks[task]
				std = std * self.model._action_masks[task]

		# Select action
		score = score.squeeze(1).cpu().numpy()
		actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
		self._prev_mean = mean
		mu, std = actions[0], std[0]
		if not eval_mode:
			a = mu + std * torch.randn(self.cfg.action_dim, device=std.device)
		else:
			a = mu
		return a.clamp_(-1, 1), mu, std

	def update_pi(self, zs, action, mu, std, task):
		"""
		Update policy using a sequence of latent states.

		Args:
				zs (torch.Tensor): Sequence of latent states.
				action (torch.Tensor): Sequence of actions.
				task (torch.Tensor): Task index (only used for multi-task experiments).

		Returns:
				float: Loss of the policy update.
		"""
		self.pi_optim.zero_grad(set_to_none=True)
		self.model.track_q_grad(False)
		_, pis, log_pis, _ = self.model.pi(zs, task)
		qs = self.model.Q(zs, pis, task, return_type="avg")
		self.scale.update(qs[0])
		qs = self.scale(qs)
			
		rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
		if self.cfg.actor_mode=="sac":
			# TD-MPC2 baseline setting.
			pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1, 2)) * rho).mean()
			prior_loss = torch.zeros_like(pi_loss) # Not used
			q_loss = pi_loss.detach().clone()

		elif self.cfg.actor_mode=="awac":
			# Loss for AWAC-MPC
			with torch.no_grad():
				vs = self.model.Q(zs, action, task, return_type="avg")
				vs = self.scale(vs)
			adv = (qs - vs).detach()
			weights = torch.clamp(torch.exp(adv / self.cfg.awac_lambda), self.cfg.exp_adv_min, self.cfg.exp_adv_max)
			log_pis_action = self.model.log_pi_action(zs, action, task)
			pi_loss = (( - weights * log_pis_action).mean(dim=(1, 2)) * rho).mean()
			q_loss = torch.zeros_like(pi_loss)
			prior_loss = torch.zeros_like(pi_loss)

		elif self.cfg.actor_mode=="residual":
			# Loss for TD-M(PC)^2
			action_dims = None if not self.cfg.multitask else self.model._action_masks.size(-1)
			std = torch.max(std, self.cfg.min_std * torch.ones_like(std))
			eps = (pis - mu) / std
			log_pis_prior = gaussian_logprob(eps, std.log(), size=action_dims).mean(dim=-1)
			#log_pis_prior = torch.clamp(log_pis_prior, -50000, 0.0)

			log_pis_prior = self.scale(log_pis_prior) if self.scale.value > self.cfg.scale_threshold else torch.zeros_like(log_pis_prior)

			q_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1, 2)) * rho).mean()
			prior_loss = - (log_pis_prior.mean(dim=-1) * rho).mean()
			pi_loss = q_loss + (self.cfg.prior_coef * self.cfg.action_dim / 61) * prior_loss

		elif self.cfg.actor_mode=="bc_sac": 
			# Vanilla BC-SAC loss for policy learning
			q_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1, 2)) * rho).mean()
			prior_loss = (((pis - action) ** 2).sum(dim=-1).mean(dim=1) * rho).mean()
			pi_loss = q_loss + self.cfg.prior_coef * prior_loss

		elif self.cfg.actor_mode=="bc":
			# Loss for BC-MPC baseline
			action_dims = None if not self.cfg.multitask else self.model._action_masks.size(-1)
			std = torch.max(std, self.cfg.min_std * torch.ones_like(std))
			eps = (pis - mu) / std
			log_pis_prior = gaussian_logprob(eps, std.log(), size=action_dims).mean(dim=-1)
			log_pis_prior = torch.clamp(log_pis_prior, -50000, 0.0)
			self.log_pi_scale.update(log_pis_prior[0])
			
			log_pis_prior = self.scale(log_pis_prior)
			pi_loss = - (log_pis_prior.mean(dim=-1) * rho).mean()
			prior_loss = pi_loss.detach().clone()
			q_loss = torch.zeros_like(pi_loss) # Not used

		else:
			raise NotImplementedError

		pi_loss.backward()
		torch.nn.utils.clip_grad_norm_(
			self.model._pi.parameters(), self.cfg.grad_clip_norm
		)
		
		self.pi_optim.step()
		self.model.track_q_grad(True)

		return pi_loss.item(), q_loss.item(), prior_loss.item()

	@torch.no_grad()
	def _td_target(self, next_z, reward, task):
		"""
		Compute the TD-target from a reward and the observation at the following time step.

		Args:
				next_z (torch.Tensor): Latent state at the following time step.
				reward (torch.Tensor): Reward at the current time step.
				task (torch.Tensor): Task index (only used for multi-task experiments).

		Returns:
				torch.Tensor: TD-target.
		"""
		pi = self.model.pi(next_z, task)[1]
		discount = (
			self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
		)
		return reward + discount * self.model.Q(
			next_z, pi, task, return_type="min", target=True
		)

	def update(self, buffer):	
		"""
		Main update function. Corresponds to one iteration of model learning.

		Args:
				buffer (common.buffer.Buffer): Replay buffer.

		Returns:
				dict: Dictionary of training statistics.
		"""
		if self.cfg.multitask and self.cfg.task in {"mt30","mt80"}:
			# offline training
			obs, action, reward, task = buffer.sample()
			mu = action.detach().clone()
			std = torch.full_like(action, self.cfg.max_std)
		else:
			# online training
			obs, action, mu, std, reward, task = buffer.sample() # mu and std are from Gaussian policy used for data collection	

		# Compute targets
		with torch.no_grad():
			next_z = self.model.encode(obs[1:], task)
			td_targets = self._td_target(next_z, reward, task)
			
		# Prepare for update
		self.optim.zero_grad(set_to_none=True)
		self.model.train()

		# Latent rollout
		zs = torch.empty(
			self.cfg.horizon + 1,
			self.cfg.batch_size,
			self.cfg.latent_dim,
			device=self.device,
		)
		z = self.model.encode(obs[0], task)
		zs[0] = z
		consistency_loss = 0
		for t in range(self.cfg.horizon):
			z = self.model.next(z, action[t], task)
			consistency_loss += F.mse_loss(z, next_z[t]) * self.cfg.rho**t
			zs[t + 1] = z

		# Predictions
		_zs = zs[:-1]
		qs = self.model.Q(_zs, action, task, return_type="all")
		reward_preds = self.model.reward(_zs, action, task)

		# Compute losses
		reward_loss, value_loss = 0, 0
		for t in range(self.cfg.horizon):
			reward_loss += (
				soft_ce(reward_preds[t], reward[t], self.cfg).mean()
				* self.cfg.rho**t
			)
			for q in range(self.cfg.num_q):
				value_loss += (
					soft_ce(qs[q][t], td_targets[t], self.cfg).mean()
					* self.cfg.rho**t
				)
		consistency_loss *= 1 / self.cfg.horizon
		reward_loss *= 1 / self.cfg.horizon
		value_loss *= 1 / (self.cfg.horizon * self.cfg.num_q)

		total_loss = (
			self.cfg.consistency_coef * consistency_loss
			+ self.cfg.reward_coef * reward_loss
			+ self.cfg.value_coef * value_loss
		)
		# Update model
		total_loss.backward()
		grad_norm = torch.nn.utils.clip_grad_norm_(
			self.model.parameters(), self.cfg.grad_clip_norm
		)
		self.optim.step()

		# Update policy
		pi_loss, pi_loss_q, pi_loss_prior  = self.update_pi(_zs.detach(), action.detach(), mu.detach(), std.detach(), task)

		# Update target Q-functions
		self.model.soft_update_target_Q()

		# Return training statistics
		self.model.eval()
		return {
			"consistency_loss": float(consistency_loss.mean().item()),
			"reward_loss": float(reward_loss.mean().item()),
			"value_loss": float(value_loss.mean().item()),
			"pi_loss": pi_loss,
			"pi_loss_q": pi_loss_q,
			"pi_loss_prior": pi_loss_prior,
			"total_loss": float(total_loss.mean().item()),
			"grad_norm": float(grad_norm),
			"pi_scale": float(self.scale.value)
		}

## 5. 学習

Num|Action
---|---
0|NOOP
1|UP
2|RIGHT
3|LEFT
4|DOWN
5|UPRIGHT
6|UPLEFT
7|DOWNRIGHT
8|DOWNLEFT

In [14]:
# logger

CONSOLE_FORMAT = [
    ("iteration", "I", "int"),
    ("episode", "E", "int"),
    ("step", "I", "int"),
    ("episode_reward", "R", "float"),
    ("episode_success", "S", "float"),
    ("total_time", "T", "time"),
]

def make_dir(dir_path):
    """Create directory if it does not already exist."""
    try:
        os.makedirs(dir_path)
    except OSError:
        pass
    return dir_path

def print_run(cfg):
    """
    Pretty-printing of current run information.
    Logger calls this method at initialization.
    """
    prefix, color, attrs = "  ", "green", ["bold"]

    def _limstr(s, maxlen=36):
        return str(s[:maxlen]) + "..." if len(str(s)) > maxlen else s

    def _pprint(k, v):
        print(prefix + f'{k.capitalize()+":":<15}')

    observations = ", ".join([str(v) for v in cfg.obs_shape.values()])
    kvs = [
        ("task", cfg.task),
        ("steps", f"{int(cfg.steps):,}"),
        ("observations", observations),
        ("actions", cfg.action_dim),
        ("experiment", cfg.exp_name),
    ]
    w = np.max([len(_limstr(str(kv[1]))) for kv in kvs]) + 25
    div = "-" * w
    print(div)
    for k, v in kvs:
        _pprint(k, v)
    print(div)

def cfg_to_group(cfg, return_list=False):
    """
    Return a wandb-safe group name for logging.
    Optionally returns group name as list.
    """
    lst = [cfg.task, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
    return lst if return_list else "-".join(lst)

class VideoRecorder:
    """Utility class for logging evaluation videos."""

    def __init__(self, cfg, wandb, fps=15):
        self.cfg = cfg
        self._save_dir = make_dir(cfg.work_dir / "eval_video")
        self._wandb = wandb
        self.fps = fps
        self.frames = []
        self.enabled = False

    def init(self, env, enabled=True):
        self.frames = []
        self.enabled = self._save_dir and self._wandb and enabled
        self.record(env)

    def record(self, env):
        if self.enabled:
            self.frames.append(env.render("rgb_array"))

    def save(self, step, key="videos/eval_video"):
        if self.enabled and len(self.frames) > 0:
            frames = np.stack(self.frames)
            return self._wandb.log(
                {
                    key: self._wandb.Video(
                        frames.transpose(0, 3, 1, 2), fps=self.fps, format="mp4"
                    )
                },
                step=step,
            )

class Logger:
    """Primary logging object. Logs either locally or using wandb."""

    def __init__(self, cfg):
        self._log_dir = make_dir(cfg.work_dir)
        self._model_dir = make_dir(self._log_dir / "models")
        self._save_csv = cfg.save_csv
        self._save_agent = cfg.save_agent
        self._group = cfg_to_group(cfg)
        self._seed = cfg.seed
        self._eval = []
        print_run(cfg)
        self.project = cfg.wandb_project if cfg.wandb_project is not None else "none"
        self.entity = cfg.wandb_entity if cfg.wandb_entity is not None else "none"
        if cfg.disable_wandb or self.project == "none" or self.entity == "none":
            print("Wandb disabled.")
            cfg.save_agent = False
            cfg.save_video = False
            self._wandb = None
            self._video = None
            return
        os.environ["WANDB_SILENT"] = "true" if cfg.wandb_silent else "false"

        wandb.init(
            project=self.project,
            # entity=self.entity,
            name=f"{cfg.task}.tdmpc.{cfg.exp_name}.{cfg.seed}",
            #group=self._group,
            # tags=cfg_to_group(cfg, return_list=True) + [f"seed:{cfg.seed}"],
            dir=self._log_dir,
            config=cfg.__dict__,
        )
        print("Logs will be synced with wandb.")
        self._wandb = wandb
        self._video = (
            VideoRecorder(cfg, self._wandb) if self._wandb and cfg.save_video else None
        )


    @property
    def video(self):
        return self._video

    @property
    def model_dir(self):
        return self._model_dir

    def save_agent(self, agent=None, identifier="final"):
        if self._save_agent and agent:
            fp = self._model_dir / f"{str(identifier)}.pt"
            agent.save(fp)
            if self._wandb:
                artifact = self._wandb.Artifact(
                    self._group + "-" + str(self._seed) + "-" + str(identifier),
                    type="model",
                )
                artifact.add_file(fp)
                self._wandb.log_artifact(artifact)

    def finish(self, agent=None):
        try:
            self.save_agent(agent)
        except Exception as e:
            print(f"Failed to save model: {e}")
        if self._wandb:
            self._wandb.finish()

    def _format(self, key, value, ty):
        if ty == "int":
            return f'{key+":"} {int(value):,}'
        elif ty == "float":
            return f'{key+":"} {value:.01f}'
        elif ty == "time":
            value = str(datetime.timedelta(seconds=int(value)))
            return f'{key+":"} {value}'
        else:
            raise f"invalid log format type: {ty}"

    def _print(self, d, category):
        pieces = [f" {category:<14}"]
        for k, disp_k, ty in CONSOLE_FORMAT:
            if k in d:
                pieces.append(f"{self._format(disp_k, d[k], ty):<22}")
        print("   ".join(pieces))

    def log(self, d, category="train"):
        if self._wandb:
            if category in {"train", "eval", "results"}:
                xkey = "step"
            elif category == "pretrain":
                xkey = "iteration"
            for k, v in d.items():
                if category == "results" and k == "step":
                    continue
                self._wandb.log({category + "/" + k: v}, step=d[xkey])
        if category == "eval" and self._save_csv:
            keys = ["step", "episode_reward"]
            self._eval.append(np.array([d[keys[0]], d[keys[1]]]))
            pd.DataFrame(np.array(self._eval)).to_csv(
                self._log_dir / "eval.csv", header=keys, index=None
            )
        if category != "results":
            self._print(d, category)

In [15]:
# buffer
class Buffer:
    """
    Replay buffer for TD-MPC2 training. Based on torchrl.
    Uses CUDA memory if available, and CPU memory otherwise.
    """

    def __init__(self, cfg):
        self.cfg = cfg
        if sys.platform == "darwin":
            self._device = torch.device("cpu")
        else:
            self._device = torch.device("cuda")
        self._capacity = min(cfg.buffer_size, cfg.steps)
        self._sampler = SliceSampler(
            num_slices=self.cfg.batch_size,
            end_key=None,
            traj_key="episode",
            truncated_key=None,
        )
        self._batch_size = cfg.batch_size * (cfg.horizon + 1)
        self._num_eps = 0

    @property
    def capacity(self):
        """Return the capacity of the buffer."""
        return self._capacity

    @property
    def num_eps(self):
        """Return the number of episodes in the buffer."""
        return self._num_eps

    def _reserve_buffer(self, storage):
        """
        Reserve a buffer with the given storage.
        """
        return ReplayBuffer(
            storage=storage,
            sampler=self._sampler,
            pin_memory=True,
            prefetch=1,
            batch_size=self._batch_size,
        )

    def _init(self, tds):
        """Initialize the replay buffer. Use the first episode to estimate storage requirements."""
        print(f"Buffer capacity: {self._capacity:,}")
        if sys.platform == "darwin":
            mem_free = 0
        else:
            mem_free, _ = torch.cuda.mem_get_info()
        bytes_per_step = sum(
            [
                (
                    v.numel() * v.element_size()
                    if not isinstance(v, TensorDict)
                    else sum([x.numel() * x.element_size() for x in v.values()])
                )
                for v in tds.values()
            ]
        ) / len(tds)
        total_bytes = bytes_per_step * self._capacity
        print(f"Storage required: {total_bytes/1e9:.2f} GB")
        # Heuristic: decide whether to use CUDA or CPU memory
        storage_device = "cuda" if 2.5 * total_bytes < mem_free else "cpu"
        print(f"Using {storage_device.upper()} memory for storage.")
        return self._reserve_buffer(
            LazyTensorStorage(self._capacity, device=torch.device(storage_device))
        )

    def _to_device(self, *args, device=None):
        if device is None:
            device = self._device
        return (
            arg.to(device, non_blocking=True) if arg is not None else None
            for arg in args
        )

    def _prepare_batch(self, td):
        """
        Prepare a sampled batch for training (post-processing).
        Expects `td` to be a TensorDict with batch size TxB.
        """
        # add capacity to store mu adn std.
        obs = td["obs"]
        action = td["action"][1:]
        mu = td["mu"][1:]#
        std = td["std"][1:]#
        reward = td["reward"][1:].unsqueeze(-1)
        task = td["task"][0] if "task" in td.keys() else None
        return self._to_device(obs, action, mu, std, reward, task)

    def add(self, td):
        """Add an episode to the buffer."""
        td["episode"] = torch.ones_like(td["reward"], dtype=torch.int64) * self._num_eps

        # FIX for HumanoidBench #
        if len(td["episode"]) <= self.cfg.horizon + 1:
            return self._num_eps
        ################################

        if self._num_eps == 0:
            self._buffer = self._init(td)
        self._buffer.extend(td)
        self._num_eps += 1
        return self._num_eps

    def sample(self):
        """Sample a batch of subsequences from the buffer."""
        td = self._buffer.sample().view(-1, self.cfg.horizon + 1).permute(1, 0)
        return self._prepare_batch(td)

In [None]:
# config
class Config:
    def __init__(self, **kwargs):
        # environment
        self.task = "ms_packman" # task name (or mt30/mt80 for multi-task training)
        self.obs = "rgb" # observation type, must be one of `[rgb, state]` (default: rgb)
        # evaluation
        self.checkpoint = None
        self.eval_episodes = 1
        self.eval_pi = True
        self.eval_value = True
        self.eval_freq = 1000
        # training
        self.steps = 100000 # number of training/environment steps (default: 10M)
        self.batch_size = 256
        self.reward_coef = 0.1
        self.value_coef = 0.1
        self.consistency_coef = 20
        self.rho = 0.5
        self.lr = 3e-4
        self.enc_lr_scale = 0.3
        self.grad_clip_norm = 20
        self.tau = 0.01
        self.discount_denom = 5
        self.discount_min = 0.95
        self.discount_max = 0.995
        self.buffer_size = 1_000_000
        self.exp_name = "default"
        # planning
        self.mpc = True
        self.iterations = 6
        self.num_samples = 512
        self.num_elites = 64
        self.num_pi_trajs = 24
        self.horizon = 10 # 15 # default 3
        self.min_std = 0.05
        self.max_std = 2
        self.temperature = 0.5
        # actor
        self.actor_mode = "residual"
        self.log_std_min = -10
        self.log_std_max = 2
        self.prior_coef = 1.0
        self.scale_threshold = 2.0
        self.entropy_coef = 1e-4
        self.awac_lambda = 0.3333
        self.exp_adv_min = 0.1
        self.exp_adv_max = 10.0
        # critic
        self.num_bins = 101
        self.vmin = -10
        self.vmax = +10
        # architecture
        self.model_size = 5 # model size, must be one of `[1, 5, 19, 48, 317]` (default: 5)
        self.num_enc_layers = 2
        self.enc_dim = 256
        self.num_channels = 32
        self.mlp_dim = 512
        self.latent_dim = 512
        self.task_dim = 0
        self.num_q = 5
        self.dropout = 0.01
        self.simnorm_dim = 8
        # logging
        self.wandb_project = "tdmcp-square-test"
        self.wandb_entity = "hirekatsu0523"
        self.wandb_silent = False
        self.disable_wandb = False # いったんTrueにしておく
        self.save_csv = True
        # misc
        self.save_video = True
        self.save_agent = True
        self.seed = 1
        # convenience
        self.work_dir = Path("./log1")
        self.multitask = False
        self.tasks = None # Noneで良い
        self.obs_shape = None
        self.action_dim = None
        self.episode_length = 300 # 良く分からないからとりあえず1000
        self.action_dims = None # Noneで良い
        self.episode_lengths = None # Noneで良い
        self.seed_steps = 1000 # ランダムに行動するステップ数
        self.bin_size = (self.vmax - self.vmin) / self.num_bins
cfg = Config()
MODEL_SIZE = {  # parameters (M)
    1: {
        "enc_dim": 256,
        "mlp_dim": 384,
        "latent_dim": 128,
        "num_enc_layers": 2,
        "num_q": 2,
    },
    5: {"enc_dim": 256, "mlp_dim": 512, "latent_dim": 512, "num_enc_layers": 2},
    19: {"enc_dim": 1024, "mlp_dim": 1024, "latent_dim": 768, "num_enc_layers": 3},
    48: {"enc_dim": 1792, "mlp_dim": 1792, "latent_dim": 768, "num_enc_layers": 4},
    317: {
        "enc_dim": 4096,
        "mlp_dim": 4096,
        "latent_dim": 1376,
        "num_enc_layers": 5,
        "num_q": 8,
    },
}
# cfg.model_sizeに応じてcfgを更新
cfg.__dict__.update(MODEL_SIZE[cfg.model_size])

In [17]:
# 環境の初期化
NUM_ITER = 100_000  # 環境とのインタラクション回数の制限 ※変更しないでください
set_seed(cfg.seed)
env = make_env(max_steps=NUM_ITER)
eval_env = make_env(seed=1234, max_steps=None)  # omnicampus上の環境と同じシード値で評価環境を作成
device = "cuda" if torch.cuda.is_available() else "cpu"
cfg.action_dim = env.action_space.n
cfg.obs_shape = {"rgb":env.observation_space.shape}
# actionは0~8の整数で指定
print("action_dim:", cfg.action_dim)
print("obs_shape:", cfg.obs_shape)

action_dim: 9
obs_shape: {'rgb': (3, 64, 64)}


In [18]:
# online trainer
class OnlineTrainer():
    """Trainer class for single-task online TD-MPC2 training."""

    def __init__(self, cfg: Config, env, eval_env, agent: TDMPC2, buffer: Buffer, logger: Logger):
        self.cfg = cfg
        self.env = env
        self.eval_env = eval_env
        self.agent = agent
        self.buffer = buffer
        self.logger = logger
        print("Architecture:", self.agent.model)
        print("Learnable parameters: {:,}".format(self.agent.model.total_params))
        self._step = 0
        self._ep_idx = 0
        self._start_time = time()

    def common_metrics(self):
        """Return a dictionary of current metrics."""
        return dict(
            step=self._step,
            episode=self._ep_idx,
            total_time=time() - self._start_time,
        )

    def eval(self):
        """Evaluate a TD-MPC2 agent."""
        ep_rewards, ep_successes = [], []
        for i in range(self.cfg.eval_episodes):
            obs, done, ep_reward, t = self.eval_env.reset(), False, 0, 0
            if self.cfg.save_video:
                self.logger.video.init(self.eval_env, enabled=(i == 0))
            while not done:
                action, _, _ = self.agent.act(obs, t0=t == 0, eval_mode=True)
                obs, reward, done, truncated, _ = self.eval_env.step(action)
                done = done or truncated
                ep_reward += reward
                t += 1
                if self.cfg.save_video:
                    self.logger.video.record(self.eval_env)
            ep_rewards.append(ep_reward)
            ep_successes.append(True)
            if self.cfg.save_video:
                # self.logger.video.save(self._step)
                self.logger.video.save(self._step, key='results/video')
        
        if self.cfg.eval_pi:
            # Evaluate nominal policy pi
            ep_rewards_pi, ep_successes_pi = [], []
            for i in range(self.cfg.eval_episodes):
                obs, done, ep_reward, t = self.eval_env.reset(), False, 0, 0
                while not done:
                    action, _, _ = self.agent.act(obs, t0=t == 0, eval_mode=True, use_pi=True)
                    obs, reward, done, truncated, _ = self.eval_env.step(action)
                    done = done or truncated
                    ep_reward += reward
                    t += 1
                ep_rewards_pi.append(ep_reward)
                ep_successes_pi.append(True)
            
        return dict(
            episode_reward=np.nanmean(ep_rewards),
            episode_success=np.nanmean(ep_successes),
            episode_reward_pi=np.nanmean(ep_rewards_pi) if self.cfg.eval_pi else np.nan,
            episode_success_pi=np.nanmean(ep_successes_pi) if self.cfg.eval_pi else np.nan,
        )

    def eval_value(self, n_samples=100):
        """evaluate value approximation."""
        # MC value estimation
        mc_ep_rewards = []
        for i in range(n_samples):
            obs, done, ep_reward, t = self.eval_env.reset(), False, 0, 0
            while not done:
                action, _, _ = self.agent.act(obs, t0=t == 0, eval_mode=True, use_pi=True)
                obs, reward, done, truncated, _ = self.eval_env.step(action)
                done = done or truncated
                ep_reward += reward * self.agent.discount ** t
                t += 1
            mc_ep_rewards.append(ep_reward)

        # Value function approximation
        q_values = []
        for i in range(n_samples):
            obs, done, ep_reward, t = self.eval_env.reset(), False, 0, 0
            
            action, _, _ = self.agent.act(obs, t0=t == 0, eval_mode=True, use_pi=True)
            task = None
            # q_value = self.agent.model.Q(self.agent.model.encode(obs.to(self.agent.device), task), 
            q_value = self.agent.model.Q(self.agent.model.encode(torch.tensor(obs, device=self.agent.device), task),
                                         action.to(self.agent.device), 
                                         task, return_type="avg")
            q_values.append(q_value.detach().cpu().numpy())
        
        return dict(
            mc_value= np.nanmean(mc_ep_rewards),
            q_value= np.nanmean(q_values),
        )

    def to_td(self, obs, action=None, mu=None, std=None, reward=None):
        """Creates a TensorDict for a new episode."""
        if isinstance(obs, dict):
            obs = TensorDict(obs, batch_size=(), device="cpu")
        else:
            obs = torch.tensor(obs, device="cpu")
            obs = obs.unsqueeze(0).cpu()
        if action is None:
            action = torch.full_like(self.env.rand_act(), float("nan"))
        if mu is None:
            mu = torch.full_like(action, float("nan"))
        if std is None:
            std = torch.full_like(action, float("nan"))
        if reward is None:
            reward = torch.tensor(float("nan"))
        # rewardがfloatならtorch.tensorに変換
        if isinstance(reward, float):
            reward = torch.tensor(reward)
        td = TensorDict(
            dict(
                obs=obs,
                action=action.unsqueeze(0) if len(action.shape) == 1 else action,
                mu=mu.unsqueeze(0),
                std=std.unsqueeze(0),
                reward=reward.unsqueeze(0),
            ),
            batch_size=(1,),
        )
        return td

    def train(self):
        """Train a TD-MPC2 agent."""
        train_metrics, done, eval_next = {}, True, True

        while self._step <= self.cfg.steps:
            # Evaluate agent periodically
            if self._step % self.cfg.eval_freq == 0:
                eval_next = True

            # Reset environment
            if done:
                if eval_next:
                    eval_metrics = self.eval()

                    if self.cfg.eval_value:
                        eval_metrics.update(self.eval_value())

                    eval_metrics.update(self.common_metrics())
                    self.logger.log(eval_metrics, "eval")
                    eval_next = False

                if self._step > 0:
                    train_metrics.update(
                        episode_reward=torch.tensor(
                            [td["reward"] for td in self._tds[1:]]
                        ).sum(),
                        episode_success=True,
                    )
                    train_metrics.update(self.common_metrics())

                    results_metrics = {
                        'return': train_metrics['episode_reward'],
                        'episode_length': len(self._tds[1:]),
                        'success': train_metrics['episode_success'],
                        'step': self._step,}

                    self.logger.log(train_metrics, "train")
                    self.logger.log(results_metrics, "results")
                    self._ep_idx = self.buffer.add(torch.cat(self._tds))

                obs = self.env.reset()
                self._tds = [self.to_td(obs)]

            # Collect experience
            if self._step > self.cfg.seed_steps:
                t0 = len(self._tds) == 1
                action, mu, std = self.agent.act(obs, t0=t0)
            else:
                action = self.env.rand_act()
                mu, std = action.detach().clone(), torch.full_like(action, math.exp(self.cfg.log_std_max)) # torch.full_like(action, float('nan')), torch.full_like(action, float('nan')) #  # noqa
            obs, reward, done, truncated, _ = self.env.step(action)
            done = done or truncated
            self._tds.append(self.to_td(obs, action, mu, std, reward))

            # Update agent
            if self._step >= self.cfg.seed_steps:
                if self._step == self.cfg.seed_steps:
                    num_updates = self.cfg.seed_steps
                    print("Pretraining agent on seed data...")
                else:
                    num_updates = 1
                for _ in range(num_updates):
                    _train_metrics = self.agent.update(self.buffer) # 少なくとも1回はdoneをする前に，updateを呼び出すとエラー self.seed_stepsを大きくすると良い
                train_metrics.update(_train_metrics)

            self._step += 1

        self.logger.finish(self.agent)

In [19]:
# エラー回避
try:
    env.reset()
    env.step(torch.tensor([1, 0, 0, 0, 0, 0, 0, 0, 0]))
except Exception as e:
    print(e)
try:
    eval_env.reset()
    eval_env.step(torch.tensor([1, 0, 0, 0, 0, 0, 0, 0, 0]))
except Exception as e:
    print(e)

module 'numpy' has no attribute 'bool8'
module 'numpy' has no attribute 'bool8'


In [None]:
# 学習
assert torch.cuda.is_available()
assert cfg.steps > 0, "Must train for at least 1 step."
print("Work dir:", cfg.work_dir)

trainer = OnlineTrainer(
    cfg=cfg,
    env=env,
    eval_env=eval_env,
    agent=TDMPC2(cfg),
    buffer=Buffer(cfg),
    logger=Logger(cfg),
)
trainer.train()
print("\nTraining completed successfully")

Work dir: log1
------------------------------------
  Task:          
  Steps:         
  Observations:  
  Actions:       
  Experiment:    
------------------------------------


  fn, params, _ = combine_state_for_ensemble(modules)
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mhirekatsu0523[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Logs will be synced with wandb.
Architecture: WorldModel(
  (_encoder): ModuleDict(
    (rgb): Sequential(
      (0): ShiftAug()
      (1): PixelPreprocess()
      (2): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2))
      (3): ReLU(inplace=True)
      (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
      (5): ReLU(inplace=True)
      (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2))
      (7): ReLU(inplace=True)
      (8): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      (9): Flatten(start_dim=1, end_dim=-1)
      (10): SimNorm(dim=8)
    )
  )
  (_dynamics): Sequential(
    (0): NormedLinear(in_features=521, out_features=512, bias=True, act=Mish)
    (1): NormedLinear(in_features=512, out_features=512, bias=True, act=Mish)
    (2): NormedLinear(in_features=512, out_features=512, bias=True, act=SimNorm)
  )
  (_reward): Sequential(
    (0): NormedLinear(in_features=521, out_features=512, bias=True, act=Mish)
    (1): NormedLinear(in_features=512, out_features=512, 

  logger.warn(


 eval             E: 0                     I: 0                     R: 280.0                 S: 1.0                   T: 0:00:47            
 train            E: 0                     I: 92                    R: 230.0                 S: 1.0                   T: 0:00:47            
Buffer capacity: 100,000
Storage required: 1.24 GB
Using CUDA memory for storage.
 train            E: 1                     I: 202                   R: 290.0                 S: 1.0                   T: 0:00:47            
 train            E: 2                     I: 333                   R: 270.0                 S: 1.0                   T: 0:00:47            
 train            E: 3                     I: 477                   R: 370.0                 S: 1.0                   T: 0:00:48            
 train            E: 4                     I: 663                   R: 580.0                 S: 1.0                   T: 0:00:48            
 train            E: 5                     I: 799                   R: 3

これ以下はいったん無視

## 6. モデルの保存

In [None]:
# モデルの保存(Google Driveの場合）
from google.colab import drive
drive.mount('/content/drive')

trained_models.save("drive/MyDrive/Colab Notebooks/")

## 7. 学習済みパラメータで評価  
- こちらの評価に用いている環境は，Omnicampus上で評価する際に用いる環境と同じになっています．
- 今回のコンペティションではPublic / Privateの分類はないため，基本的には以下の実装の評価を性能の目安としていただくと良いと思います．  

In [None]:
# 環境の読み込み
env = make_env()
device = "cuda" if torch.cuda.is_available() else "cpu"

# 学習済みモデルの読み込み
rssm = RSSM(cfg.mlp_hidden_dim, cfg.rnn_hidden_dim, cfg.state_dim, cfg.num_classes, action_dim).to(device)
encoder = Encoder().to(device)
decoder = Decoder(cfg.rnn_hidden_dim, cfg.state_dim, cfg.num_classes).to(device)
reward_model =  RewardModel(cfg.mlp_hidden_dim, cfg.rnn_hidden_dim, cfg.state_dim, cfg.num_classes).to(device)
discount_model = DiscountModel(cfg.mlp_hidden_dim, cfg.rnn_hidden_dim, cfg.state_dim, cfg.num_classes).to(device)
actor = Actor(action_dim, cfg.mlp_hidden_dim, cfg.rnn_hidden_dim, cfg.state_dim, cfg.num_classes).to(device)
critic = Critic(cfg.mlp_hidden_dim, cfg.rnn_hidden_dim, cfg.state_dim, cfg.num_classes).to(device)

trained_models = TrainedModels(
    rssm,
    encoder,
    decoder,
    reward_model,
    discount_model,
    actor,
    critic
)

trained_models.load("./", device)

In [None]:
# 結果を動画で観てみるための関数
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML


def display_video(frames):
    plt.figure(figsize=(8, 8), dpi=50)
    patch = plt.imshow(frames[0], cmap="gray")
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])
        plt.title("Step %d" % (i))

    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
    display(HTML(anim.to_jshtml(default_mode='once')))
    plt.close()

**環境のシードを固定して評価を行います．シードを変更しないでください．**
- 変更した場合，Omnicampus上での評価と結果が異なります．  

In [None]:
env = make_env(seed=1234, max_steps=None)

policy = Agent(encoder, rssm, actor)

obs = env.reset()
done = False
total_reward = 0
frames = [obs]
actions = []

while not done:
    action = policy(obs, eval=True)
    action_int = np.argmax(action)  # 環境に渡すときはint型

    obs, reward, done, _ = env.step(action_int)

    total_reward += reward
    frames.append(obs)
    actions.append(action_int)

print('Total Reward:', total_reward)

In [None]:
display_video(frames)

今回，評価を行う際のrepeat actionは1に設定しています．  
そのため，repeat actionをそれ以外に設定している場合，repeat actionの分だけ繰り返した行動を提出する形にしています．

In [None]:
# repeat actionに対応した行動に変換する
submission_actions = np.zeros(len(actions) * env._skip)
for start_idx in range(env._skip):
    submission_actions[start_idx::env._skip] = np.array(actions)

np.save("drive/MyDrive/submission", submission_actions)