In [1]:
import os

import numpy as np
import torch

import datasets
from preprocessed_dataset import DecisionTransformerPreprocessedDataset, UnwrapCollator
import torch.utils.data
from transformers import TrainingArguments, Trainer
from decision_transformer import DecisionTransformerConfig, DecisionTransformerModel

import snowietxt_processor


In [2]:
os.environ["WANDB_DISABLED"] = "true"
torch.backends.cuda.matmul.allow_tf32 = True

In [None]:
MAX_LEN = 10
BATCH_SIZE = 64
NUM_EPOCHS = 120

### Step 4: Defining a custom DataCollator for the transformers Trainer class

In [3]:
dataset = snowietxt_processor.create_dataset()
dataset = datasets.Dataset.from_dict(dataset)

Number of games 5105


100%|██████████| 5105/5105 [00:01<00:00, 2677.15it/s]


In [4]:
dataset = DecisionTransformerPreprocessedDataset(dataset, max_len=MAX_LEN, batch_size=BATCH_SIZE)

Preprocessing dataset: 100%|██████████| 480/480 [04:20<00:00,  1.84batch/s]


### Step 5: Extending the Decision Transformer Model to include a loss function

In order to train the model with the 🤗 trainer class, we first need to ensure the dictionary it returns contains a loss, in this case L-2 norm of the models action predictions and the targets.

In [5]:
config = DecisionTransformerConfig(state_dim=dataset.state_dim, act_dim=dataset.act_dim, max_length=MAX_LEN)
model = DecisionTransformerModel(config)

In [6]:
config.max_length

20

### Step 6: Defining the training hyperparameters and training the model
Here, we define the training hyperparameters and our Trainer class that we'll use to train our Decision Transformer model.

This step takes about an hour, so you may leave it running. Note the authors train for at least 3 hours, so the results presented here are not as performant as the models hosted on the 🤗 hub.

In [9]:
training_args = TrainingArguments(
    output_dir="output/",
    remove_unused_columns=False,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=1,
    learning_rate=1e-4,
    weight_decay=1e-4,
    warmup_ratio=0.1,
    optim="adamw_torch",
    max_grad_norm=0.25,
    tf32=True,
    fp16=True,
    dataloader_pin_memory=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=UnwrapCollator(),
)

trainer.train()

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using cuda_amp half precision backend
***** Running training *****
  Num examples = 480
  Num Epochs = 20
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 9600
  Number of trainable parameters = 1349795


Step,Training Loss
500,2.6017
1000,1.3393
1500,1.0932
2000,1.0055
2500,0.9619
3000,0.9273
3500,0.8995
4000,0.8738
4500,0.8489
5000,0.8207


Saving model checkpoint to output/checkpoint-500
Configuration saved in output/checkpoint-500\config.json
Model weights saved in output/checkpoint-500\pytorch_model.bin
Saving model checkpoint to output/checkpoint-1000
Configuration saved in output/checkpoint-1000\config.json
Model weights saved in output/checkpoint-1000\pytorch_model.bin
Saving model checkpoint to output/checkpoint-1500
Configuration saved in output/checkpoint-1500\config.json
Model weights saved in output/checkpoint-1500\pytorch_model.bin
Saving model checkpoint to output/checkpoint-2000
Configuration saved in output/checkpoint-2000\config.json
Model weights saved in output/checkpoint-2000\pytorch_model.bin
Saving model checkpoint to output/checkpoint-2500
Configuration saved in output/checkpoint-2500\config.json
Model weights saved in output/checkpoint-2500\pytorch_model.bin
Saving model checkpoint to output/checkpoint-3000
Configuration saved in output/checkpoint-3000\config.json
Model weights saved in output/check

TrainOutput(global_step=9600, training_loss=0.9416190870602925, metrics={'train_runtime': 127.9303, 'train_samples_per_second': 75.041, 'train_steps_per_second': 75.041, 'total_flos': 5359189099320000.0, 'train_loss': 0.9416190870602925, 'epoch': 20.0})

In [14]:
from torch.utils.data import DataLoader

# create a dataloader for evaluation
eval_dataloader = DataLoader(dataset, batch_size=1, collate_fn=UnwrapCollator())

# get one batch from the dataloader and run it through the model
batch = next(iter(eval_dataloader))
model.cuda()
model.eval()

with torch.no_grad():
    output = model.forward(**batch)

print(output['action_preds'][1].argmax(dim=-1))
print(batch['actions'][1])


tensor([[25, 23, 13, 10,  6,  5,  6,  4],
        [17, 21, 19, 22,  0,  0,  0,  0],
        [25, 23, 24, 22,  0,  0,  0,  0],
        [12, 16, 15, 18,  0,  0,  0,  0],
        [18, 16, 13, 11,  0,  0,  0,  0],
        [17, 20, 19, 23,  0,  0,  0,  0],
        [25, 20,  8,  5,  0,  0,  0,  0],
        [12, 14, 12, 15,  0,  0,  0,  0],
        [25, 23, 13,  5,  0,  0,  0,  0],
        [ 0,  2, 19, 23,  0,  0,  0,  0]], device='cuda:0')
tensor([[23., 21., 13., 11.,  6.,  4.,  6.,  4.],
        [17., 21., 19., 21.,  0.,  0.,  0.,  0.],
        [25., 22., 24., 20.,  0.,  0.,  0.,  0.],
        [12., 16., 16., 18.,  0.,  0.,  0.,  0.],
        [20., 14., 14., 11.,  0.,  0.,  0.,  0.],
        [18., 19., 19., 23.,  0.,  0.,  0.,  0.],
        [22., 16.,  8.,  4.,  0.,  0.,  0.,  0.],
        [12., 14., 12., 16.,  0.,  0.,  0.,  0.],
        [25., 23., 11.,  5.,  0.,  0.,  0.,  0.],
        [ 0.,  5., 19., 23.,  0.,  0.,  0.,  0.]], device='cuda:0')
