# Training Decision Transformers with 🤗 transformers

In this tutorial, **you’ll learn to train your first Offline Decision Transformer model from scratch to make a half-cheetah run.** 🏃

❓ If you have questions, please post them on #study-group discord channel 👉 https://discord.gg/aYka4Yhff9

🎮 Environments:
- [Half Cheetah](https://www.gymlibrary.dev/environments/mujoco/half_cheetah/)

⬇️ Here's what you'll achieve at the end of this tutorial. ⬇️

### Prerequisites 🏗️
Before diving into the notebook, you need to:

🔲 📚 [Read the tutorial](https://huggingface.co/blog/train-decision-transformers)

In [1]:
import os
import random
from dataclasses import dataclass

import numpy as np
import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import DecisionTransformerConfig, Trainer, TrainingArguments, TrainerCallback
from models.decision_mamba import TrainableDT, TrainableDM
from evaluation.evaluate_episodes import evaluate_episode_rtg


### Step 3: Loading the dataset from the 🤗 Hub and instantiating the model

We host a number of Offline RL Datasets on the hub. Today we will be training with the halfcheetah “expert” dataset, hosted here on hub.

First we need to import the load_dataset function from the 🤗 datasets package and download the dataset to our machine.

In [9]:
os.environ["WANDB_DISABLED"] = "true" # we diable weights and biases logging for this tutorial
DATASET_NAME = "halfcheetah-medium-expert-v2"

dataset = load_dataset("edbeeching/decision_transformer_gym_replay", DATASET_NAME, )


### Step 4: Defining a custom DataCollator for the transformers Trainer class

In [3]:
@dataclass
class DecisionTransformerGymDataCollator:
    return_tensors: str = "pt"
    max_len: int = 20 #subsets of the episode we use for training
    state_dim: int = 17  # size of state space
    act_dim: int = 6  # size of action space
    max_ep_len: int = 1000 # max episode length in the dataset
    scale: float = 1000.0  # normalization of rewards/returns
    state_mean: np.array = None  # to store state means
    state_std: np.array = None  # to store state stds
    p_sample: np.array = None  # a distribution to take account trajectory lengths
    n_traj: int = 0 # to store the number of trajectories in the dataset

    def __init__(self, dataset) -> None:
        self.act_dim = len(dataset[0]["actions"][0])
        self.state_dim = len(dataset[0]["observations"][0])
        self.dataset = dataset
        # calculate dataset stats for normalization of states
        states = []
        traj_lens = []
        for obs in dataset["observations"]:
            states.extend(obs)
            traj_lens.append(len(obs))
        self.n_traj = len(traj_lens)
        states = np.vstack(states)
        self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6

        traj_lens = np.array(traj_lens)
        self.p_sample = traj_lens / sum(traj_lens)

    def _discount_cumsum(self, x, gamma):
        discount_cumsum = np.zeros_like(x)
        discount_cumsum[-1] = x[-1]
        for t in reversed(range(x.shape[0] - 1)):
            discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
        return discount_cumsum

    def __call__(self, features):
        batch_size = len(features)
        # this is a bit of a hack to be able to sample of a non-uniform distribution
        batch_inds = np.random.choice(
            np.arange(self.n_traj),
            size=batch_size,
            replace=True,
            p=self.p_sample,  # reweights so we sample according to timesteps
        )
        # a batch of dataset features
        s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []

        for ind in batch_inds:
            # for feature in features:
            feature = self.dataset[int(ind)]
            si = random.randint(0, len(feature["rewards"]) - 1)

            # get sequences from dataset
            s.append(np.array(feature["observations"][si : si + self.max_len]).reshape(1, -1, self.state_dim))
            a.append(np.array(feature["actions"][si : si + self.max_len]).reshape(1, -1, self.act_dim))
            r.append(np.array(feature["rewards"][si : si + self.max_len]).reshape(1, -1, 1))

            d.append(np.array(feature["dones"][si : si + self.max_len]).reshape(1, -1))
            timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))
            timesteps[-1][timesteps[-1] >= self.max_ep_len] = self.max_ep_len - 1  # padding cutoff
            rtg.append(
                self._discount_cumsum(np.array(feature["rewards"][si:]), gamma=1.0)[
                    : s[-1].shape[1]   # TODO check the +1 removed here
                ].reshape(1, -1, 1)
            )
            if rtg[-1].shape[1] < s[-1].shape[1]:
                print("if true")
                rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)

            # padding and state + reward normalization
            tlen = s[-1].shape[1]
            s[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, self.state_dim)), s[-1]], axis=1)
            s[-1] = (s[-1] - self.state_mean) / self.state_std
            a[-1] = np.concatenate(
                [np.ones((1, self.max_len - tlen, self.act_dim)) * -10.0, a[-1]],
                axis=1,
            )
            r[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, 1)), r[-1]], axis=1)
            d[-1] = np.concatenate([np.ones((1, self.max_len - tlen)) * 2, d[-1]], axis=1)
            rtg[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, 1)), rtg[-1]], axis=1) / self.scale
            timesteps[-1] = np.concatenate([np.zeros((1, self.max_len - tlen)), timesteps[-1]], axis=1)
            mask.append(np.concatenate([np.zeros((1, self.max_len - tlen)), np.ones((1, tlen))], axis=1))

        s = torch.from_numpy(np.concatenate(s, axis=0)).float()
        a = torch.from_numpy(np.concatenate(a, axis=0)).float()
        r = torch.from_numpy(np.concatenate(r, axis=0)).float()
        d = torch.from_numpy(np.concatenate(d, axis=0))
        rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).float()
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).long()
        mask = torch.from_numpy(np.concatenate(mask, axis=0)).float()

        return {
            "states": s,
            "actions": a,
            "rewards": r,
            "returns_to_go": rtg,
            "timesteps": timesteps,
            "attention_mask": mask,
        }

### Step 5: Extending the Decision Transformer Model to include a loss function

In order to train the model with the 🤗 trainer class, we first need to ensure the dictionary it returns contains a loss, in this case L-2 norm of the models action predictions and the targets.

In [4]:
collator = DecisionTransformerGymDataCollator(dataset['train'])

config = DecisionTransformerConfig(state_dim=collator.state_dim, act_dim=collator.act_dim)
model = TrainableDT(config)

def num_params(model):
    return sum(p.numel() for p in model.parameters())

num_params(model)

1257368

### Step 6: Defining the training hyperparameters and training the model
Here, we define the training hyperparameters and our Trainer class that we'll use to train our Decision Transformer model.

This step takes about an hour, so you may leave it running. Note the authors train for at least 3 hours, so the results presented here are not as performant as the models hosted on the 🤗 hub.

In [5]:
import mujoco
import gymnasium as gym

# build the environment
directory = './video'
device = "cuda"

model = model.to(device)
env = gym.make("HalfCheetah-v4", render_mode='rgb_array')

#env = gym.wrappers.RecordVideo(env, directory)
max_ep_len = 1000
scale = 1000.0  # normalization for rewards/returns
TARGET_RETURN = 12000 / scale  # evaluation is conditioned on a return of 12000, scaled accordingly

state_mean = collator.state_mean.astype(np.float32)
state_std = collator.state_std.astype(np.float32)
print(state_mean)

state_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
# Create the decision transformer model

# state_mean = torch.from_numpy(state_mean).to(device=device)
# state_std = torch.from_numpy(state_std).to(device=device)


[-0.04489212  0.03232612  0.06034821 -0.17081618 -0.19477023 -0.05751681
  0.0970142   0.03239178 11.0473385  -0.07997213 -0.32363245  0.3629689
  0.42323524  0.40836537  1.1085011  -0.48743752 -0.07375081]


In [6]:
def evaluate_episodes(num_eval_episodes, model):
    returns, lengths = [], []

    model.eval()
    
    with torch.no_grad():
        for _ in tqdm(range(num_eval_episodes)):
            ret, length = evaluate_episode_rtg(
                env=env,
                state_dim=state_dim,
                act_dim=act_dim,
                model=model,
                scale=scale,
                state_mean=state_mean,
                state_std=state_std,
                device=device,
                target_return=TARGET_RETURN,
            )

            returns.append(ret)
            lengths.append(ret)

    return {
        f'target_{TARGET_RETURN}_return_mean': np.mean(returns),
        f'target_{TARGET_RETURN}_return_std': np.std(returns),
        f'target_{TARGET_RETURN}_length_mean': np.mean(lengths),
        f'target_{TARGET_RETURN}_length_std': np.std(lengths),
    }

class EvaluateCallback(TrainerCallback):
    def on_epoch_end(self, args, state, control, **kwargs):
        if int(state.epoch) % 500 == 0:
            print('Epoch', state.epoch, 'eval:', evaluate_episodes(10, model))


In [7]:
# these params more or less match the ones used by the original DT paper

training_args = TrainingArguments(
    output_dir="output/",
    remove_unused_columns=False,
    max_steps=50_000, # we only need about 50k steps to reach the highest score for DT and DM
    logging_strategy='steps',
    save_strategy='no',
    logging_steps=500,
    warmup_steps=0,
    per_device_train_batch_size=64,
    learning_rate=1e-4,
    weight_decay=1e-4,
    #warmup_ratio=0.1,
    optim="adamw_torch",
    dataloader_num_workers=16,
    dataloader_persistent_workers=True,
    max_grad_norm=0.25,
    tf32=True,
    bf16=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset['train'],
    data_collator=collator,
    callbacks=[EvaluateCallback()],
)

trainer.train()

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Step,Training Loss
500,0.0847
1000,0.0456
1500,0.0412
2000,0.0388
2500,0.0373
3000,0.036
3500,0.035
4000,0.0347
4500,0.034
5000,0.0336


100%|██████████| 10/10 [00:22<00:00,  2.23s/it]


Epoch 500.0 eval: {'target_12.0_return_mean': 10989.504561448835, 'target_12.0_return_std': 162.11139303138572, 'target_12.0_length_mean': 10989.504561448835, 'target_12.0_length_std': 162.11139303138572}


100%|██████████| 10/10 [00:22<00:00,  2.22s/it]


Epoch 1000.0 eval: {'target_12.0_return_mean': 10876.45982045702, 'target_12.0_return_std': 122.30480176626415, 'target_12.0_length_mean': 10876.45982045702, 'target_12.0_length_std': 122.30480176626415}


100%|██████████| 10/10 [00:22<00:00,  2.23s/it]


Epoch 1500.0 eval: {'target_12.0_return_mean': 11015.533469758262, 'target_12.0_return_std': 95.93927091062784, 'target_12.0_length_mean': 11015.533469758262, 'target_12.0_length_std': 95.93927091062784}


100%|██████████| 10/10 [00:22<00:00,  2.23s/it]


Epoch 2000.0 eval: {'target_12.0_return_mean': 11128.097160683043, 'target_12.0_return_std': 114.82603171683121, 'target_12.0_length_mean': 11128.097160683043, 'target_12.0_length_std': 114.82603171683121}


100%|██████████| 10/10 [00:22<00:00,  2.21s/it]


Epoch 2500.0 eval: {'target_12.0_return_mean': 11329.624658365869, 'target_12.0_return_std': 146.35758150366172, 'target_12.0_length_mean': 11329.624658365869, 'target_12.0_length_std': 146.35758150366172}


100%|██████████| 10/10 [00:22<00:00,  2.22s/it]


Epoch 3000.0 eval: {'target_12.0_return_mean': 11236.52906299337, 'target_12.0_return_std': 110.82619620657644, 'target_12.0_length_mean': 11236.52906299337, 'target_12.0_length_std': 110.82619620657644}


100%|██████████| 10/10 [00:22<00:00,  2.24s/it]


Epoch 3500.0 eval: {'target_12.0_return_mean': 11198.178584461948, 'target_12.0_return_std': 87.11778150438742, 'target_12.0_length_mean': 11198.178584461948, 'target_12.0_length_std': 87.11778150438742}


100%|██████████| 10/10 [00:22<00:00,  2.21s/it]


Epoch 4000.0 eval: {'target_12.0_return_mean': 11149.90722964233, 'target_12.0_return_std': 143.04270196743954, 'target_12.0_length_mean': 11149.90722964233, 'target_12.0_length_std': 143.04270196743954}


100%|██████████| 10/10 [00:22<00:00,  2.22s/it]


Epoch 4500.0 eval: {'target_12.0_return_mean': 11255.579453085767, 'target_12.0_return_std': 145.01454014693073, 'target_12.0_length_mean': 11255.579453085767, 'target_12.0_length_std': 145.01454014693073}


100%|██████████| 10/10 [00:22<00:00,  2.22s/it]


Epoch 5000.0 eval: {'target_12.0_return_mean': 11211.59157367994, 'target_12.0_return_std': 134.117961524067, 'target_12.0_length_mean': 11211.59157367994, 'target_12.0_length_std': 134.117961524067}


100%|██████████| 10/10 [00:22<00:00,  2.23s/it]


Epoch 5500.0 eval: {'target_12.0_return_mean': 11235.584381126067, 'target_12.0_return_std': 121.56054460758341, 'target_12.0_length_mean': 11235.584381126067, 'target_12.0_length_std': 121.56054460758341}


100%|██████████| 10/10 [00:22<00:00,  2.23s/it]


Epoch 6000.0 eval: {'target_12.0_return_mean': 11209.636695802928, 'target_12.0_return_std': 100.82104355519274, 'target_12.0_length_mean': 11209.636695802928, 'target_12.0_length_std': 100.82104355519274}


TrainOutput(global_step=100000, training_loss=0.028459468364715575, metrics={'train_runtime': 11641.737, 'train_samples_per_second': 549.746, 'train_steps_per_second': 8.59, 'total_flos': 7673970000000000.0, 'train_loss': 0.028459468364715575, 'epoch': 6250.0})

### Decision Mamba Training

In [8]:
config = DecisionTransformerConfig(state_dim=collator.state_dim, act_dim=collator.act_dim)
model = TrainableDM(config)

print(num_params(model))

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    data_collator=collator,
    callbacks=[EvaluateCallback()],
)

trainer.train()

evaluate_episodes(10, model)

1230872


Step,Training Loss
500,0.0812
1000,0.0375
1500,0.0346
2000,0.0327
2500,0.0317
3000,0.0307
3500,0.0299
4000,0.0297
4500,0.0292
5000,0.0288


100%|██████████| 10/10 [00:26<00:00,  2.70s/it]


Epoch 500.0 eval: {'target_12.0_return_mean': 11108.234626054067, 'target_12.0_return_std': 73.84636351995347, 'target_12.0_length_mean': 11108.234626054067, 'target_12.0_length_std': 73.84636351995347}


100%|██████████| 10/10 [00:27<00:00,  2.71s/it]


Epoch 1000.0 eval: {'target_12.0_return_mean': 11079.969558672155, 'target_12.0_return_std': 108.33909028755076, 'target_12.0_length_mean': 11079.969558672155, 'target_12.0_length_std': 108.33909028755076}


100%|██████████| 10/10 [00:26<00:00,  2.70s/it]


Epoch 1500.0 eval: {'target_12.0_return_mean': 11083.025185320897, 'target_12.0_return_std': 119.13317300046135, 'target_12.0_length_mean': 11083.025185320897, 'target_12.0_length_std': 119.13317300046135}


100%|██████████| 10/10 [00:27<00:00,  2.71s/it]


Epoch 2000.0 eval: {'target_12.0_return_mean': 11194.317992716893, 'target_12.0_return_std': 115.86805905663459, 'target_12.0_length_mean': 11194.317992716893, 'target_12.0_length_std': 115.86805905663459}


100%|██████████| 10/10 [00:27<00:00,  2.72s/it]


Epoch 2500.0 eval: {'target_12.0_return_mean': 11190.49171754849, 'target_12.0_return_std': 98.76984893962259, 'target_12.0_length_mean': 11190.49171754849, 'target_12.0_length_std': 98.76984893962259}


100%|██████████| 10/10 [00:27<00:00,  2.71s/it]


Epoch 3000.0 eval: {'target_12.0_return_mean': 11301.625097235714, 'target_12.0_return_std': 77.62978415697798, 'target_12.0_length_mean': 11301.625097235714, 'target_12.0_length_std': 77.62978415697798}


100%|██████████| 10/10 [00:27<00:00,  2.72s/it]


Epoch 3500.0 eval: {'target_12.0_return_mean': 11185.843814265307, 'target_12.0_return_std': 149.90863085979038, 'target_12.0_length_mean': 11185.843814265307, 'target_12.0_length_std': 149.90863085979038}


100%|██████████| 10/10 [00:27<00:00,  2.72s/it]


Epoch 4000.0 eval: {'target_12.0_return_mean': 11119.95143462426, 'target_12.0_return_std': 149.31119889729456, 'target_12.0_length_mean': 11119.95143462426, 'target_12.0_length_std': 149.31119889729456}


100%|██████████| 10/10 [00:27<00:00,  2.72s/it]


Epoch 4500.0 eval: {'target_12.0_return_mean': 11193.384980876936, 'target_12.0_return_std': 73.01928614082368, 'target_12.0_length_mean': 11193.384980876936, 'target_12.0_length_std': 73.01928614082368}


100%|██████████| 10/10 [00:27<00:00,  2.73s/it]


Epoch 5000.0 eval: {'target_12.0_return_mean': 11173.847391281797, 'target_12.0_return_std': 96.17459660286931, 'target_12.0_length_mean': 11173.847391281797, 'target_12.0_length_std': 96.17459660286931}


100%|██████████| 10/10 [00:27<00:00,  2.72s/it]


Epoch 5500.0 eval: {'target_12.0_return_mean': 11231.128363471944, 'target_12.0_return_std': 85.14632933250516, 'target_12.0_length_mean': 11231.128363471944, 'target_12.0_length_std': 85.14632933250516}


100%|██████████| 10/10 [00:27<00:00,  2.74s/it]


Epoch 6000.0 eval: {'target_12.0_return_mean': 11226.860700096828, 'target_12.0_return_std': 65.14456825745998, 'target_12.0_length_mean': 11226.860700096828, 'target_12.0_length_std': 65.14456825745998}


100%|██████████| 10/10 [00:27<00:00,  2.76s/it]


{'target_12.0_return_mean': 11115.0773040795,
 'target_12.0_return_std': 118.08435197662368,
 'target_12.0_length_mean': 11115.0773040795,
 'target_12.0_length_std': 118.08435197662368}

### Step 7: Visualize the performance of the agent

With mujoco_py, it'll take a little while to compile the first 

In [11]:
trainer.save_model('trained_models/dm_' + DATASET_NAME)