We want to train an Atari game, fine-tuning a Multi-game DT

In [None]:
# For some trainings, maybe you will need a colab with "high capacity RAM" (colab Pro)

In [None]:
!pip install git+https://github.com/takuseno/d4rl-atari

# ACHTUNG You will get an error but don't worry. Press the button 'restart session' and run this cell again

In [None]:
# TODO maybe it's not necessary
!pip install gym[atari]
!pip install autorom[accept-rom-license]

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch import Tensor

from typing import Mapping, Optional, Tuple
import math
import numpy as np
import scipy
import seaborn as sns
import matplotlib.pylab as plt

import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
# CONVERT action numbers in a game into all_game numbers

# MsPacman' actions
GAME_ACTIONS = {
  "NOOP": 0,
  "UP": 1,
  "RIGHT": 2,
  "LEFT": 3,
  "DOWN": 4,
  "UPRIGHT": 5,
  "UPLEFT": 6,
  "DOWNRIGHT": 7,
  "DOWNLEFT": 8,
}

ALL_GAME_ACTION = {
  "NOOP": 0,
  "FIRE": 1,
  "UP": 2,
  "RIGHT": 3,
  "LEFT": 4,
  "DOWN": 5,
  "UPRIGHT": 6,
  "UPLEFT": 7,
  "DOWNRIGHT": 8,
  "DOWNLEFT": 9,
  "UPFIRE": 10,
  "RIGHTFIRE": 11,
  "LEFTFIRE": 12,
  "DOWNFIRE": 13,
  "UPRIGHTFIRE": 14,
  "UPLEFTFIRE": 15,
  "DOWNRIGHTFIRE": 16,
  "DOWNLEFTFIRE": 17,
}

def action_to_allgameaction(action):
  all_action = list(GAME_ACTIONS.keys())[list(GAME_ACTIONS.values()).index(action)]
  return ALL_GAME_ACTION[all_action]

def allgameaction_to_action(all_action):
  action = list(ALL_GAME_ACTION.keys())[list(ALL_GAME_ACTION.values()).index(all_action)]
  return GAME_ACTIONS[action]


In [None]:
# Breakout
#act_voc = 4
#context_length = 30
#batch_size = 128
#game_name = 'breakout-mixed-v4'
#target_reward = 90

# Pong
#act_voc = 6
#context_length = 50
#batch_size = 512
#game_name = pong-expert-v4'
#target_reward = 20

# Packman
act_voc = 18 # 18=all_game_actions 9=ms-pacman
context_length = 4 # default multi-game-DT
batch_size = 1 # 256 in paper, but we don't have enough GPU ram
game_name = 'ms-pacman-expert-v4'
target_reward = 1150 #180?

# Qbert
#act_voc = 6
#context_length = 30
#batch_size = 128
#game_name = 'qbert-expert-v4'
#target_reward = 14000

hparams = {
    'game_name': game_name,
    'n_layer':6, # num of blocks?
    'n_head': 8, #
    'n_embd':128,
    'context_length':context_length, # 30 in breakout, 50 in pong
    'dropout':0.1, # dropout value
    'act_voc':act_voc, # in breakout 4, 6 in Pong
    'batch_size': batch_size,
    'state_dim': 1*84*84, # no skacked images
    'max_timestep': 4096, # hardcoded default 4096
    'lr': 1e-4, # TODO adjust the lr for fine-tuning
    'wt_decay':1e-4,
    'warmup_steps':10000,
    'target_reward': target_reward, # 90=Breakout, Pong=20
    'max_epochs':5
}

MULTI-GAME DECISION TRANSFORMER


In [None]:
# https://github.com/rwightman/pytorch-image-models/blob/29fda20e6d428bf636090ab207bbcf60617570ca/timm/layers/weight_init.py#L99
def variance_scaling_(tensor: Tensor, scale=1.0, mode="fan_in", distribution="trunc_normal") -> Tensor:
    fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor)
    if mode == "fan_in":
        scale /= max(1.0, fan_in)
    elif mode == "fan_out":
        scale /= max(1.0, fan_out)
    elif mode == "fan_avg":
        scale /= max(1.0, (fan_in + fan_out) / 2.0)

    if distribution == "trunc_normal":
        stddev = np.sqrt(scale)
        # Adjust stddev for truncation.
        # Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
        stddev = stddev / 0.87962566103423978
        return nn.init.trunc_normal_(tensor, std=stddev)
    elif distribution == "normal":
        stddev = np.sqrt(scale)
        return nn.init.normal_(tensor, std=stddev)
    elif distribution == "uniform":
        limit = np.sqrt(3.0 * scale)
        return nn.init.uniform_(tensor, -limit, limit)
    else:
        raise ValueError(f"Invalid distribution: {distribution}")


def sample_from_logits(
    logits: Tensor,
    generator: Optional[torch.Generator] = None,
    deterministic: Optional[bool] = False,
    temperature: Optional[float] = 1e0,
    top_k: Optional[int] = None,
    top_percentile: Optional[float] = None,
) -> Tuple[Tensor, Tensor]:
    r"""Generate a categorical sample from given logits."""
    if deterministic:
        sample = torch.argmax(logits, dim=-1)
    else:
        if top_percentile is not None:
            # percentile: 0 to 100, quantile: 0 to 1
            percentile = torch.quantile(logits, top_percentile / 100, dim=-1)
            logits = torch.where(logits > percentile[..., None], logits, -np.inf)
        if top_k is not None:
            logits, top_indices = torch.topk(logits, top_k)
        sample = D.Categorical(logits=temperature * logits).sample()
        # probs = F.softmax(temperature * logits, dim=-1)
        # sample = torch.multinomial(probs, num_samples=1, generator=generator)
        if top_k is not None:
            sample_shape = sample.shape
            # Flatten top-k indices and samples for easy indexing.
            top_indices = torch.reshape(top_indices, [-1, top_k])
            sample = sample.flatten()
            sample = top_indices[torch.arange(len(sample)), sample]
            # Reshape samples back to original dimensions.
            sample = torch.reshape(sample, sample_shape)
    return sample


def encode_reward(rew: Tensor) -> Tensor:
    r"""Encode reward values into values expected by the model."""
    # 0: no reward   1: positive reward   2: terminal reward   3: negative reward
    rew = (rew > 0) * 1 + (rew < 0) * 3
    return rew.to(dtype=torch.int32)


def encode_return(ret: Tensor, ret_range: Tuple[int]) -> Tensor:
    r"""Encode (possibly negative) return values into discrete return tokens."""
    ret = ret.to(dtype=torch.int32)
    ret = torch.clip(ret, ret_range[0], ret_range[1]-1)
    ret = ret - ret_range[0]
    return ret

def cross_entropy(logits, labels):
    r"""Applies sparse cross entropy loss between logits and target labels."""
    labels = F.one_hot(labels.long(), logits.shape[-1]).to(dtype=logits.dtype)
    loss = -labels * F.log_softmax(logits, dim=-1)
    return torch.mean(loss)


def decode_return(ret: torch.Tensor, ret_range: Tuple[int]) -> torch.Tensor:
    ret = ret.to(dtype=torch.int32)
    ret = ret + ret_range[0]
    return ret

In [None]:
# DT code from https://github.com/nikhilbarhate99/min-decision-transformer
class MLP(nn.Module):
    r"""A 2-layer MLP which widens then narrows the input."""

    def __init__(
        self,
        in_dim: int,
        init_scale: float,
        widening_factor: int = 4,
    ):
        super().__init__()
        self._init_scale = init_scale
        self._widening_factor = widening_factor

        self.fc1 = nn.Linear(in_dim, self._widening_factor * in_dim)
        self.act = nn.GELU(approximate="tanh")
        self.fc2 = nn.Linear(self._widening_factor * in_dim, in_dim)

        self.reset_parameters()

    def reset_parameters(self):
        variance_scaling_(self.fc1.weight, scale=self._init_scale)
        nn.init.zeros_(self.fc1.bias)
        variance_scaling_(self.fc2.weight, scale=self._init_scale)
        nn.init.zeros_(self.fc2.bias)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x


class Attention(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        w_init_scale: Optional[float] = None,
        qkv_bias: bool = True,
        proj_bias: bool = True,
    ):
        super().__init__()
        assert dim % num_heads == 0, "dim should be divisible by num_heads"
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5
        self.w_init_scale = w_init_scale

        self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim, bias=proj_bias)

        self.reset_parameters()

    def reset_parameters(self):
        variance_scaling_(self.qkv.weight, scale=self.w_init_scale)
        if self.qkv.bias is not None:
            nn.init.zeros_(self.qkv.bias)
        variance_scaling_(self.proj.weight, scale=self.w_init_scale)
        if self.proj.bias is not None:
            nn.init.zeros_(self.proj.bias)

    def forward(self, x, mask: Optional[Tensor] = None) -> Tensor:
        B, T, C = x.shape
        qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        if mask is not None:
            mask_value = -torch.finfo(attn.dtype).max  # max_neg_value
            attn = attn.masked_fill(~mask.to(dtype=torch.bool), mask_value)

        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, T, C)
        x = self.proj(x)
        return x


class CausalSelfAttention(Attention):
    r"""Self attention with a causal mask applied."""

    def forward(
        self,
        x: Tensor,
        mask: Optional[Tensor] = None,
        custom_causal_mask: Optional[Tensor] = None,
        prefix_length: Optional[int] = 0,
    ) -> Tensor:
        if x.ndim != 3:
            raise ValueError("Expect queries of shape [B, T, D].")

        seq_len = x.shape[1]
        # If custom_causal_mask is None, the default causality assumption is
        # sequential (a lower triangular causal mask).
        causal_mask = custom_causal_mask
        if causal_mask is None:
            device = x.device
            causal_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=device))
        causal_mask = causal_mask[None, None, :, :]

        # Similar to T5, tokens up to prefix_length can all attend to each other.
        causal_mask[:, :, :, :prefix_length] = 1
        mask = mask * causal_mask if mask is not None else causal_mask

        return super().forward(x, mask)


class Block(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, init_scale: float, dropout_rate: float):
        super().__init__()
        self.ln_1 = nn.LayerNorm(embed_dim)
        self.attn = CausalSelfAttention(embed_dim, num_heads=num_heads, w_init_scale=init_scale)
        self.dropout_1 = nn.Dropout(dropout_rate)

        self.ln_2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, init_scale)
        self.dropout_2 = nn.Dropout(dropout_rate)

    def forward(self, x, **kwargs):
        x = x + self.dropout_1(self.attn(self.ln_1(x), **kwargs))
        x = x + self.dropout_2(self.mlp(self.ln_2(x)))
        return x


class Transformer(nn.Module):
    r"""A transformer stack."""

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        num_layers: int,
        dropout_rate: float,
    ):
        super().__init__()
        self._num_layers = num_layers
        self._num_heads = num_heads
        self._dropout_rate = dropout_rate

        init_scale = 2.0 / self._num_layers
        self.layers = nn.ModuleList([])
        for _ in range(self._num_layers):
            block = Block(embed_dim, num_heads, init_scale, dropout_rate)
            self.layers.append(block)
        self.norm_f = nn.LayerNorm(embed_dim)

    def forward(
        self,
        h: Tensor,
        mask: Optional[Tensor] = None,
        custom_causal_mask: Optional[Tensor] = None,
        prefix_length: Optional[int] = 0,
    ) -> Tensor:
        r"""Connects the transformer.

        Args:
        h: Inputs, [B, T, D].
        mask: Padding mask, [B, T].
        custom_causal_mask: Customized causal mask, [T, T].
        prefix_length: Number of prefix tokens that can all attend to each other.

        Returns:
        Array of shape [B, T, D].
        """
        if mask is not None:
            # Make sure we're not passing any information about masked h.
            h = h * mask[:, :, None]
            mask = mask[:, None, None, :]

        for block in self.layers:
            h = block(
                h,
                mask=mask,
                custom_causal_mask=custom_causal_mask,
                prefix_length=prefix_length,
            )
        h = self.norm_f(h)
        return h


class MultiGameDecisionTransformer(nn.Module):
    def __init__(
        self,
        img_size: Tuple[int],
        patch_size: Tuple[int],
        num_actions: int,
        num_rewards: int,
        return_range: Tuple[int],
        d_model: int,
        num_layers: int,
        dropout_rate: float,
        predict_reward: bool,
        single_return_token: bool,
        conv_dim: int,
        num_steps: int,
    ):
        super().__init__()

        # Expected by the transformer model.
        if d_model % 64 != 0:
            raise ValueError(f"Model size {d_model} must be divisible by 64")

        self.img_size = img_size   # (84,84)
        self.patch_size = patch_size # (16,16)
        self.num_actions = num_actions # 18
        self.num_rewards = num_rewards # 4
        self.num_returns = return_range[1] - return_range[0]  # [-20, 100] -> 120
        #print('self.num_returns ', self.num_returns)
        self.return_range = return_range
        self.d_model = d_model # 1280 embedd_dim
        self.predict_reward = predict_reward
        self.conv_dim = conv_dim
        self.single_return_token = single_return_token
        self.spatial_tokens = True
        self.num_steps = num_steps

        self.transformer = Transformer(
            embed_dim=self.d_model,
            num_heads=self.d_model // 64,
            num_layers=num_layers,
            dropout_rate=dropout_rate,
        )

        patch_height, patch_width = self.patch_size[0], self.patch_size[1]
        # If img_size=(84, 84), patch_size=(14, 14), then P = 84 / 14 = 6.
        self.image_emb = nn.Conv2d(
            in_channels=1,
            out_channels=self.d_model, # ???? 1280 channels out??
            kernel_size=(patch_height, patch_width),
            stride=(patch_height, patch_width),
            padding="valid",
        )  # image_emb is now [BT x D x P x P].
        patch_grid = (self.img_size[0] // self.patch_size[0], self.img_size[1] // self.patch_size[1]) # (84/14,84/14)->(6,6)
        num_patches = patch_grid[0] * patch_grid[1]
        self.image_pos_enc = nn.Parameter(torch.randn(1, 1, num_patches, self.d_model)) # ??? pos_enc randn?

        self.ret_emb = nn.Embedding(self.num_returns+1, self.d_model) # 121 -> embedd_dim
        self.act_emb = nn.Embedding(self.num_actions, self.d_model) # 18 -> embedd_dim
        if self.predict_reward:
            self.rew_emb = nn.Embedding(self.num_rewards, self.d_model) # 4 -> embedd_dim

        #num_steps = 4 # ?????
        num_obs_tokens = num_patches if self.spatial_tokens else 1 # 36 = 6*6
        if self.predict_reward:
            tokens_per_step = num_obs_tokens + 3    # 39
        else:
            tokens_per_step = num_obs_tokens + 2
        self.positional_embedding = nn.Parameter(torch.randn(tokens_per_step * self.num_steps, self.d_model)) # (159, 1280)

        # prediction heads
        self.ret_linear = nn.Linear(self.d_model, self.num_returns+1)
        self.act_linear = nn.Linear(self.d_model, self.num_actions)
        if self.predict_reward:
            self.rew_linear = nn.Linear(self.d_model, self.num_rewards)

    def sequence_loss(self, inputs: Mapping[str, Tensor], model_outputs: Mapping[str, Tensor]) -> Tensor:
        r"""Compute the loss on data wrt model outputs."""
        obj_pairs = self._objective_pairs(inputs, model_outputs)
        obj = [cross_entropy(logits, target) for logits, target in obj_pairs]
        return sum(obj) / len(obj)

    def reset_parameters(self):
        nn.init.trunc_normal_(self.image_emb.weight, std=0.02)
        nn.init.zeros_(self.image_emb.bias)
        nn.init.normal_(self.image_pos_enc, std=0.02)

        nn.init.trunc_normal_(self.ret_emb.weight, std=0.02)
        nn.init.trunc_normal_(self.act_emb.weight, std=0.02)
        if self.predict_reward:
            nn.init.trunc_normal_(self.rew_emb.weight, std=0.02)

        nn.init.trunc_normal_(self.positional_embedding, std=0.02)

        variance_scaling_(self.ret_linear.weight)
        nn.init.zeros_(self.ret_linear.bias)
        variance_scaling_(self.act_linear.weight)
        nn.init.zeros_(self.act_linear.bias)
        if self.predict_reward:
            variance_scaling_(self.rew_linear.weight)
            nn.init.zeros_(self.rew_linear.bias)

    def _image_embedding(self, image: Tensor):
        r"""Embed [B x T x C x W x H] images to tokens [B x T x output_dim] tokens.

        Args:
            image: [B x T x C x W x H] image to embed.

        Returns:
            Image embedding of shape [B x T x output_dim] or [B x T x _ x output_dim].
        """
        assert len(image.shape) == 5 # ha de tenir 5 dimensions
        image_dims = image.shape[-3:] # -> [C x W x H]
        batch_dims = image.shape[:2] # -> [B x T]

        # Reshape to [BT x C x H x W].
        image = torch.reshape(image, (-1,) + image_dims)
        # Perform any-image specific processing.
        image = image.to(dtype=torch.float32) / 255.0

        # split in patches
        image_emb = self.image_emb(image)  # [BT x D x P x P]
        # haiku.Conv2D is channel-last, so permute before reshape below for consistency
        image_emb = image_emb.permute(0, 2, 3, 1)  # [BT x P x P x D]
        # Reshape to [B x T x P*P x D].
        image_emb = torch.reshape(image_emb, batch_dims + (-1, self.d_model))
        image_emb = image_emb + self.image_pos_enc
        return image_emb

    def _embed_inputs(self, obs: Tensor, ret: Tensor, act: Tensor, rew: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        # Embed only prefix_frames first observations.
        # obs are [B x T x C x H x W].
        #print('obs ', obs.shape)
        obs_emb = self._image_embedding(obs) # output -> Image embedding of shape [B x T x output_dim] or [B x T x _ x output_dim].
        #print('obs_emb ', obs_emb.shape)
        # Embed returns and actions
        # Encode returns.
        #print('ret ', ret)
        ret = encode_return(ret, self.return_range)
        rew = encode_reward(rew)
        #print('ret ', ret)
        #print('max_ret ', torch.max(ret))
        act_emb = self.act_emb(act)
        ret_emb = self.ret_emb(ret)
        if self.predict_reward:
            rew_emb = self.rew_emb(rew)
        else:
            rew_emb = None
        return obs_emb, ret_emb, act_emb, rew_emb

    def forward(self, inputs: Mapping[str, Tensor]) -> Mapping[str, Tensor]:
        r"""Process sequence."""
        num_batch = inputs["actions"].shape[0]
        num_steps = inputs["actions"].shape[1]
        # Embed inputs.
        obs_emb, ret_emb, act_emb, rew_emb = self._embed_inputs(
            inputs["observations"],
            inputs["returns-to-go"],
            inputs["actions"],
            inputs["rewards"],
        )
        device = obs_emb.device

        #print('obs_emb0 ', obs_emb.shape)
        if self.spatial_tokens:
            # obs is [B x T x W x D]
            num_obs_tokens = obs_emb.shape[2]
            obs_emb = torch.reshape(obs_emb, obs_emb.shape[:2] + (-1,))
            # obs is [B x T x W*D]
        else:
            num_obs_tokens = 1
        # Collect sequence.
        # Embeddings are [B x T x D].

        #print('obs_emb ', obs_emb.shape)
        #print('ret_emb ', ret_emb.shape)
        #print('act_emb ', act_emb.shape)
        #print('rew_emb ', rew_emb.shape)

        if self.predict_reward:
            token_emb = torch.cat([obs_emb, ret_emb, act_emb, rew_emb], dim=-1)
            tokens_per_step = num_obs_tokens + 3
            # sequence is [obs ret act rew ... obs ret act rew]
        else:
            token_emb = torch.cat([obs_emb, ret_emb, act_emb], dim=-1)
            tokens_per_step = num_obs_tokens + 2
            # sequence is [obs ret act ... obs ret act]
        token_emb = torch.reshape(token_emb, [num_batch, tokens_per_step * num_steps, self.d_model])
        # Create position embeddings.

        token_emb = token_emb + self.positional_embedding
        # Run the transformer over the inputs.

        # Token dropout.
        batch_size = token_emb.shape[0]
        obs_mask = np.ones([batch_size, num_steps, num_obs_tokens], dtype=bool)
        ret_mask = np.ones([batch_size, num_steps, 1], dtype=bool)
        act_mask = np.ones([batch_size, num_steps, 1], dtype=bool)
        rew_mask = np.ones([batch_size, num_steps, 1], dtype=bool)

        # Mask out all return tokens expect the first one.
        ret_mask[:, 1:] = 0

        if self.predict_reward:
            mask = [obs_mask, ret_mask, act_mask, rew_mask]
        else:
            mask = [obs_mask, ret_mask, act_mask]
        mask = np.concatenate(mask, axis=-1)
        mask = np.reshape(mask, [batch_size, tokens_per_step * num_steps])
        mask = torch.tensor(mask, dtype=torch.bool, device=device)

        custom_causal_mask = None
        if self.spatial_tokens:
            # Temporal transformer by default assumes sequential causal relation.
            # This makes the transformer causal mask a lower triangular matrix.
            #     P1 P2 R  a  P1 P2 ... (Ps: image patches)
            # P1  1  0* 0  0  0  0
            # P2  1  1  0  0  0  0
            # R   1  1  1  0  0  0
            # a   1  1  1  1  0  0
            # P1  1  1  1  1  1  0*
            # P2  1  1  1  1  1  1
            # ... (0*s should be replaced with 1s in the ideal case)
            # But, when we have multiple tokens for an image (e.g. patch tokens, conv
            # feature map tokens, etc) as inputs to transformer, this assumption does
            # not hold, because there is no sequential dependencies between tokens.
            # Therefore, the ideal causal mask should not mask out tokens that belong
            # to the same images from each others.
            seq_len = token_emb.shape[1]
            sequential_causal_mask = np.tril(np.ones((seq_len, seq_len)))
            num_timesteps = seq_len // tokens_per_step
            num_non_obs_tokens = tokens_per_step - num_obs_tokens
            diag = [
                np.ones((num_obs_tokens, num_obs_tokens)) if i % 2 == 0 else np.zeros((num_non_obs_tokens, num_non_obs_tokens))
                for i in range(num_timesteps * 2)
            ]
            block_diag = scipy.linalg.block_diag(*diag)
            custom_causal_mask = np.logical_or(sequential_causal_mask, block_diag)
            custom_causal_mask = torch.tensor(custom_causal_mask, dtype=torch.bool, device=device)

        output_emb = self.transformer(token_emb, mask, custom_causal_mask)

        # Output_embeddings are [B x 3T x D].
        # Next token predictions (tokens one before their actual place).
        ret_pred = output_emb[:, (num_obs_tokens - 1) :: tokens_per_step, :]
        act_pred = output_emb[:, (num_obs_tokens - 0) :: tokens_per_step, :]
        embeds = torch.cat([ret_pred, act_pred], dim=-1)
        # Project to appropriate dimensionality.
        ret_pred = self.ret_linear(ret_pred)
        act_pred = self.act_linear(act_pred)
        # Return logits as well as pre-logits embedding.
        result_dict = {
            "embeds": embeds,
            "action_logits": act_pred,
            "return_logits": ret_pred,
        }
        if self.predict_reward:
            rew_pred = output_emb[:, (num_obs_tokens + 1) :: tokens_per_step, :]
            rew_pred = self.rew_linear(rew_pred)
            result_dict["reward_logits"] = rew_pred
        # Return evaluation metrics.
        result_dict["loss"] = self.sequence_loss(inputs, result_dict)
        #result_dict["accuracy"] = self.sequence_accuracy(inputs, result_dict)
        return result_dict

    def _objective_pairs(self, inputs: Mapping[str, Tensor], model_outputs: Mapping[str, Tensor]) -> Tensor:
        r"""Get logit-target pairs for the model objective terms."""
        act_target = inputs["actions"]
        ret_target = encode_return(inputs["returns-to-go"], self.return_range)
        act_logits = model_outputs["action_logits"]
        ret_logits = model_outputs["return_logits"]

        #single token return
        ret_target = ret_target[:, :1]
        ret_logits = ret_logits[:, :1, :]

        obj_pairs = [(act_logits, act_target), (ret_logits, ret_target)]
        if self.predict_reward:
            rew_target = encode_reward(inputs["rewards"])
            rew_logits = model_outputs["reward_logits"]
            obj_pairs.append((rew_logits, rew_target))
        return obj_pairs



    def optimal_action(
        self,
        inputs: Mapping[str, Tensor],
        return_range: Tuple[int] = (-100, 100),
        single_return_token: bool = True,
        opt_weight: Optional[float] = 0.0,
        num_samples: Optional[int] = 128,
        action_temperature: Optional[float] = 1.0,
        return_temperature: Optional[float] = 1.0,
        action_top_percentile: Optional[float] = None,
        return_top_percentile: Optional[float] = None,
        rng: Optional[torch.Generator] = None,
        deterministic: bool = False,
    ):
        r"""Calculate optimal action for the given sequence model."""
        logits_fn = self.forward
        obs, act, rew = inputs["observations"], inputs["actions"], inputs["rewards"]
        assert len(obs.shape) == 5
        assert len(act.shape) == 2
        inputs = {
            "observations": obs,
            "actions": act,
            "rewards": rew,
            "returns-to-go": torch.zeros_like(act),
        }
        sequence_length = obs.shape[1]
        # Use samples from the last timestep.
        timestep = -1
        # A biased sampling function that prefers sampling larger returns.
        def ret_sample_fn(rng, logits):
            assert len(logits.shape) == 2
            # Add optimality bias.
            if opt_weight > 0.0:
                # Calculate log of P(optimality=1|return) := exp(return) / Z.
                logits_opt = torch.linspace(0.0, 1.0, logits.shape[1])
                logits_opt = torch.repeat_interleave(logits_opt[None, :], logits.shape[0], dim=0)
                # Sample from log[P(optimality=1|return)*P(return)].
                logits = logits + opt_weight * logits_opt
            logits = torch.repeat_interleave(logits[None, ...], num_samples, dim=0)
            ret_sample = sample_from_logits(
                logits,
                generator=rng,
                deterministic=deterministic,
                temperature=return_temperature,
                top_percentile=return_top_percentile,
            )
            # Pick the highest return sample.
            ret_sample, _ = torch.max(ret_sample, dim=0)
            # Convert return tokens into return values.
            ret_sample = decode_return(ret_sample, return_range)
            return ret_sample

        # Set returns-to-go with an (optimistic) autoregressive sample.
        # Since only first return is used by the model, only sample that (faster).
        ret_logits = logits_fn(inputs)["return_logits"][:, 0, :]
        ret_sample = ret_sample_fn(rng, ret_logits)
        inputs["returns-to-go"][:, 0] = ret_sample

        # Generate a sample from action logits.
        act_logits = logits_fn(inputs)["action_logits"][:, timestep, :]
        act_sample = sample_from_logits(
            act_logits,
            generator=rng,
            deterministic=deterministic,
            temperature=action_temperature,
            top_percentile=action_top_percentile,
        )
        return act_sample


In [None]:
# MODEL TESTER

# CREATE THE MODEL
# maxTimestep = max steps in a trajectory
'''model_gpt = DecisionTransformer(
                state_dim=hparams['state_dim'],
                act_voc=hparams['act_voc'],
                n_blocks=hparams['n_layer'],
                h_dim=hparams['n_embd'],
                context_len=3,
                n_heads=hparams['n_head'],
                drop_p=hparams['dropout']
            )
model_gpt.to(device)'''

OBSERVATION_SHAPE = (84, 84)
PATCH_SHAPE = (14, 14)
NUM_ACTIONS = 18  # Maximum number of actions in the full dataset.
# rew=0: no reward, rew=1: score a point, rew=2: end game rew=3: lose a point
NUM_REWARDS = 4
RETURN_RANGE = [-20, 100]  # A reasonable range of returns identified in the dataset, quantized in 120 discrete values

model = MultiGameDecisionTransformer(
    img_size=OBSERVATION_SHAPE,
    patch_size=PATCH_SHAPE,
    num_actions=NUM_ACTIONS,
    num_rewards=NUM_REWARDS,
    return_range=RETURN_RANGE,
    d_model=1280,
    num_layers=10,
    dropout_rate=0.1,
    predict_reward=True,
    single_return_token=True,
    conv_dim=256,
    num_steps=4, # fixed sequence len in training
).to(device)


In [None]:

inputs = {}
#states_d  torch.Size([256, 4, 1, 84, 84])
#actions_d  torch.Size([256, 4])
#rtgs_d  torch.Size([256, 4])
#rewards_d  torch.Size([256, 4])

inputs["observations"] = torch.zeros((1,4, 1, 84, 84), dtype=torch.float32).to(device)
inputs["actions"] = torch.zeros((1,4), dtype=torch.long).to(device)
inputs["returns-to-go"] = 130*torch.ones((1,4), dtype=torch.float32).to(device)
inputs["rewards"] = torch.zeros((1,4), dtype=torch.float32).to(device)
#timesteps=torch.zeros((1, 3), dtype=torch.int64).to(device)
model.to(device)
model.forward(inputs)


LOAD MULTI-GAMR WEIGHTS

In [None]:
def load_jax_weights(model, model_params):
    def load_ln(m, k):
        m.weight.data = torch.from_numpy(model_params[k]["scale"])
        m.bias.data = torch.from_numpy(model_params[k]["offset"])

    def load_linear(m, k):
        m.weight.data = torch.from_numpy(model_params[k]["w"]).t()
        m.bias.data = torch.from_numpy(model_params[k]["b"])

    def load_attn(attn, k):
        qkv_w = np.concatenate(
            [
                model_params[k + "/query"]["w"],
                model_params[k + "/key"]["w"],
                model_params[k + "/value"]["w"],
            ],
            axis=-1,
        )
        attn.qkv.weight.data = torch.from_numpy(qkv_w).t()

        qkv_b = np.concatenate(
            [
                model_params[k + "/query"]["b"],
                model_params[k + "/key"]["b"],
                model_params[k + "/value"]["b"],
            ]
        )
        attn.qkv.bias.data = torch.from_numpy(qkv_b)

        load_linear(attn.proj, k + "/linear")

    def load_mlp(mlp, k):
        load_linear(mlp.fc1, k + "/linear")
        load_linear(mlp.fc2, k + "/linear_1")

    def load_transformer(transformer):
        prefix = "decision_transformer/~/sequence"
        for i in range(transformer._num_layers):
            block = transformer.layers[i]

            load_ln(block.ln_1, f"{prefix}/h{i}_ln_1")
            load_attn(block.attn, f"{prefix}/h{i}_attn")

            load_ln(block.ln_2, f"{prefix}/h{i}_ln_2")
            load_mlp(block.mlp, f"{prefix}/h{i}_mlp")
        load_ln(transformer.norm_f, f"{prefix}/ln_f")

    def load_embedding(m, k):
        m.weight.data = torch.from_numpy(model_params[k]["embeddings"])

    def load_image_emb(m, k):
        # [H x W x Cin x Cout] -> [Cout, Cin, H, W]
        m.weight.data = torch.from_numpy(model_params[k]["w"]).permute(3, 2, 0, 1)
        m.bias.data = torch.from_numpy(model_params[k]["b"])

    # --- Load transformer

    load_transformer(model.transformer)

    # --- Load model

    load_linear(model.act_linear, "decision_transformer/act_linear")
    load_linear(model.ret_linear, "decision_transformer/ret_linear")
    if model.predict_reward:
        load_linear(model.rew_linear, "decision_transformer/rew_linear")

    model.image_pos_enc = nn.Parameter(torch.tensor(model_params["decision_transformer"]["image_pos_enc"]))
    model.positional_embedding = nn.Parameter(torch.tensor(model_params["decision_transformer"]["positional_embeddings"]))

    load_image_emb(model.image_emb, "decision_transformer/~_embed_inputs/image_emb")

    load_embedding(model.ret_emb, "decision_transformer/~_embed_inputs/embed")
    load_embedding(model.act_emb, "decision_transformer/~_embed_inputs/embed_1")
    if model.predict_reward:
        load_embedding(model.rew_emb, "decision_transformer/~_embed_inputs/embed_2")


In [None]:
# get Multi-game checkpoints from owr gdrive
!pip install --upgrade gdown
!gdown "https://drive.google.com/uc?id=1tFB_hmhztUtm9SQdIEMgipkI6SJLqJP3"
!gdown "https://drive.google.com/uc?id=1EXx38g6p5J-XH_zP0Km1mCjKxCZbQ6M_"

In [None]:
#load multi-game checkpoints
model_params = torch.load("/content/drive/MyDrive/TMP/model_params.pt")
model_state = torch.load("/content/drive/MyDrive/TMP/model_state.pt")

load_jax_weights(model, model_params)
model.to(device)

SET THE ENVIRONMENT

In [None]:
import gym
import d4rl_atari

env = gym.make(hparams['game_name'])

print(env.action_space)

env.reset() # (1, 84, 84)



SETUP THE DATASET of our game

In [None]:
def create_dataset():

      # GET THE DATA !!
      dataset = env.get_dataset()

      obs_data = dataset['observations'] # observation data in (1000000, 1, 84, 84)
      action_data = dataset['actions'] # action data in (1000000,)
      reward_data = dataset['rewards'] # reward data in (1000000,)
      terminal_data = dataset['terminals'] # terminal flags in (1000000,)

      plt.imshow(obs_data[1000][0])
      plt.show()

      terminal_pos = np.where(terminal_data==1)[0]
      terminal_data = None # de-allocate mem
      print("num episodes ", terminal_pos.shape)

      # -- create reward-to-go dataset
      start_index = 0
      rtg = np.zeros_like(reward_data)
      for i in terminal_pos:
          i = int(i)
          curr_traj_returns = reward_data[start_index:i]
          reward_acum = 0
          for j in range(i-1, start_index-1, -1): # start from i-1
              reward_acum += reward_data[j]
              #rtg_j = curr_traj_returns[j-start_index:i-start_index]
              rtg[j] = reward_acum
          start_index = i
      print('max rtg is %d' % max(rtg))

      #reward_data = None

      # -- create timestep dataset ******************************
      start_index = 0
      timesteps = np.zeros(len(action_data), dtype=int)
      for i in terminal_pos:
          timesteps[start_index:i] = np.arange(i - start_index)
          start_index = i

      # convert game_ation numbers into all_game_action numbers
      print(len(action_data))
      print(action_data[0])
      for i in range(len(action_data)):
        action_data[i] = action_to_allgameaction(action_data[i])

      max_timestep = max(timesteps)
      print('max timestep is %d' % max_timestep)
      print("***** data loaded **********")


      return obs_data, action_data, terminal_pos, rtg, timesteps, max_timestep, reward_data



In [None]:
# Class that picks up a block of data from the dataset
class StateActionReturnDataset(Dataset):

    def __init__(self, obs, actions, done_idxs, rtgs, timesteps, rewards):
        self.context_length = hparams['context_length']
        self.obs = obs
        self.actions = actions
        self.rewards = rewards
        self.done_idxs = done_idxs
        self.rtgs = rtgs
        self.timesteps = timesteps

    def __len__(self):
        return len(self.obs) - self.context_length * 3

    def __getitem__(self, idx):
        # to avoid blocks in between of 2 trajectories, if the idx is too close to the end of a trajectory, re-position
        # the idx to a context_length away to the end of the trajectory
        done_idx = idx + self.context_length
        for i in self.done_idxs:
            if i > idx: # first done_idx greater than idx
                done_idx = min(int(i), done_idx)
                break
        idx = done_idx - self.context_length
        if idx < 0:
            idx = 0
            done_idx = idx + self.context_length
        states = torch.tensor(np.array(self.obs[idx:done_idx]), dtype=torch.float32) #.reshape(self.context_length, -1) # (self.context_length, 4*84*84)
        states = states / 255. # normalize obs
        actions = torch.tensor(self.actions[idx:done_idx], dtype=torch.long)#.unsqueeze(1) # (self.context_length, 1)
        rtgs = torch.tensor(self.rtgs[idx:done_idx], dtype=torch.float32)#.unsqueeze(1)
        rewards = torch.tensor(self.rewards[idx:done_idx], dtype=torch.float32)#.unsqueeze(1)
        #timesteps = torch.tensor(self.timesteps[idx:done_idx], dtype=torch.int64).unsqueeze(1) # (1,1)

        return states, actions, rtgs, rewards

In [None]:
# CREATE THE DATASET
obss, actions, done_idxs, rtgs, timesteps, maxTimestep, rewards = create_dataset()
#hparams['max_timestep'] = maxTimestep

# CREATE A CLASS FOR THE DATALOADER TO GET DATA
train_dataset = StateActionReturnDataset(obss, actions, done_idxs, rtgs, timesteps, rewards)


TRAINING

In [None]:
def get_action(
    inputs,
    model,
    return_range,
    single_return_token,
    opt_weight: Optional[float] = 0.0,
    num_samples: Optional[int] = 128,
    action_temperature: Optional[float] = 1.0,
    return_temperature: Optional[float] = 1.0,
    action_top_percentile: Optional[float] = None,
    return_top_percentile: Optional[float] = None,
):
    obs, act, rew = inputs['observations'], inputs['actions'], inputs['rewards']
    assert len(obs.shape) == 5
    assert len(act.shape) == 2
    act = act[:, -1].unsqueeze(1)
    inputs['actions'] = act
    inputs['rewards'] = rew[:, -1].unsqueeze(1)
    inputs['returns-to-go'] = torch.zeros_like(act)
    seq_len = obs.shape[1]
    timesteps = -1

    def ret_sample_fn(logits):
        assert len(logits.shape) == 2
        # Add optimality bias
        if opt_weight > 0.0:
            # Calculate log of P(optimality|return) = exp(return)/Z
            logits_opt = torch.linspace(0., 1., logits.shape[1])
            logits_opt = torch.repeat_interleave(
                logits_opt.unsqueeze(0), logits.shape[0], dim=0)
            # Sample from log[P(optimality=1|return)*P(return)]
            logits = logits + opt_weight * logits_opt
        logits = torch.repeat_interleave(
            logits.unsqueeze(0), num_samples, dim=0)
        ret_sample = sample_from_logits(
            logits, temperature=return_temperature, top_percentile=return_top_percentile)
        # pick the highest return sample
        ret_sample = torch.max(ret_sample)
        # ret_sample = torch.max(ret_sample, dim=0)
        # Convert return tokens into return values
        ret_sample = decode_return(ret_sample, return_range)
        return ret_sample

    with torch.no_grad():
        if single_return_token:
            ret_logits = model(inputs)['return_logits'][:, 0, :]
            ret_sample = ret_sample_fn(ret_logits)
            inputs['returns-to-go'][:, 0] = ret_sample
        else:
            # Auto-regressively regenerate all return tokens in a sequence
            def ret_logits_fn(ipts): return model(ipts)['return_logits']
            ret_sample = autoregressive_generate(inputs, ret_logits_fn, 'returns-to-go', seq_len, ret_sample_fn)
            inputs['returns-to-go'] = ret_sample

        # Generate a sample from action logits
        act_logits = model(inputs)['action_logits'][:, timesteps, :]
        act_sample = sample_from_logits(
            act_logits, temperature=action_temperature,
            top_percentile=action_top_percentile)
    return act_sample


In [None]:
from tqdm import tqdm

def trainer(model, dataloader):

        #optimizer = torch.optim.AdamW(model.parameters(), lr=tparams['learning_rate'], betas=tparams['betas'])
        optimizer = torch.optim.AdamW(
                        model.parameters(),
                        lr=hparams['lr'],
                        weight_decay=hparams['wt_decay']
                    )
        scheduler = torch.optim.lr_scheduler.LambdaLR(
                        optimizer,
                        lambda steps: min((steps+1)/hparams['warmup_steps'], 1)
                        )


        #self.tokens = 0 # counter used for learning rate decay

        counter = 0

        losses = []
        losses_1000 = []
        for epoch in range(hparams['max_epochs']):

            model.to(device)
            model.train()

            loader = DataLoader(dataloader, shuffle=True, pin_memory=True, # pin_memory ???????
                                batch_size=hparams['batch_size'])

            pbar = tqdm(enumerate(loader), total=len(loader))
            for it, (states, actions, rtgs, rewards) in pbar:

                # place data on the correct device
                states = states.to(device) # size([B, seq_len, state_dim]) state_dim = 28224
                #print('states_d ', states.shape)
                actions = actions.to(device) # size([B, seq_len, 1])
                #print('acrtions_d ', actions.shape)
                rtgs = rtgs.to(device) # size([B, seq_len, 1])
                #print('rtgs_d ', rtgs.shape)
                rewards = rewards.to(device) #   size([B, seq_len])
                #print('rewards_d ', rewards.shape)
                #action_target = torch.clone(actions).detach().to(device)

                inputs = {'observations': states,
                          'returns-to-go': rtgs,
                          'actions': actions,
                          'rewards': rewards}

                # forward the model
                with torch.set_grad_enabled(True):
                  result_dict= model.forward(inputs=inputs)
                  loss = result_dict['loss']
                  # only consider non padded elements
                  #action_preds = action_preds.view(-1, 1) # act_dim=1 #[traj_mask.view(-1,) > 0]
                  #action_target = action_target.type(torch.long).squeeze(-1).view(-1, 1) #act_dim=1 #[traj_mask.view(-1,) > 0]
                                    #loss = F.cross_entropy(action_preds.reshape(-1, action_preds.size(-1)), action_target.reshape(-1).long())
                  losses.append(loss.detach().cpu().item())

                  # backprop and update the parameters
                  optimizer.zero_grad()
                  loss.backward()
                  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
                  optimizer.step()
                  scheduler.step()

                    # report progress
                pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {hparams['lr']:e}")

                losses_1000.append(loss.detach().cpu().item())
                # loss info -> tensorboard
                if it % 1000 == 999:    # every 1000 mini-batches...
                  counter += 1
                  writer.add_scalar('training loss',float(np.mean(losses_1000)), counter)
                  losses_1000 = []

                    # evaluate action accuracy
            #evaluate_on_env(model, hparams['target_reward'])

        return losses


In [None]:
# INSTANTIATE THE MODEL
OBSERVATION_SHAPE = (84, 84)
PATCH_SHAPE = (14, 14)
NUM_ACTIONS = 18  # Maximum number of actions in the full dataset.
# rew=0: no reward, rew=1: score a point, rew=2: end game rew=3: lose a point
NUM_REWARDS = 4
RETURN_RANGE = [-20, 100]  # A reasonable range of returns identified in the dataset, quantized in 120 discrete values

model = MultiGameDecisionTransformer(
    img_size=OBSERVATION_SHAPE,
    patch_size=PATCH_SHAPE,
    num_actions=NUM_ACTIONS,
    num_rewards=NUM_REWARDS,
    return_range=RETURN_RANGE,
    d_model=1280,
    num_layers=10,
    dropout_rate=0.1,
    predict_reward=True,
    single_return_token=True,
    conv_dim=256,
    num_steps=4, # fixed sequence len in training
).to(device)

In [None]:
# TRAIN THE MODEL
losses = trainer(model, train_dataset)
