In [17]:
import os
import random
from dataclasses import dataclass

import numpy as np
from tqdm import tqdm
import torch

import datasets
import torch.utils.data
from transformers import TrainingArguments, Trainer
from decision_transformer import DecisionTransformerConfig, DecisionTransformerModel

import snowietxt_processor


In [18]:
os.environ["WANDB_DISABLED"] = "true"
torch.backends.cuda.matmul.allow_tf32 = True

### Step 4: Defining a custom DataCollator for the transformers Trainer class

In [19]:
@dataclass
class DecisionTransformerGymDataCollator:
    return_tensors: str = "pt"
    max_len: int = 10 #subsets of the episode we use for training
    state_dim: int = 17  # size of state space
    act_dim: int = 6  # size of action space
    max_ep_len: int = 1000 # max episode length in the dataset
    scale: float = 1.0  # normalization of rewards/returns
    state_mean: np.array = None  # to store state means
    state_std: np.array = None  # to store state stds
    p_sample: np.array = None  # a distribution to take account trajectory lengths
    n_traj: int = 0 # to store the number of trajectories in the dataset

    def __init__(self, dataset) -> None:
        self.act_dim = len(dataset[0]["actions"][0])
        self.state_dim = len(dataset[0]["observations"][0])
        self.dataset = dataset
        # calculate dataset stats for normalization of states
        states = []
        traj_lens = []
        for obs in dataset["observations"]:
            #states.extend(obs)
            traj_lens.append(len(obs))
        self.n_traj = len(traj_lens)
        #states = np.vstack(states)
        #self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
        
        traj_lens = np.array(traj_lens)
        self.p_sample = traj_lens / sum(traj_lens)

    def _discount_cumsum(self, x, gamma):
        discount_cumsum = np.zeros_like(x)
        discount_cumsum[-1] = x[-1]
        for t in reversed(range(x.shape[0] - 1)):
            discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
        return discount_cumsum

    def __call__(self, features):
        batch_size = len(features)
        # this is a bit of a hack to be able to sample of a non-uniform distribution
        batch_inds = np.random.choice(
            np.arange(self.n_traj),
            size=batch_size,
            replace=True,
            p=self.p_sample,  # reweights so we sample according to timesteps
        )
        # a batch of dataset features
        s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []
        
        for ind in batch_inds:
            # for feature in features:
            feature = self.dataset[int(ind)]
            si = random.randint(0, len(feature["rewards"]) - 1)

            # get sequences from dataset
            s.append(np.array(feature["observations"][si : si + self.max_len]).reshape(1, -1, self.state_dim))
            a.append(np.array(feature["actions"][si : si + self.max_len]).reshape(1, -1, self.act_dim))
            r.append(np.array(feature["rewards"][si : si + self.max_len]).reshape(1, -1, 1))

            d.append(np.array(feature["dones"][si : si + self.max_len]).reshape(1, -1))
            timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))
            timesteps[-1][timesteps[-1] >= self.max_ep_len] = self.max_ep_len - 1  # padding cutoff
            rtg.append(
                self._discount_cumsum(np.array(feature["rewards"][si:]), gamma=1.0)[
                    : s[-1].shape[1]   # TODO check the +1 removed here
                ].reshape(1, -1, 1)
            )
            if rtg[-1].shape[1] < s[-1].shape[1]:
                print("if true")
                rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)

            # padding and state + reward normalization
            tlen = s[-1].shape[1]
            s[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, self.state_dim)), s[-1]], axis=1)
            #s[-1] = (s[-1] - self.state_mean) / self.state_std
            a[-1] = np.concatenate(
                [np.ones((1, self.max_len - tlen, self.act_dim)) * -10.0, a[-1]],
                axis=1,
            )
            r[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, 1)), r[-1]], axis=1)
            d[-1] = np.concatenate([np.ones((1, self.max_len - tlen)) * 2, d[-1]], axis=1)
            rtg[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, 1)), rtg[-1]], axis=1) / self.scale
            timesteps[-1] = np.concatenate([np.zeros((1, self.max_len - tlen)), timesteps[-1]], axis=1)
            mask.append(np.concatenate([np.zeros((1, self.max_len - tlen)), np.ones((1, tlen))], axis=1))

        s = torch.from_numpy(np.concatenate(s, axis=0)).float().cuda()
        a = torch.from_numpy(np.concatenate(a, axis=0)).float().cuda()
        r = torch.from_numpy(np.concatenate(r, axis=0)).float().cuda()
        d = torch.from_numpy(np.concatenate(d, axis=0))
        rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).float().cuda()
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).long().cuda()
        mask = torch.from_numpy(np.concatenate(mask, axis=0)).float().cuda()

        return {
            "states": s,
            "actions": a,
            "rewards": r,
            "returns_to_go": rtg,
            "timesteps": timesteps,
            "attention_mask": mask,
        }

class DummyCollator:
    return_tensors: str = "pt"

    def __call__(self, features):
        # in this case, features should only ever be a single index because we've already created and collated the dataset ahead of time

        return features[0]

class DummyDataset(torch.utils.data.Dataset):
    def __init__(self, states, actions, rewards, rtgs, timesteps, attention_mask):
        self.states = states
        self.actions = actions
        self.rewards = rewards
        self.rtgs = rtgs
        self.timesteps = timesteps
        self.attention_mask = attention_mask

    def __len__(self):
        return len(self.states)

    def __getitem__(self, idx):
        return {
            "states": self.states[idx],
            "actions": self.actions[idx],
            "rewards": self.rewards[idx],
            "returns_to_go": self.rtgs[idx],
            "timesteps": self.timesteps[idx],
            "attention_mask": self.attention_mask[idx],
        }

In [20]:
dataset = snowietxt_processor.create_dataset()
dataset = datasets.Dataset.from_dict(dataset)

100%|██████████| 5105/5105 [00:01<00:00, 4177.50it/s]


In [21]:
# we want to preprocess the dataset using the data collator we defined above (it also handles the batches
# right now, just test to see how many batches we get using the collator and batch size of 64
from torch.utils.data import DataLoader

collator = DecisionTransformerGymDataCollator(dataset)
dataloader = DataLoader(dataset, batch_size=64, collate_fn=collator)


In [22]:
states, actions, rewards, returns_to_go, timesteps, attention_mask = [], [], [], [], [], []

num_batches = 0

for i in range(10 // DecisionTransformerGymDataCollator.max_len): # TODO: figure out the average length of an episode
    for batch in tqdm(dataloader, total=len(dataloader)):
        states.append(batch["states"])
        actions.append(batch["actions"])
        rewards.append(batch["rewards"])
        returns_to_go.append(batch["returns_to_go"])
        timesteps.append(batch["timesteps"])
        attention_mask.append(batch["attention_mask"])

        num_batches += 1


100%|██████████| 80/80 [00:33<00:00,  2.35it/s]


In [23]:
# calculate how much memory the states take up
states_view = torch.cat(states, dim=0)
actions_view = torch.cat(actions, dim=0)

print(f"States shape: {states_view.shape}")
print(f"Actions shape: {actions_view.shape}")

# calculate how much memory the states take up
print(f"States memory: {states_view.element_size() * states_view.nelement() / 1e6} MB")
print(f"Actions memory: {actions_view.element_size() * actions_view.nelement() / 1e6} MB")

States shape: torch.Size([5105, 10, 210])
Actions shape: torch.Size([5105, 10, 8])
States memory: 42.882 MB
Actions memory: 1.6336 MB


In [24]:
print(states[0])

tensor([[[1., 1., 0.,  ..., 0., 0., 0.],
         [1., 1., 0.,  ..., 0., 0., 0.],
         [1., 1., 0.,  ..., 0., 0., 1.],
         ...,
         [1., 1., 0.,  ..., 0., 0., 0.],
         [1., 1., 0.,  ..., 0., 0., 0.],
         [1., 1., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         ...,
         [1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0., 

In [25]:
# now, create a new dataset based on the collated data
# dataset = Dataset.from_dict(
#     {
#         "states": states,
#         "actions": actions,
#         "rewards": rewards,
#         "returns_to_go": returns_to_go,
#         "timesteps": timesteps,
#         "attention_mask": attention_mask,
#     }
# )

dataset = DummyDataset(states, actions, rewards, returns_to_go, timesteps, attention_mask)

In [26]:
print(dataset[0]['states'].shape)

torch.Size([64, 10, 210])


### Step 5: Extending the Decision Transformer Model to include a loss function

In order to train the model with the 🤗 trainer class, we first need to ensure the dictionary it returns contains a loss, in this case L-2 norm of the models action predictions and the targets.

In [27]:
class TrainableDT(DecisionTransformerModel):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, **kwargs):
        output = super().forward(**kwargs)
        # add the DT loss
        action_preds = output[1]
        action_targets = kwargs["actions"]
        attention_mask = kwargs["attention_mask"]

        act_dim = action_preds.shape[2]
        action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        action_targets = action_targets.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        
        loss = torch.mean((action_preds - action_targets) ** 2)

        return {"loss": loss}

    def original_forward(self, **kwargs):
        return super().forward(**kwargs)

In [28]:
dummy_collator = DummyCollator()

config = DecisionTransformerConfig(state_dim=collator.state_dim, act_dim=collator.act_dim, action_tanh=False)
model = TrainableDT(config)

### Step 6: Defining the training hyperparameters and training the model
Here, we define the training hyperparameters and our Trainer class that we'll use to train our Decision Transformer model.

This step takes about an hour, so you may leave it running. Note the authors train for at least 3 hours, so the results presented here are not as performant as the models hosted on the 🤗 hub.

In [29]:
training_args = TrainingArguments(
    output_dir="output/",
    remove_unused_columns=False,
    num_train_epochs=20,
    per_device_train_batch_size=1,
    learning_rate=1e-4,
    weight_decay=1e-4,
    warmup_ratio=0.1,
    optim="adamw_torch",
    max_grad_norm=0.25,
    tf32=True,
    fp16=True,
    #dataloader_num_workers=8
    dataloader_pin_memory=False,
    disable_tqdm=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=dummy_collator,
)

trainer.train()

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using cuda_amp half precision backend
***** Running training *****
  Num examples = 80
  Num Epochs = 20
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 1600
  Number of trainable parameters = 1307483
Saving model checkpoint to output/checkpoint-500
Configuration saved in output/checkpoint-500\config.json
Model weights saved in output/checkpoint-500\pytorch_model.bin


{'loss': 77.8211, 'learning_rate': 7.659722222222223e-05, 'epoch': 6.25}


Saving model checkpoint to output/checkpoint-1000
Configuration saved in output/checkpoint-1000\config.json
Model weights saved in output/checkpoint-1000\pytorch_model.bin


{'loss': 47.166, 'learning_rate': 4.1875e-05, 'epoch': 12.5}


Saving model checkpoint to output/checkpoint-1500
Configuration saved in output/checkpoint-1500\config.json
Model weights saved in output/checkpoint-1500\pytorch_model.bin


{'loss': 40.1719, 'learning_rate': 7.1527777777777775e-06, 'epoch': 18.75}




Training completed. Do not forget to share your model on huggingface.co/models =)




{'train_runtime': 17.5871, 'train_samples_per_second': 90.976, 'train_steps_per_second': 90.976, 'train_loss': 54.066346435546876, 'epoch': 20.0}


TrainOutput(global_step=1600, training_loss=54.066346435546876, metrics={'train_runtime': 17.5871, 'train_samples_per_second': 90.976, 'train_steps_per_second': 90.976, 'train_loss': 54.066346435546876, 'epoch': 20.0})

In [30]:
from torch.utils.data import DataLoader

# create a dataloader for evaluation
eval_dataloader = DataLoader(dataset, batch_size=1, collate_fn=dummy_collator)

# get one batch from the dataloader and run it through the model
batch = next(iter(eval_dataloader))
model.cuda()
model.eval()

with torch.no_grad():
    output = model.original_forward(**batch)

print(output['action_preds'][0].round())
print(batch['actions'][0])


tensor([[ 4.,  1.,  2., -0.,  0., -0.,  0., -0.],
        [11., 12., 12., 13.,  2.,  2.,  2.,  3.],
        [ 3.,  0.,  1., -1.,  1.,  0.,  0.,  0.],
        [11., 12., 12., 13.,  2.,  2.,  2.,  2.],
        [12., 11., 11., 10.,  1.,  1.,  1.,  1.],
        [12., 13., 12., 13.,  2.,  2.,  2.,  2.],
        [ 3.,  1.,  1., -1.,  0., -0.,  0., -0.],
        [12., 13., 12., 13.,  2.,  2.,  2.,  2.],
        [ 2., -0.,  0., -2.,  0., -0., -0., -1.],
        [12., 13., 12., 12.,  2.,  2.,  2.,  2.]], device='cuda:0')
tensor([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [16., 18.,  4.,  5.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 4.,  7., 19., 20.,  0.,  0.,  0.,  0.],
        [25., 24., 13.,  7.,  0.,  0.,  0.,  0.],
        [ 0.,  2.,  2.,  7.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [20., 24., 17., 19.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [18., 24., 19., 20.,  0.