In [None]:
import random
from dataclasses import dataclass

import numpy as np
import torch

from datasets import load_dataset
from huggingface_hub import interpreter_login, list_models
from transformers import DecisionTransformerConfig, DecisionTransformerModel, Trainer, TrainingArguments

In [None]:
interpreter_login()

## Configs

In [None]:
envs = {
    0: "walker2d-expert-v2",
    1: "halfcheetah-expert-v2",
}

chosen_env = envs[0]

## Loading and exploring the dataset: halfcheetah (expert)

Some notes:
* This is a multi-dimensional, continuous environment. States are represented by 17 continuous dimensions; actions are represented by 7 continuous dimensions.
* The state space includes the positions and velocities of multiple body parts of the robotic cheetah, which are continuous, unbounded, real-valued quantities.
* The action space consists of torques applied to the joints, which are real-valued and thus continuous. They are however limited to the interval [-1, 1]. 

For more details: https://www.gymlibrary.dev/environments/mujoco/half_cheetah/

In [None]:
dataset = load_dataset("edbeeching/decision_transformer_gym_replay", chosen_env)

In [None]:
print("Dataset elements: ", dataset['train'][0].keys())
print("Number of steps: ", len(dataset['train'][0]['observations']))
print("Size of state representation: ", len(dataset['train'][0]['observations'][0]))
print("Size of action representation: ", len(dataset['train'][0]['actions'][0]))
print("Reward type: ", type(dataset['train'][0]['rewards'][0]))
print("Done flag: ", type(dataset['train'][0]['dones'][0]))

In [None]:
dataset['train'][0]['observations'][0]

In [None]:
dataset['train'][0]['actions'][0]

## Processing the dataset

In [None]:
RETURNS_SCALE = 1000.0
CONTEXT_SIZE = 20

While most datasets on the hub are ready to use out of the box, sometimes we wish to perform some additional processing or modification of the dataset. 

In this case we wish to match the author's implementation (from the original paper), that is we need to:
* Normalize each feature by subtracting the mean and dividing by the standard deviation.
* Pre-compute discounted returns for each trajectory.
* Scale the rewards and returns by a factor of 1000.
* Augment the dataset sampling distribution so it takes into account the length of the expert agent’s trajectories.

In order to perform this dataset preprocessing, we will use a custom Data Collator.

In [None]:
@dataclass
class DecisionTransformerGymDataCollator:
    return_tensors: str = "pt"  # pytorch tensors
    context_size: int = 1  # length of trajectories we use in training
    state_dim: int = 1  # size of state space
    act_dim: int = 1  # size of action space
    max_ep_len: int = 9999999  # max episode length in the dataset
    scale: float = 1  # normalisation of rewards/returns
    state_mean: np.array = None  # to store state means
    state_std: np.array = None  # to store state stds
    p_sample: np.array = None  # a distribution weighing episodes by trajectory lengths
    n_traj: int = 0  # to store the number of trajectories in the dataset

    def __init__(self, dataset, context_size, returns_scale):
        self.dataset = dataset
        # get dataset-specific features
        self.max_ep_len = max([len(traj["rewards"]) for traj in dataset])
        self.act_dim = len(dataset[0]["actions"][0])
        self.state_dim = len(dataset[0]["observations"][0])
        self.context_size = context_size
        self.returns_scale = returns_scale
        # collect some statistics about the dataset
        states = []
        traj_lens = []
        for obs in dataset["observations"]:
            states.extend(obs)
            traj_lens.append(len(obs))
        traj_lens = np.array(traj_lens)
        states = np.vstack(states)
        # use stats to produce normalisation constants
        self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-8
        self.n_traj = traj_lens.shape[0]
        self.p_sample = traj_lens / sum(traj_lens)

    def __call__(self, features):

        def _discount_cumsum(x, gamma):
            # return-to-go calculation
            discount_cumsum = np.zeros_like(x)
            discount_cumsum[-1] = x[-1]
            for t in reversed(range(x.shape[0] - 1)):
                discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1]
            return discount_cumsum

        # FIXME this is a bit of a hack to be able to sample from a non-uniform distribution
        # the idea is that we re-sample with replacement from the dataset rather than just taking the batch
        # this also means we can sample according to the length of the trajectories
        batch_inds = np.random.choice(
            np.arange(self.n_traj),
            size=len(features),
            replace=True,
            p=self.p_sample,  # reweights so we sample according to timesteps
        )

        # a batch of dataset features
        s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []
        
        for ind in batch_inds:
            traj = self.dataset[int(ind)]
            start = random.randint(0, len(traj["rewards"]) - 1)  # FIXME we are again randomising which feels dumb

            # get sequences from the dataset
            s.append(np.array(traj["observations"][start : start + self.context_size]).reshape(1, -1, self.state_dim))
            a.append(np.array(traj["actions"][start : start + self.context_size]).reshape(1, -1, self.act_dim))
            r.append(np.array(traj["rewards"][start : start + self.context_size]).reshape(1, -1, 1))
            d.append(np.array(traj["dones"][start : start + self.context_size]).reshape(1, -1))
            timesteps.append(np.arange(start, start + s[-1].shape[1]).reshape(1, -1))
            # FIXME feels hacky/dumb/unnecessary timesteps[-1][timesteps[-1] >= self.max_ep_len] = self.max_ep_len - 1  # padding cutoff
            rtg.append(
                _discount_cumsum(np.array(traj["rewards"][start:]), gamma=1.0)[
                    : s[-1].shape[1]  # TODO check the +1 removed here
                ].reshape(1, -1, 1)
            )

            # FIXME hacky... can't see the purpose; could be tied to +1 removed above
            # if rtg[-1].shape[1] < s[-1].shape[1]:
            #     rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)

            # normalising and padding
            tlen = s[-1].shape[1]
            padlen = self.context_size - tlen
            
            s[-1] = (s[-1] - self.state_mean) / self.state_std
            s[-1] = np.concatenate(
                [np.zeros((1, padlen, self.state_dim)) * 1.0, s[-1]], 
                axis=1,
            )
            
            a[-1] = np.concatenate(
                [np.ones((1, padlen, self.act_dim)) * -10.0, a[-1]],
                axis=1,
            )

            r[-1] = np.concatenate(
                [np.zeros((1, padlen, 1)) * 1.0, r[-1]], 
                axis=1,
            )

            d[-1] = np.concatenate(
                [np.ones((1, padlen)) * 2.0, d[-1]], 
                axis=1,
            )

            rtg[-1] /= self.scale
            rtg[-1] = np.concatenate(
                [np.zeros((1, padlen, 1)) * 1.0, rtg[-1]], 
                axis=1,
            ) 

            timesteps[-1] = np.concatenate([np.zeros((1, padlen)), timesteps[-1]], axis=1)

            # masking: disregard padded values
            mask.append(np.concatenate([np.zeros((1, padlen)), np.ones((1, tlen))], axis=1))

        s = torch.from_numpy(np.concatenate(s, axis=0)).float()
        a = torch.from_numpy(np.concatenate(a, axis=0)).float()
        r = torch.from_numpy(np.concatenate(r, axis=0)).float()
        d = torch.from_numpy(np.concatenate(d, axis=0))
        rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).float()
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).long()
        mask = torch.from_numpy(np.concatenate(mask, axis=0)).float()

        return {
            "states": s,
            "actions": a,
            "rewards": r,
            "returns_to_go": rtg,
            "timesteps": timesteps,
            "attention_mask": mask,
        }

## Create a trainable Decision Transformer (HF is not trainable by default)

In [None]:
class TrainableDT(DecisionTransformerModel):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, **kwargs):
        output = super().forward(**kwargs)

        # add the DT loss; applied only to non-padding values in action head
        action_targets = kwargs["actions"]
        attention_mask = kwargs["attention_mask"]
        action_preds = output[1]
        act_dim = action_preds.shape[2]
        
        action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        action_targets = action_targets.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]

        return {"loss": torch.mean((action_preds - action_targets) ** 2)}

    def original_forward(self, **kwargs):
        return super().forward(**kwargs)

## Setting up the model

In [None]:
# putting together the model we just built
collator = DecisionTransformerGymDataCollator(dataset["train"], context_size=CONTEXT_SIZE, returns_scale=RETURNS_SCALE)
config = DecisionTransformerConfig(state_dim=collator.state_dim, 
                                   act_dim=collator.act_dim,
                                   max_ep_len=collator.max_ep_len,
                                   context_size=collator.context_size,
                                   state_mean=list(collator.state_mean),
                                   state_std=list(collator.state_std),
                                   scale=collator.scale,)
model = TrainableDT(config)

## Training the model

In [None]:
# just naming stuff
my_env_name = "dt-" + chosen_env.split("-")[0]
models = sorted([m.modelId.split("/")[-1] for m in list_models(author="afonsosamarques")])
models = [m for m in models if my_env_name in m]
if len(models) > 0:
    latest_version = [m.split("-")[-1][1:] for m in models][-1]
    new_version = "v" + str(int(latest_version) + 1)
else:
    new_version = "v0"
model_name = my_env_name + "-" + new_version

In [None]:
# we use the same hyperparameters are in the authors original implementation, but train for fewer iterations
training_args = TrainingArguments(
    output_dir=model_name,
    remove_unused_columns=False,
    num_train_epochs=250,
    per_device_train_batch_size=64,
    learning_rate=1e-4,
    weight_decay=1e-4,
    warmup_ratio=0.1,
    optim="adamw_torch",
    max_grad_norm=0.25,
    use_mps_device=True,
    report_to="none",
    push_to_hub=True,
    hub_model_id=model_name,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    data_collator=collator,
)

trainer.train()
trainer.save_model()
trainer.push_to_hub()