In [1]:
import os

import datetime

import datasets
import transformers

from preprocessed_dataset import DecisionTransformerPreprocessedDataset, UnwrapCollator
import torch.utils.data
from transformers import TrainingArguments, Trainer
from decision_transformer import DecisionTransformerConfig, DecisionTransformerModel
from dt_backgammon_env import RandomAgent, TDAgent, DTAgent, DecisionTransformerBackgammonEnv, BLACK, WHITE

import snowietxt_processor


In [2]:
os.environ["WANDB_DISABLED"] = "true"
torch.backends.cuda.matmul.allow_tf32 = True

In [3]:
MAX_LEN = 10
BATCH_SIZE = 64
NUM_EPOCHS = 120
LOG_DIR = os.path.join('saved_models', 'decision_transformer', datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))

### Step 4: Defining a custom DataCollator for the transformers Trainer class

In [4]:
dataset = snowietxt_processor.create_dataset()
dataset = datasets.Dataset.from_dict(dataset)

Number of games 5105


100%|██████████| 5105/5105 [00:01<00:00, 4137.81it/s]


In [5]:
dataset = DecisionTransformerPreprocessedDataset(dataset, max_len=MAX_LEN, batch_size=BATCH_SIZE)

Preprocessing dataset: 100%|██████████| 480/480 [03:28<00:00,  2.31batch/s]


### Step 5: Extending the Decision Transformer Model to include a loss function

In order to train the model with the 🤗 trainer class, we first need to ensure the dictionary it returns contains a loss, in this case L-2 norm of the models action predictions and the targets.

In [6]:
config = DecisionTransformerConfig(state_dim=dataset.state_dim, act_dim=dataset.act_dim, max_length=MAX_LEN)
model = DecisionTransformerModel(config)

In [7]:
config.max_length

10

### Step 6: Defining the training hyperparameters and training the model
Here, we define the training hyperparameters and our Trainer class that we'll use to train our Decision Transformer model.

This step takes about an hour, so you may leave it running. Note the authors train for at least 3 hours, so the results presented here are not as performant as the models hosted on the 🤗 hub.

In [8]:
class EvaluateModelCallback(transformers.integrations.TensorBoardCallback):
    def __init__(self, model, num_episodes):
        super().__init__()
        self.model = model
        self.num_episodes = num_episodes

        self.first_log = False # there are some scalars that we only want to log once, but we don't want to log them until the first time we log

        self.random_agent = RandomAgent(BLACK)
        self.beginner_agent = TDAgent(BLACK, 'beginner')
        self.intermediate_agent = TDAgent(BLACK, 'intermediate')

        self.dt_agent = DTAgent(WHITE, self.model)

        self.backgammon_env = DecisionTransformerBackgammonEnv()


    def on_log(self, args, state, control, logs=None, **kwargs):
        super().on_log(args, state, control, logs, **kwargs)

        if not self.first_log:
            # log the number of episodes we're evaluating on
            self.tb_writer.add_scalar("eval/num_episodes", self.num_episodes, 0)

        self.model.eval()

        # log the number of games won by the decision transformer agent
        wins_random = self.backgammon_env.evaluate_agents({WHITE: self.dt_agent, BLACK: self.random_agent}, self.num_episodes, verbose=0)[WHITE]
        wins_beginner = self.backgammon_env.evaluate_agents({WHITE: self.dt_agent, BLACK: self.beginner_agent}, self.num_episodes, verbose=0)[WHITE]
        wins_intermediate = self.backgammon_env.evaluate_agents({WHITE: self.dt_agent, BLACK: self.intermediate_agent}, self.num_episodes, verbose=0)[WHITE]

        self.model.train()

        self.tb_writer.add_scalar("eval/wins/random", wins_random, state.epoch)
        self.tb_writer.add_scalar("eval/wins/beginner", wins_beginner, state.epoch)
        self.tb_writer.add_scalar("eval/wins/intermediate", wins_intermediate, state.epoch)



In [9]:
training_args = TrainingArguments(
    remove_unused_columns=False,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=1,
    learning_rate=1e-4,
    weight_decay=1e-4,
    warmup_ratio=0.1,
    optim="adamw_torch",
    max_grad_norm=0.25,
    tf32=True,
    fp16=True,
    dataloader_pin_memory=False,
    logging_dir=LOG_DIR,
    output_dir=LOG_DIR
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=UnwrapCollator(),
)

# remove the old tensorboard trainer callback
trainer.remove_callback(transformers.integrations.TensorBoardCallback)
# add our own tensorboard/evaluation callback
trainer.add_callback(EvaluateModelCallback(model, 20))

trainer.train()

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using cuda_amp half precision backend
  logger.warn(
***** Running training *****
  Num examples = 480
  Num Epochs = 120
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 57600
  Number of trainable parameters = 1349795


Step,Training Loss
500,3.1541
1000,2.3115
1500,1.6281
2000,1.3432
2500,1.2032
3000,1.105
3500,1.0409
4000,0.9952
4500,0.958
5000,0.9253


Saving model checkpoint to saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-500
Configuration saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-500\config.json
Model weights saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-500\pytorch_model.bin
Saving model checkpoint to saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-1000
Configuration saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-1000\config.json
Model weights saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-1000\pytorch_model.bin
Saving model checkpoint to saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-1500
Configuration saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-1500\config.json
Model weights saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-1500\pytorch_model.bin
Saving model checkpoint to saved_models\decision_transformer\202

Configuration saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-13500\config.json
Model weights saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-13500\pytorch_model.bin
Saving model checkpoint to saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-14000
Configuration saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-14000\config.json
Model weights saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-14000\pytorch_model.bin
Saving model checkpoint to saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-14500
Configuration saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-14500\config.json
Model weights saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-14500\pytorch_model.bin
Saving model checkpoint to saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-15000
Configuration saved in saved_models\decision_transfo

Configuration saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-26500\config.json
Model weights saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-26500\pytorch_model.bin
Saving model checkpoint to saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-27000
Configuration saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-27000\config.json
Model weights saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-27000\pytorch_model.bin
Saving model checkpoint to saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-27500
Configuration saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-27500\config.json
Model weights saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-27500\pytorch_model.bin
Saving model checkpoint to saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-28000
Configuration saved in saved_models\decision_transfo

Configuration saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-39500\config.json
Model weights saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-39500\pytorch_model.bin
Saving model checkpoint to saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-40000
Configuration saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-40000\config.json
Model weights saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-40000\pytorch_model.bin
Saving model checkpoint to saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-40500
Configuration saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-40500\config.json
Model weights saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-40500\pytorch_model.bin
Saving model checkpoint to saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-41000
Configuration saved in saved_models\decision_transfo

Configuration saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-52500\config.json
Model weights saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-52500\pytorch_model.bin
Saving model checkpoint to saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-53000
Configuration saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-53000\config.json
Model weights saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-53000\pytorch_model.bin
Saving model checkpoint to saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-53500
Configuration saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-53500\config.json
Model weights saved in saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-53500\pytorch_model.bin
Saving model checkpoint to saved_models\decision_transformer\2023-05-14-18-53-10\checkpoint-54000
Configuration saved in saved_models\decision_transfo

TrainOutput(global_step=57600, training_loss=0.37262396747867266, metrics={'train_runtime': 3125.9774, 'train_samples_per_second': 18.426, 'train_steps_per_second': 18.426, 'total_flos': 3.215513459592e+16, 'train_loss': 0.37262396747867266, 'epoch': 120.0})