In [1]:
import os
import random
from dataclasses import dataclass

import numpy as np
import torch
from datasets import load_dataset, load_from_disk
from transformers import DecisionTransformerConfig, DecisionTransformerModel, Trainer, TrainingArguments

In [2]:
torch.cuda.current_device()

0

In [3]:
dataset = load_from_disk("data/dataset/")

In [4]:
state_mean = dataset['state_mean']
state_std = dataset['state_std']

In [5]:
dataset = dataset['train']

KeyboardInterrupt: 

In [None]:
len(dataset), len(dataset[0])

(5334, 4)

In [None]:
dataset[0].keys()

dict_keys(['actions', 'dones', 'observations', 'rewards'])

In [None]:
act_dim = len(dataset[0]['actions'][0])
act_dim

393

In [None]:
state_dim = len(dataset[0]['observations'][0])
state_dim

4324

In [None]:
state_mean = state_mean[:state_dim]
state_std = state_std[:state_dim]

In [None]:
len(dataset[0]['observations'])

75

In [None]:
dataset[0]['actions'][:2]

[[-0.019129622727632523,
  -0.006513502448797226,
  0.04570562019944191,
  0.15612338483333588,
  0.041282929480075836,
  0.18546512722969055,
  -0.16300007700920105,
  0.045858871191740036,
  -0.0598205029964447,
  -0.046219564974308014,
  -0.26312240958213806,
  -0.10800012946128845,
  -0.07765176892280579,
  -0.11391003429889679,
  -0.15980571508407593,
  0.04218614101409912,
  -0.08814289420843124,
  0.15767569839954376,
  -0.17028531432151794,
  -0.04312672093510628,
  -0.14606116712093353,
  0.08254409581422806,
  -0.03896056115627289,
  -0.15955659747123718,
  -0.09140151739120483,
  -0.11390421539545059,
  0.06155223771929741,
  0.08480572700500488,
  0.0758899450302124,
  0.12660588324069977,
  0.0731610432267189,
  -0.08008327335119247,
  0.0851617380976677,
  -0.10980658233165741,
  0.05300186946988106,
  -0.1484784483909607,
  -0.10654953122138977,
  0.27709102630615234,
  -0.09234079718589783,
  0.22396980226039886,
  -0.13589535653591156,
  -0.023483429104089737,
  0.1311

In [None]:
@dataclass
class DecisionTransformerGymDataCollator:
    return_tensors: str = "pt"
    max_len: int = 20 #subsets of the episode we use for training
    state_dim: int = 4324  # size of state space
    act_dim: int = 393  # size of action space
    max_ep_len: int = 985 # max episode length in the dataset
    scale: float = 1000.0  # normalization of rewards/returns
    #state_mean: np.array = None  # to store state means
    #state_std: np.array = None  # to store state stds
    p_sample: np.array = None  # a distribution to take account trajectory lengths
    n_traj: int = 0 # to store the number of trajectories in the dataset

    def __init__(self, dataset, state_mean, state_std) -> None:
        self.act_dim = len(dataset[0]['actions'][0])
        self.state_dim = len(dataset[0]['observations'][0])
        self.dataset = dataset
        self.state_mean = state_mean
        self.state_std = state_std
        # calculate dataset stats for normalization of states
        states = []
        traj_lens = []

        self.n_traj = len(self.dataset)

        traj_lens = [len(self.dataset[0]) for i in range(self.n_traj)]
        traj_lens = np.array(traj_lens)
        self.p_sample = traj_lens / sum(traj_lens)

    def _discount_cumsum(self, x, gamma):
        discount_cumsum = np.zeros_like(x)
        discount_cumsum[-1] = x[-1]
        for t in reversed(range(x.shape[0] - 1)):
            discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
        return discount_cumsum

    def __call__(self, features):
        batch_size = len(features)
        # this is a bit of a hack to be able to sample of a non-uniform distribution
        batch_inds = np.random.choice(
            np.arange(self.n_traj),
            size=batch_size,
            replace=True,
            p=self.p_sample,  # reweights so we sample according to timesteps
        )
        # a batch of dataset features
        s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []

        for ind in batch_inds:
            # for feature in features:
            feature = self.dataset[int(ind)]
            si = random.randint(0, len(feature["rewards"]) - 1)

            # get sequences from dataset
            s.append(np.array(feature["observations"][si : si + self.max_len]).reshape(1, -1, self.state_dim))
            a.append(np.array(feature["actions"][si : si + self.max_len]).reshape(1, -1, self.act_dim))
            r.append(np.array(feature["rewards"][si : si + self.max_len]).reshape(1, -1, 1))

            d.append(np.array(feature["dones"][si : si + self.max_len]).reshape(1, -1))
            timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))
            timesteps[-1][timesteps[-1] >= self.max_ep_len] = self.max_ep_len - 1  # padding cutoff
            rtg.append(
                self._discount_cumsum(np.array(feature["rewards"][si:]), gamma=1.0)[
                    : s[-1].shape[1]   # TODO check the +1 removed here
                ].reshape(1, -1, 1)
            )
            if rtg[-1].shape[1] < s[-1].shape[1]:
                print("if true")
                rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)

            # padding and state + reward normalization
            tlen = s[-1].shape[1]
            s[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, self.state_dim)), s[-1]], axis=1)
            s[-1] = (s[-1] - self.state_mean) / self.state_std
            a[-1] = np.concatenate(
                [np.ones((1, self.max_len - tlen, self.act_dim)) * -10.0, a[-1]],
                axis=1,
            )
            r[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, 1)), r[-1]], axis=1)
            d[-1] = np.concatenate([np.ones((1, self.max_len - tlen)) * 2, d[-1]], axis=1)
            rtg[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, 1)), rtg[-1]], axis=1) / self.scale
            timesteps[-1] = np.concatenate([np.zeros((1, self.max_len - tlen)), timesteps[-1]], axis=1)
            mask.append(np.concatenate([np.zeros((1, self.max_len - tlen)), np.ones((1, tlen))], axis=1))

        s = torch.from_numpy(np.concatenate(s, axis=0)).float()
        a = torch.from_numpy(np.concatenate(a, axis=0)).float()
        r = torch.from_numpy(np.concatenate(r, axis=0)).float()
        d = torch.from_numpy(np.concatenate(d, axis=0))
        rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).float()
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).long()
        mask = torch.from_numpy(np.concatenate(mask, axis=0)).float()

        return {
            "states": s,
            "actions": a,
            "rewards": r,
            "returns_to_go": rtg,
            "timesteps": timesteps,
            "attention_mask": mask,
        }

In [None]:
collator = DecisionTransformerGymDataCollator(dataset, state_mean, state_std)

In [None]:
import random
from dataclasses import dataclass

import numpy as np
import torch

from transformers import DecisionTransformerConfig, DecisionTransformerModel, Trainer, TrainingArguments

In [None]:
class TrainableDT(DecisionTransformerModel):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, **kwargs):
        output = super().forward(**kwargs)
        # add the DT loss
        action_preds = output[1]
        action_targets = kwargs["actions"]
        attention_mask = kwargs["attention_mask"]
        act_dim = action_preds.shape[2]
        action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        action_targets = action_targets.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]

        loss = torch.mean((action_preds - action_targets) ** 2)

        return {"loss": loss}

    def original_forward(self, **kwargs):
        return super().forward(**kwargs)

In [None]:
config = DecisionTransformerConfig(state_dim=collator.state_dim, act_dim=collator.act_dim)
model = TrainableDT(config)

In [None]:
os.environ["WANDB_DISABLED"] = "true" # we disable weights and biases logging for this tutorial

In [None]:
training_args = TrainingArguments(
    output_dir="output/",
    remove_unused_columns=False,
    num_train_epochs=120,
    per_device_train_batch_size=64,
    learning_rate=1e-4,
    weight_decay=1e-4,
    warmup_ratio=0.1,
    optim="adamw_torch",
    max_grad_norm=0.25,
)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [None]:
training_args.device

device(type='cuda', index=0)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=collator,
)

trainer.train()

Step,Training Loss
500,0.0242
1000,0.0031
1500,0.0019
2000,0.0012
2500,0.0008
3000,0.0005
3500,0.0004
4000,0.0003
4500,0.0002
5000,0.0002


TrainOutput(global_step=10080, training_loss=0.001679275388642776, metrics={'train_runtime': 2402.1044, 'train_samples_per_second': 266.466, 'train_steps_per_second': 4.196, 'total_flos': 6.005605220842752e+17, 'train_loss': 0.001679275388642776, 'epoch': 120.0})

In [None]:
trainer.save_model('trained_models')