In [1]:
import os
import random
from dataclasses import dataclass

import numpy as np
import torch
from datasets import load_dataset, load_from_disk
from transformers import DecisionTransformerConfig, DecisionTransformerModel, Trainer, TrainingArguments

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
dataset = load_from_disk('data/dataset/')

Loading dataset from disk:   0%|          | 0/18 [00:00<?, ?it/s]

In [None]:
type(dataset['train'])

In [4]:
state_mean = dataset['state_mean']
state_std = dataset['state_std']

In [5]:
dataset = dataset['train']

In [20]:
len(dataset), len(dataset[0])

(4999, 4)

In [21]:
dataset[0].keys()

dict_keys(['actions', 'dones', 'observations', 'rewards'])

In [22]:
act_dim = len(dataset[0]['actions'][0])
act_dim

394

In [23]:
state_dim = len(dataset[0]['observations'][0])
state_dim

3941

In [24]:
state_mean = state_mean[:state_dim]
state_std = state_std[:state_dim]


In [11]:
len(dataset[0]['observations'])

100

In [23]:
dataset[0]['actions'][:2]

[[-0.0671696737408638,
  -0.18361186981201172,
  0.2469455599784851,
  -0.01885118894279003,
  0.10628541558980942,
  -0.2710021436214447,
  0.07851219177246094,
  0.23499460518360138,
  0.25581544637680054,
  0.05722948908805847,
  0.2383994609117508,
  -0.11575343459844589,
  0.18378931283950806,
  -0.20906339585781097,
  -0.0037708738818764687,
  0.1894841343164444,
  0.20058195292949677,
  -0.16733631491661072,
  -0.20673605799674988,
  -0.04827669635415077,
  -0.02904391475021839,
  0.1855313777923584,
  -0.393166184425354,
  -0.083065465092659,
  -0.09834568947553635,
  -0.136605367064476,
  -0.5146928429603577,
  0.1899474561214447,
  0.007351686712354422],
 [-0.0671696737408638,
  -0.18361186981201172,
  0.2469455599784851,
  -0.01885118894279003,
  0.10628541558980942,
  -0.2710021436214447,
  0.07851219177246094,
  0.23499460518360138,
  0.25581544637680054,
  0.05722948908805847,
  0.2383994609117508,
  -0.11575343459844589,
  0.18378931283950806,
  -0.20906339585781097,
  -

### Class Overview

This is a data collator that prepares batches of RL trajectories for training a Decision Transformer model. It handles the complex task of sampling trajectory segments, computing returns-to-go, and formatting everything for the transformer architecture.

### Key Design Decisions

1. **Random segment sampling**: Instead of using full episodes, samples random segments for better generalzation.
2. **Undiscounted returns**: Use `gamma=1.0` rather than traditional RL discounting.
3. **Sentinel padding**: Uses distinct sentinel values $(-10.0, 2)$ to distinguish padding from real data.
4. **State normalization**: Normalizes states using dataset statistics for stable training.
5. **Length-weighted sampling**: Intended to give more weight to longer trajectories.

### Purpose in Decision Transformer Training

This collator enables the Decision Transformer to learn from offline RL data by:

* Conditioning actions on states, returns-to-go, and timesteps
* Using attention mechanism to process variable-length sequence
* Learning to predict actions given desired future returns

The resulting batch format matches exactly what the Decision Transformer model expects for training.

### Initialization

1. **Dynamic dimension detection (lines 15 - 16)**: Automatically detects the action dimention `act_dim` and the state dimension `state_dim` from the first sample.
2. **Store normalization stats (line 18-19)**: Keeps the state mean `state_mean` and state standard deviation `state_std` for state normalization.
3. **Trajectory sampling weights (lines 24-28)**: 

### Return-to-Go Calculation

**Lines 30-35**: Computes discounted comulative returns:

* Takes rewards and gamma (discount factor)
* Computes backwards cumulative sum 
* Used the Decision Transformer's condition on future returns

### Main Batch Processing

#### 1. Batch Sampling (lines 38-45)

This part samples trajectory indices according to `p_sample` distribution.

#### 2. Sequence Extraction (lines 47-69)

For each sample trajectory

* **Random start point (line 52)**:
* **Extract sequences (lines 55-57)**: Takes the maximum length `max_len` timesteps starting from `si`.
* Compute returns-to-go (lines 62-66): 
 - uses `_discount_cumsum` with `gamma=1.0` (undiscounted)
 - Takes only the first `max_len elements to match sequence length
* **Handling edge case (lines 67-69): If RTG is shorter than sequence, pad with zeros



In [25]:
@dataclass
class DecisionTransformerGymDataCollator:
    return_tensors: str = "pt"
    max_len: int = 20 #subsets of the episode we use for training
    state_dim: int = 3941  # size of state space
    act_dim: int =  394 # size of action space
    max_ep_len: int = 985 # max episode length in the dataset
    scale: float = 1000.0  # normalization of rewards/returns
    #state_mean: np.array = None  # to store state means
    #state_std: np.array = None  # to store state stds
    p_sample: np.array = None  # a distribution to take account trajectory lengths
    n_traj: int = 0 # to store the number of trajectories in the dataset

    def __init__(self, dataset, state_mean, state_std) -> None:
        self.act_dim = len(dataset[0]['actions'][0])
        self.state_dim = len(dataset[0]['observations'][0])
        self.dataset = dataset
        self.state_mean = state_mean
        self.state_std = state_std
        # calculate dataset stats for normalization of states
        states = []
        traj_lens = []

        self.n_traj = len(self.dataset)

        traj_lens = [len(self.dataset[i]) for i in range(self.n_traj)]
        traj_lens = np.array(traj_lens)
        self.p_sample = traj_lens / sum(traj_lens)

    def _discount_cumsum(self, x, gamma):
        discount_cumsum = np.zeros_like(x)
        discount_cumsum[-1] = x[-1]
        for t in reversed(range(x.shape[0] - 1)):
            discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
        return discount_cumsum

    def __call__(self, features):
        batch_size = len(features)
        # this is a bit of a hack to be able to sample of a non-uniform distribution
        batch_inds = np.random.choice(
            np.arange(self.n_traj),
            size=batch_size,
            replace=True,
            p=self.p_sample,  # reweights so we sample according to timesteps
        )
        # a batch of dataset features
        s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []

        for ind in batch_inds:
            # for feature in features:
            feature = self.dataset[int(ind)]
            si = random.randint(0, len(feature["rewards"]) - 1)

            # get sequences from dataset
            s.append(np.array(feature["observations"][si : si + self.max_len]).reshape(1, -1, self.state_dim))
            a.append(np.array(feature["actions"][si : si + self.max_len]).reshape(1, -1, self.act_dim))
            r.append(np.array(feature["rewards"][si : si + self.max_len]).reshape(1, -1, 1))

            d.append(np.array(feature["dones"][si : si + self.max_len]).reshape(1, -1))
            timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))
            timesteps[-1][timesteps[-1] >= self.max_ep_len] = self.max_ep_len - 1  # padding cutoff
            rtg.append(
                self._discount_cumsum(np.array(feature["rewards"][si:]), gamma=1.0)[
                    : s[-1].shape[1]   # TODO check the +1 removed here
                ].reshape(1, -1, 1)
            )
            if rtg[-1].shape[1] < s[-1].shape[1]:
                print("if true")
                rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)

            # padding and state + reward normalization
            tlen = s[-1].shape[1]
            s[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, self.state_dim)), s[-1]], axis=1)
            s[-1] = (s[-1] - self.state_mean) / self.state_std
            a[-1] = np.concatenate(
                [np.ones((1, self.max_len - tlen, self.act_dim)) * -10.0, a[-1]],
                axis=1,
            )
            r[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, 1)), r[-1]], axis=1)
            d[-1] = np.concatenate([np.ones((1, self.max_len - tlen)) * 2, d[-1]], axis=1)
            rtg[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, 1)), rtg[-1]], axis=1) / self.scale
            timesteps[-1] = np.concatenate([np.zeros((1, self.max_len - tlen)), timesteps[-1]], axis=1)
            mask.append(np.concatenate([np.zeros((1, self.max_len - tlen)), np.ones((1, tlen))], axis=1))

        s = torch.from_numpy(np.concatenate(s, axis=0)).float()
        a = torch.from_numpy(np.concatenate(a, axis=0)).float()
        r = torch.from_numpy(np.concatenate(r, axis=0)).float()
        d = torch.from_numpy(np.concatenate(d, axis=0))
        rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).float()
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).long()
        mask = torch.from_numpy(np.concatenate(mask, axis=0)).float()

        return {
            "states": s,
            "actions": a,
            "rewards": r,
            "returns_to_go": rtg,
            "timesteps": timesteps,
            "attention_mask": mask,
        }

Here we are creating an instance of the data collator called `collator` that will be used by the Hugging Face `Trainer` to prepare batches of training data for the Decision Transformer model.

This uses the `dataset` which are the PPO-generated trajectory data loaded from disk, the state normalization statistics: the mean values  (`state_mean`) and the standard deviation (`state_std`).

Then the collator will do the following:

* Sample random trajectory segments during training
* Compute returns-to-go for each sequence
* Pad sequence to uniform length
* Normalize states using the provided statistics
* Convert everything to PyTorch tensors

The Decision Transformers have specific data requirements that differ from standard transformer models

* **Sequence formatting**: Need states, actions, rewards, returns-to-go, and timesteps
* **Attention masking**: Requires proper masking for variable-length sequences
* **Random sampling**: Samples random trajectory segments rather than full episodes

The `Trainer` class doesn't know how to handle RL trajectory data, so the custom collator bridges this gap by converting the raw dataset into the exact format the Decision Transformer expects the training.

In essence, this line prepares the data processing pipeline the will tranform raw PPO trajectories into properly formatting training batches for the Decision Transformer model.

In [26]:
collator = DecisionTransformerGymDataCollator(dataset, state_mean, state_std)

In [27]:
import random
from dataclasses import dataclass

import numpy as np
import torch

from transformers import DecisionTransformerConfig, DecisionTransformerModel, Trainer, TrainingArguments

The **main purpose** is to bridge the gap between:

* **Decision Transformer**: Which normally returns model outputs (logits, hidden states, etc.)
* **Hugging Face Trainer**: Which expects models to return a loss dictionary during training


This is the key modification that makes the model trainable with Hugging Face's `Trainer`:

1. **Get base model output**: `output = super().forward(**kwarg)` calls the original Decision Transformer forward pass
2. **Extract action prediction**: `action_preds = output[1]` get the predicted actions (the second element of the output tuple)
3. **Get target actions**: `action_targets = kwargs['actions']` gets the ground truth actions from the batch
4. **Get attention mask**: `attention_mask = kwargs['attention_mask']` gets the mask that identifies real vs. padding tokens
5. **Flatten and filter**:
* Reshapes predictions and target to 2D: `(batch_size * seq_len, action_dim)`
* Uses the attention mask to select only real (non-padding) token: `[attention_mask.reshape(-1) > 0]`
6. **Compute MSE Loss**: `torch.mean((action_preds - action_targets) ** 2)` calculates mean squared error between predicted and target actions
7. **Return loss dict**: Returns `{'loss':loss}` which is the format expected by Hugging Face `Trainer`

In [28]:
class TrainableDT(DecisionTransformerModel):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, **kwargs):
        output = super().forward(**kwargs)
        # add the DT loss
        action_preds = output[1]
        action_targets = kwargs["actions"]
        attention_mask = kwargs["attention_mask"]
        act_dim = action_preds.shape[2]
        action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        action_targets = action_targets.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]

        loss = torch.mean((action_preds - action_targets) ** 2)

        return {"loss": loss}

    def original_forward(self, **kwargs):
        return super().forward(**kwargs)

This creates a configuration object for the Decision Transformer with:

* `state_dim=collator.state_dim`: Sets the state space dimension (291 dimensions)
* `act_dim=collator.act_dim`: Sets the action space dimension (29, one action per stock)

The configuration object contains all the hyperparameters and architectural settings needed to build the transformer model.

In [29]:
config = DecisionTransformerConfig(state_dim=collator.state_dim, act_dim=collator.act_dim)

This creates an instance of the custom `TrainsableDT` class using the configuration:
* Instantiates the transformer architecture with the specified dimensions
* Sets up all the neural networks layers (embeddings, attention blocks, output heads)
* Makes it compatible with Hugging Face Trainer (due to the custom `TrainerDT` wrapper)

The `DecisionTransformerConfig` automatically sets up the Token embeddings, transformer encoder with attention mechanisms, output heads for predicting actions, and the default hyperparameters (hidden size, number of layers, etc.)

The creates the trainable model that will:
1. Recieve bateches from the data collator (states, actions, rewards, etc.)
2. Process sequences through the transformer architecture
3. Predict actions conditioned on states and returns-to-go
4. Compute loss against ground truth actions (via the custom `forward` method)
5. Update weights through backpropagation during training

This is the core model that will learn to replication the PPO agent's trading beahvior in a conditional, goal-directed manner.

In [30]:

model = TrainableDT(config)

In [31]:
os.environ["WANDB_DISABLED"] = "true"

In [32]:
training_args = TrainingArguments(
    output_dir="output/",
    remove_unused_columns=False,
    num_train_epochs=120,
    per_device_train_batch_size=64,
    learning_rate=1e-4,
    weight_decay=1e-4,
    warmup_ratio=0.1,
    optim="adamw_torch",
    max_grad_norm=0.25,
)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


### Parameter Breakdown

#### Training Duration
* `num_train_epochs=120`: Train for 120 complete epochs (full passes through the dataset)
* `per_device_train_batch_size=64`: Process 64 trajectory segments per batch

#### Optimization
* `learning_rate=1e-4`: Learning rate of 0.00001 (conservative for transformer training)
* `weight_decay=1e-4`: L2 regularization to prevent overfitting
* `optim='adamw_torch`: Uses AdamW optimizer (Adam with decoupled weight decay)
* `max_grad_norm=0.25`: Clips gradients to prevent exploding gradients

#### Learning Schedule 
* `warmup_ratio=0.1`: Gradually increases learning rate from 0 to full rate over first 10% of training (12 epochs)

#### Data Handling
* `remove_unused_columns=False`: Keeps all dataset columns (important for RL data which has states, actions, rewards, etc.)

In [33]:
training_args.device

device(type='cuda', index=0)

In [34]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=collator,
)

trainer.train()

Step,Training Loss
500,0.018
1000,0.0014
1500,0.001
2000,0.0005
2500,0.0003
3000,0.0002
3500,0.0002
4000,0.0001
4500,0.0001
5000,0.0001


TrainOutput(global_step=9480, training_loss=0.001196348162187177, metrics={'train_runtime': 2378.5035, 'train_samples_per_second': 252.209, 'train_steps_per_second': 3.986, 'total_flos': 4.851370333119744e+17, 'train_loss': 0.001196348162187177, 'epoch': 120.0})

In [35]:
trainer.save_model('trained_models')