# Training Decision Transformers with 🤗 transformers

In this tutorial, **you’ll learn to train your first Offline Decision Transformer model from scratch to make a half-cheetah run.** 🏃

❓ If you have questions, please post them on #study-group discord channel 👉 https://discord.gg/aYka4Yhff9

🎮 Environments:
- [Half Cheetah](https://www.gymlibrary.dev/environments/mujoco/half_cheetah/)

⬇️ Here's what you'll achieve at the end of this tutorial. ⬇️

### Prerequisites 🏗️
Before diving into the notebook, you need to:

🔲 📚 [Read the tutorial](https://huggingface.co/blog/train-decision-transformers)

In [1]:
import os
import random
from dataclasses import dataclass

import numpy as np
import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import DecisionTransformerConfig, Trainer, TrainingArguments, TrainerCallback
from models.decision_mamba import TrainableDT, TrainableDM
from evaluation.evaluate_episodes import evaluate_episode_rtg


### Step 3: Loading the dataset from the 🤗 Hub and instantiating the model

We host a number of Offline RL Datasets on the hub. Today we will be training with the halfcheetah “expert” dataset, hosted here on hub.

First we need to import the load_dataset function from the 🤗 datasets package and download the dataset to our machine.

In [2]:
os.environ["WANDB_DISABLED"] = "true" # we diable weights and biases logging for this tutorial
dataset = load_dataset("edbeeching/decision_transformer_gym_replay", "halfcheetah-expert-v2")


### Step 4: Defining a custom DataCollator for the transformers Trainer class

In [3]:
@dataclass
class DecisionTransformerGymDataCollator:
    return_tensors: str = "pt"
    max_len: int = 20 #subsets of the episode we use for training
    state_dim: int = 17  # size of state space
    act_dim: int = 6  # size of action space
    max_ep_len: int = 1000 # max episode length in the dataset
    scale: float = 1000.0  # normalization of rewards/returns
    state_mean: np.array = None  # to store state means
    state_std: np.array = None  # to store state stds
    p_sample: np.array = None  # a distribution to take account trajectory lengths
    n_traj: int = 0 # to store the number of trajectories in the dataset

    def __init__(self, dataset) -> None:
        self.act_dim = len(dataset[0]["actions"][0])
        self.state_dim = len(dataset[0]["observations"][0])
        self.dataset = dataset
        # calculate dataset stats for normalization of states
        states = []
        traj_lens = []
        for obs in dataset["observations"]:
            states.extend(obs)
            traj_lens.append(len(obs))
        self.n_traj = len(traj_lens)
        states = np.vstack(states)
        self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
        
        traj_lens = np.array(traj_lens)
        self.p_sample = traj_lens / sum(traj_lens)

    def _discount_cumsum(self, x, gamma):
        discount_cumsum = np.zeros_like(x)
        discount_cumsum[-1] = x[-1]
        for t in reversed(range(x.shape[0] - 1)):
            discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
        return discount_cumsum

    def __call__(self, features):
        batch_size = len(features)
        # this is a bit of a hack to be able to sample of a non-uniform distribution
        batch_inds = np.random.choice(
            np.arange(self.n_traj),
            size=batch_size,
            replace=True,
            p=self.p_sample,  # reweights so we sample according to timesteps
        )
        # a batch of dataset features
        s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []
        
        for ind in batch_inds:
            # for feature in features:
            feature = self.dataset[int(ind)]
            si = random.randint(0, len(feature["rewards"]) - 1)

            # get sequences from dataset
            s.append(np.array(feature["observations"][si : si + self.max_len]).reshape(1, -1, self.state_dim))
            a.append(np.array(feature["actions"][si : si + self.max_len]).reshape(1, -1, self.act_dim))
            r.append(np.array(feature["rewards"][si : si + self.max_len]).reshape(1, -1, 1))

            d.append(np.array(feature["dones"][si : si + self.max_len]).reshape(1, -1))
            timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))
            timesteps[-1][timesteps[-1] >= self.max_ep_len] = self.max_ep_len - 1  # padding cutoff
            rtg.append(
                self._discount_cumsum(np.array(feature["rewards"][si:]), gamma=1.0)[
                    : s[-1].shape[1]   # TODO check the +1 removed here
                ].reshape(1, -1, 1)
            )
            if rtg[-1].shape[1] < s[-1].shape[1]:
                print("if true")
                rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)

            # padding and state + reward normalization
            tlen = s[-1].shape[1]
            s[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, self.state_dim)), s[-1]], axis=1)
            s[-1] = (s[-1] - self.state_mean) / self.state_std
            a[-1] = np.concatenate(
                [np.ones((1, self.max_len - tlen, self.act_dim)) * -10.0, a[-1]],
                axis=1,
            )
            r[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, 1)), r[-1]], axis=1)
            d[-1] = np.concatenate([np.ones((1, self.max_len - tlen)) * 2, d[-1]], axis=1)
            rtg[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, 1)), rtg[-1]], axis=1) / self.scale
            timesteps[-1] = np.concatenate([np.zeros((1, self.max_len - tlen)), timesteps[-1]], axis=1)
            mask.append(np.concatenate([np.zeros((1, self.max_len - tlen)), np.ones((1, tlen))], axis=1))

        s = torch.from_numpy(np.concatenate(s, axis=0)).float()
        a = torch.from_numpy(np.concatenate(a, axis=0)).float()
        r = torch.from_numpy(np.concatenate(r, axis=0)).float()
        d = torch.from_numpy(np.concatenate(d, axis=0))
        rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).float()
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).long()
        mask = torch.from_numpy(np.concatenate(mask, axis=0)).float()

        return {
            "states": s,
            "actions": a,
            "rewards": r,
            "returns_to_go": rtg,
            "timesteps": timesteps,
            "attention_mask": mask,
        }

### Step 5: Extending the Decision Transformer Model to include a loss function

In order to train the model with the 🤗 trainer class, we first need to ensure the dictionary it returns contains a loss, in this case L-2 norm of the models action predictions and the targets.

In [14]:
collator = DecisionTransformerGymDataCollator(dataset["train"])

config = DecisionTransformerConfig(state_dim=collator.state_dim, act_dim=collator.act_dim)
model = TrainableDT(config)

def num_params(model):
    return sum(p.numel() for p in model.parameters())

num_params(model)

1257368

### Step 6: Defining the training hyperparameters and training the model
Here, we define the training hyperparameters and our Trainer class that we'll use to train our Decision Transformer model.

This step takes about an hour, so you may leave it running. Note the authors train for at least 3 hours, so the results presented here are not as performant as the models hosted on the 🤗 hub.

In [15]:
# these params more or less match the ones used by the original DT paper

training_args = TrainingArguments(
    output_dir="output/",
    remove_unused_columns=False,
    #num_train_epochs=120,
    max_steps=100_000,
    logging_strategy='steps',
    save_strategy='no',
    logging_steps=5000,
    warmup_steps=0,
    per_device_train_batch_size=64,
    learning_rate=1e-4,
    weight_decay=1e-4,
    #warmup_ratio=0.1,
    optim="adamw_torch",
    dataloader_num_workers=12,
    max_grad_norm=0.25,
    tf32=True,
    bf16=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    data_collator=collator,
    callbacks=[EvaluateCallback()],
)

trainer.train()

evaluate_episodes(5, model)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Step,Training Loss
5000,0.0427
10000,0.0319
15000,0.0302
20000,0.0294
25000,0.0288
30000,0.0284
35000,0.0282
40000,0.0279
45000,0.0276
50000,0.0274


100%|██████████| 5/5 [00:11<00:00,  2.25s/it]

Epoch 1.0 eval: {'target_12.0_return_mean': -24.388857590408495, 'target_12.0_return_std': 0.3785924929925912, 'target_12.0_length_mean': -24.388857590408495, 'target_12.0_length_std': 0.3785924929925912}



100%|██████████| 5/5 [00:11<00:00,  2.24s/it]

Epoch 501.0 eval: {'target_12.0_return_mean': 10184.166463494808, 'target_12.0_return_std': 1499.0420624377882, 'target_12.0_length_mean': 10184.166463494808, 'target_12.0_length_std': 1499.0420624377882}



100%|██████████| 5/5 [00:11<00:00,  2.20s/it]

Epoch 1001.0 eval: {'target_12.0_return_mean': 11110.21172932327, 'target_12.0_return_std': 186.07299326008567, 'target_12.0_length_mean': 11110.21172932327, 'target_12.0_length_std': 186.07299326008567}



100%|██████████| 5/5 [00:11<00:00,  2.24s/it]

Epoch 1501.0 eval: {'target_12.0_return_mean': 11132.929309485986, 'target_12.0_return_std': 111.95761758702164, 'target_12.0_length_mean': 11132.929309485986, 'target_12.0_length_std': 111.95761758702164}



100%|██████████| 5/5 [00:11<00:00,  2.22s/it]

Epoch 2001.0 eval: {'target_12.0_return_mean': 11063.84870537514, 'target_12.0_return_std': 226.18844467243147, 'target_12.0_length_mean': 11063.84870537514, 'target_12.0_length_std': 226.18844467243147}



100%|██████████| 5/5 [00:11<00:00,  2.24s/it]

Epoch 2501.0 eval: {'target_12.0_return_mean': 10955.150628678279, 'target_12.0_return_std': 79.64473145686183, 'target_12.0_length_mean': 10955.150628678279, 'target_12.0_length_std': 79.64473145686183}



100%|██████████| 5/5 [00:11<00:00,  2.22s/it]

Epoch 3001.0 eval: {'target_12.0_return_mean': 11088.433584288121, 'target_12.0_return_std': 117.3586700046007, 'target_12.0_length_mean': 11088.433584288121, 'target_12.0_length_std': 117.3586700046007}



100%|██████████| 5/5 [00:11<00:00,  2.23s/it]

Epoch 3501.0 eval: {'target_12.0_return_mean': 10933.639805734687, 'target_12.0_return_std': 136.21261398159734, 'target_12.0_length_mean': 10933.639805734687, 'target_12.0_length_std': 136.21261398159734}



100%|██████████| 5/5 [00:11<00:00,  2.22s/it]

Epoch 4001.0 eval: {'target_12.0_return_mean': 11271.795629340548, 'target_12.0_return_std': 219.80849500849882, 'target_12.0_length_mean': 11271.795629340548, 'target_12.0_length_std': 219.80849500849882}



100%|██████████| 5/5 [00:11<00:00,  2.22s/it]

Epoch 4501.0 eval: {'target_12.0_return_mean': 11211.743321033076, 'target_12.0_return_std': 108.22947529685659, 'target_12.0_length_mean': 11211.743321033076, 'target_12.0_length_std': 108.22947529685659}



100%|██████████| 5/5 [00:11<00:00,  2.25s/it]

Epoch 5001.0 eval: {'target_12.0_return_mean': 11363.736243529724, 'target_12.0_return_std': 92.20732713645, 'target_12.0_length_mean': 11363.736243529724, 'target_12.0_length_std': 92.20732713645}



100%|██████████| 5/5 [00:11<00:00,  2.22s/it]

Epoch 5501.0 eval: {'target_12.0_return_mean': 11090.862641595084, 'target_12.0_return_std': 146.00222129211522, 'target_12.0_length_mean': 11090.862641595084, 'target_12.0_length_std': 146.00222129211522}



100%|██████████| 5/5 [00:11<00:00,  2.24s/it]

Epoch 6001.0 eval: {'target_12.0_return_mean': 11195.28904575367, 'target_12.0_return_std': 60.4180799733843, 'target_12.0_length_mean': 11195.28904575367, 'target_12.0_length_std': 60.4180799733843}



100%|██████████| 5/5 [00:11<00:00,  2.24s/it]


{'target_12.0_return_mean': 11103.532312140353,
 'target_12.0_return_std': 44.267410736544704,
 'target_12.0_length_mean': 11103.532312140353,
 'target_12.0_length_std': 44.267410736544704}

### Decision Mamba Training

In [6]:
import mujoco
import gymnasium as gym

In [7]:
# build the environment
directory = './video'
device = "cuda"

model = model.to(device)
env = gym.make("HalfCheetah-v4", render_mode='rgb_array')

#env = gym.wrappers.RecordVideo(env, directory)
max_ep_len = 1000
scale = 1000.0  # normalization for rewards/returns
TARGET_RETURN = 12000 / scale  # evaluation is conditioned on a return of 12000, scaled accordingly

state_mean = collator.state_mean.astype(np.float32)
state_std = collator.state_std.astype(np.float32)
print(state_mean)

state_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
# Create the decision transformer model

# state_mean = torch.from_numpy(state_mean).to(device=device)
# state_std = torch.from_numpy(state_std).to(device=device)


[-0.04489212  0.03232612  0.06034821 -0.17081618 -0.19477023 -0.05751681
  0.0970142   0.03239178 11.0473385  -0.07997213 -0.32363245  0.3629689
  0.42323524  0.40836537  1.1085011  -0.48743752 -0.07375081]


In [10]:
def evaluate_episodes(num_eval_episodes, model):
    returns, lengths = [], []

    model.eval()
    
    with torch.no_grad():
        for _ in tqdm(range(num_eval_episodes)):
            ret, length = evaluate_episode_rtg(
                env=env,
                state_dim=state_dim,
                act_dim=act_dim,
                model=model,
                scale=scale,
                state_mean=state_mean,
                state_std=state_std,
                device=device,
                target_return=TARGET_RETURN,
            )

            returns.append(ret)
            lengths.append(ret)

    return {
        f'target_{TARGET_RETURN}_return_mean': np.mean(returns),
        f'target_{TARGET_RETURN}_return_std': np.std(returns),
        f'target_{TARGET_RETURN}_length_mean': np.mean(lengths),
        f'target_{TARGET_RETURN}_length_std': np.std(lengths),
    }

class EvaluateCallback(TrainerCallback):
    def on_epoch_end(self, args, state, control, **kwargs):
        if (int(state.epoch) - 1) % 500 == 0:
            print('Epoch', state.epoch, 'eval:', evaluate_episodes(5, model))


In [11]:
config = DecisionTransformerConfig(state_dim=collator.state_dim, act_dim=collator.act_dim)
model = TrainableDM(config)

print(num_params(model))

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    data_collator=collator,
    callbacks=[EvaluateCallback()],
)

trainer.train()

evaluate_episodes(5, model)

1230872


Step,Training Loss
5000,0.0362
10000,0.0274
15000,0.0261
20000,0.0253
25000,0.0247
30000,0.0242
35000,0.0238
40000,0.0234
45000,0.023
50000,0.0227


100%|██████████| 5/5 [00:12<00:00,  2.58s/it]

Epoch 1.0 eval: {'target_12.0_return_mean': -14.06919297013171, 'target_12.0_return_std': 1.1605751406392475, 'target_12.0_length_mean': -14.06919297013171, 'target_12.0_length_std': 1.1605751406392475}



100%|██████████| 5/5 [00:12<00:00,  2.60s/it]

Epoch 501.0 eval: {'target_12.0_return_mean': 11004.728973586252, 'target_12.0_return_std': 71.91044712426662, 'target_12.0_length_mean': 11004.728973586252, 'target_12.0_length_std': 71.91044712426662}



100%|██████████| 5/5 [00:13<00:00,  2.63s/it]

Epoch 1001.0 eval: {'target_12.0_return_mean': 11008.887768843582, 'target_12.0_return_std': 83.51621808539541, 'target_12.0_length_mean': 11008.887768843582, 'target_12.0_length_std': 83.51621808539541}



100%|██████████| 5/5 [00:13<00:00,  2.63s/it]

Epoch 1501.0 eval: {'target_12.0_return_mean': 11239.333516797022, 'target_12.0_return_std': 100.52138315673544, 'target_12.0_length_mean': 11239.333516797022, 'target_12.0_length_std': 100.52138315673544}



100%|██████████| 5/5 [00:13<00:00,  2.60s/it]

Epoch 2001.0 eval: {'target_12.0_return_mean': 11195.615892787908, 'target_12.0_return_std': 71.89832828904474, 'target_12.0_length_mean': 11195.615892787908, 'target_12.0_length_std': 71.89832828904474}



100%|██████████| 5/5 [00:13<00:00,  2.63s/it]

Epoch 2501.0 eval: {'target_12.0_return_mean': 11145.891078088494, 'target_12.0_return_std': 145.42430855113577, 'target_12.0_length_mean': 11145.891078088494, 'target_12.0_length_std': 145.42430855113577}



100%|██████████| 5/5 [00:13<00:00,  2.61s/it]

Epoch 3001.0 eval: {'target_12.0_return_mean': 11120.522921012129, 'target_12.0_return_std': 57.643228258727056, 'target_12.0_length_mean': 11120.522921012129, 'target_12.0_length_std': 57.643228258727056}



100%|██████████| 5/5 [00:13<00:00,  2.63s/it]

Epoch 3501.0 eval: {'target_12.0_return_mean': 11137.036670080712, 'target_12.0_return_std': 132.2095648109389, 'target_12.0_length_mean': 11137.036670080712, 'target_12.0_length_std': 132.2095648109389}



100%|██████████| 5/5 [00:13<00:00,  2.62s/it]

Epoch 4001.0 eval: {'target_12.0_return_mean': 11135.456405894121, 'target_12.0_return_std': 117.30773801670622, 'target_12.0_length_mean': 11135.456405894121, 'target_12.0_length_std': 117.30773801670622}



100%|██████████| 5/5 [00:13<00:00,  2.61s/it]

Epoch 4501.0 eval: {'target_12.0_return_mean': 11177.834986975333, 'target_12.0_return_std': 126.7074676400272, 'target_12.0_length_mean': 11177.834986975333, 'target_12.0_length_std': 126.7074676400272}



100%|██████████| 5/5 [00:13<00:00,  2.63s/it]

Epoch 5001.0 eval: {'target_12.0_return_mean': 11234.28329622281, 'target_12.0_return_std': 123.76612037424367, 'target_12.0_length_mean': 11234.28329622281, 'target_12.0_length_std': 123.76612037424367}



100%|██████████| 5/5 [00:13<00:00,  2.61s/it]

Epoch 5501.0 eval: {'target_12.0_return_mean': 11178.267474027167, 'target_12.0_return_std': 129.48385591572944, 'target_12.0_length_mean': 11178.267474027167, 'target_12.0_length_std': 129.48385591572944}



100%|██████████| 5/5 [00:13<00:00,  2.61s/it]

Epoch 6001.0 eval: {'target_12.0_return_mean': 11177.703071772648, 'target_12.0_return_std': 25.172768386687757, 'target_12.0_length_mean': 11177.703071772648, 'target_12.0_length_std': 25.172768386687757}





TrainOutput(global_step=100000, training_loss=0.02360807273864746, metrics={'train_runtime': 17297.5847, 'train_samples_per_second': 369.994, 'train_steps_per_second': 5.781, 'total_flos': 9008946000000000.0, 'train_loss': 0.02360807273864746, 'epoch': 6250.0})

### Step 7: Visualize the performance of the agent

With mujoco_py, it'll take a little while to compile the first 

In [12]:
evaluate_episodes(5, model)

100%|██████████| 5/5 [00:13<00:00,  2.61s/it]


{'target_12.0_return_mean': 11049.180788166927,
 'target_12.0_return_std': 174.65680113676865,
 'target_12.0_length_mean': 11049.180788166927,
 'target_12.0_length_std': 174.65680113676865}