In [1]:
import os

import datetime

import datasets
import transformers

from preprocessed_dataset import DecisionTransformerPreprocessedDataset, UnwrapCollator
import torch.utils.data
from transformers import TrainingArguments, Trainer
from decision_transformer import DecisionTransformerConfig, DecisionTransformerModel
from dt_backgammon_env import RandomAgent, TDAgent, DTAgent, DecisionTransformerBackgammonEnv, BLACK, WHITE

import snowietxt_processor


In [2]:
os.environ["WANDB_DISABLED"] = "true"
torch.backends.cuda.matmul.allow_tf32 = True

In [3]:
MAX_LEN = 10
BATCH_SIZE = 64
NUM_EPOCHS = 120
LOG_DIR = os.path.join('saved_models', 'decision_transformer', datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))

### Step 4: Defining a custom DataCollator for the transformers Trainer class

In [4]:
dataset = snowietxt_processor.create_dataset()
dataset = datasets.Dataset.from_dict(dataset)

Number of games 5105


100%|██████████| 5105/5105 [00:01<00:00, 4308.02it/s]


In [5]:
dataset = DecisionTransformerPreprocessedDataset(dataset, max_len=MAX_LEN, batch_size=BATCH_SIZE)

Preprocessing dataset: 100%|██████████| 480/480 [03:11<00:00,  2.51batch/s]


### Step 5: Extending the Decision Transformer Model to include a loss function

In order to train the model with the 🤗 trainer class, we first need to ensure the dictionary it returns contains a loss, in this case L-2 norm of the models action predictions and the targets.

In [11]:
config = DecisionTransformerConfig(state_dim=dataset.state_dim, act_dim=dataset.act_dim, max_length=MAX_LEN)
model = DecisionTransformerModel(config)

In [7]:
config.max_length

10

### Step 6: Defining the training hyperparameters and training the model
Here, we define the training hyperparameters and our Trainer class that we'll use to train our Decision Transformer model.

This step takes about an hour, so you may leave it running. Note the authors train for at least 3 hours, so the results presented here are not as performant as the models hosted on the 🤗 hub.

In [12]:
class EvaluateModelCallback(transformers.integrations.TensorBoardCallback):
    def __init__(self, model, num_episodes):
        super().__init__()
        self.model = model
        self.num_episodes = num_episodes

        self.first_log = False # there are some scalars that we only want to log once, but we don't want to log them until the first time we log

        self.random_agent = RandomAgent(BLACK)
        self.beginner_agent = TDAgent(BLACK, 'beginner')
        self.intermediate_agent = TDAgent(BLACK, 'intermediate')

        self.dt_agent = DTAgent(WHITE, self.model)

        self.backgammon_env = DecisionTransformerBackgammonEnv()


    def on_log(self, args, state, control, logs=None, **kwargs):
        super().on_log(args, state, control, logs, **kwargs)

        if not self.first_log:
            # log the number of episodes we're evaluating on
            self.tb_writer.add_scalar("eval/num_episodes", self.num_episodes, 0)

        # log the number of games won by the decision transformer agent
        wins_random = self.backgammon_env.evaluate_agents({WHITE: self.dt_agent, BLACK: self.random_agent}, self.num_episodes, verbose=0)[WHITE]
        wins_beginner = self.backgammon_env.evaluate_agents({WHITE: self.dt_agent, BLACK: self.beginner_agent}, self.num_episodes, verbose=0)[WHITE]
        wins_intermediate = self.backgammon_env.evaluate_agents({WHITE: self.dt_agent, BLACK: self.intermediate_agent}, self.num_episodes, verbose=0)[WHITE]

        self.tb_writer.add_scalar("eval/wins/random", wins_random, state.epoch)
        self.tb_writer.add_scalar("eval/wins/beginner", wins_beginner, state.epoch)
        self.tb_writer.add_scalar("eval/wins/intermediate", wins_intermediate, state.epoch)



In [14]:
training_args = TrainingArguments(
    remove_unused_columns=False,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=1,
    learning_rate=1e-4,
    weight_decay=1e-4,
    warmup_ratio=0.1,
    optim="adamw_torch",
    max_grad_norm=0.25,
    tf32=True,
    fp16=True,
    dataloader_pin_memory=False,
    logging_dir=LOG_DIR,
    output_dir=LOG_DIR
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=UnwrapCollator(),
)

# remove the old tensorboard trainer callback
trainer.remove_callback(transformers.integrations.TensorBoardCallback)
# add our own tensorboard/evaluation callback
trainer.add_callback(EvaluateModelCallback(model, 20))

trainer.train()

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using cuda_amp half precision backend
  logger.warn(
***** Running training *****
  Num examples = 480
  Num Epochs = 120
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 57600
  Number of trainable parameters = 1349795


Step,Training Loss


IndexError: too many indices for tensor of dimension 1

In [10]:
from torch.utils.data import DataLoader

# create a dataloader for evaluation
eval_dataloader = DataLoader(dataset, batch_size=1, collate_fn=UnwrapCollator())

# get one batch from the dataloader and run it through the model
batch = next(iter(eval_dataloader))
model.cuda()
model.eval()

with torch.no_grad():
    output = model.forward(**batch)

print(output['action_preds'][1].argmax(dim=-1))
print(batch['actions'][1])


tensor([[ 0,  0,  0,  0,  0,  0,  0,  0],
        [ 5,  1,  5,  4,  0,  0,  0,  0],
        [ 0,  6,  0,  0,  0,  0,  0,  0],
        [ 3,  0,  4,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0],
        [ 4,  2,  4,  0,  0,  0,  0,  0],
        [ 0,  6, 18, 19,  0,  0,  0,  0],
        [ 3,  0,  3,  1,  0,  0,  0,  0],
        [ 6,  9,  6,  8,  0,  0,  0,  0],
        [ 1,  0,  2,  0,  0,  0,  0,  0]], device='cuda:0')
tensor([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 5.,  1.,  5.,  4.,  0.,  0.,  0.,  0.],
        [ 0.,  6.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 3.,  0.,  4.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 4.,  2.,  4.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  6., 18., 19.,  0.,  0.,  0.,  0.],
        [ 3.,  0.,  3.,  1.,  0.,  0.,  0.,  0.],
        [ 6.,  9.,  6.,  8.,  0.,  0.,  0.,  0.],
        [ 1.,  0.,  2.,  0.,  0.,  0.,  0.,  0.]], device='cuda:0')
