# DRL 2025 - Decision Transformers Midterm
**Authors:**  Aman Garg, Niki Vasan, Ethan Haarer, Karthik Kothuri, Yee Ching Lau


## Setup and Installation

We first install the necessary dependencies needed to run the **Decision Transformer** on a MuJoCo D4RL dataset. Because this is a self-contained notebook, no additional environment or compute configurations are required while using Google CoLab. The primary packages to install are `gymnasium`, `minari`, `torch`.

**Note**: *This notebook was run using Colab Pro (Student License) using an T4 GPU. Colab may already have certain packages -  we cannot guarantee the same results if run locally or in a different environment.*

In [None]:
!pip install gymnasium[mujoco] > /dev/null 2>&1
!pip install minari[all] > /dev/null 2>&1
!pip install torch tqdm numpy opencv-python-headless > /dev/null 2>&1
!pip install moviepy pyvirtualdisplay > /dev/null 2>&1
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1

Next, we import the necessary libraries and set up a basic logger to track our progress.

In [None]:
# Import necessary libraries
import math
import logging
import random
import numpy as np
from tqdm import tqdm
from collections import deque
import cv2 # Import OpenCV

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

import minari
import gymnasium as gym

# Set up logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

## Dataset Selection

Now, we need to select a dataset. For this walkthrough, we will be selecting a MuJoCo dataset. MuJoCo is a physics simulator that contains various continuous control tasks like hopper, walker, halfcheetah and more. In this cell, we print the different options available from `minari`.


In [None]:
# List all available datasets
all_datasets = minari.list_remote_datasets()

# Filter for MuJoCo datasets
mujoco_datasets = [d for d in all_datasets if 'mujoco' in d.lower() or any(env in d.lower() for env in ['hopper', 'walker', 'halfcheetah', 'ant', 'humanoid'])]

print("Available MuJoCo datasets:")
for dataset in mujoco_datasets:
    print(f"  - {dataset}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


namespace_metadata.json: 0.00B [00:00, ?B/s]

namespace_metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

namespace_metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

namespace_metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

namespace_metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

namespace_metadata.json:   0%|          | 0.00/238 [00:00<?, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

metadata.json: 0.00B [00:00, ?B/s]

Available MuJoCo datasets:
  - D4RL/antmaze/medium-play-v1
  - D4RL/antmaze/umaze-diverse-v1
  - D4RL/antmaze/large-diverse-v1
  - D4RL/antmaze/large-play-v1
  - D4RL/antmaze/medium-diverse-v1
  - D4RL/antmaze/umaze-v1
  - mujoco/hopper/expert-v0
  - mujoco/halfcheetah/simple-v0
  - mujoco/walker2d/simple-v0
  - mujoco/ant/medium-v0
  - mujoco/pusher/expert-v0
  - mujoco/invertedpendulum/expert-v0
  - mujoco/inverteddoublependulum/expert-v0
  - mujoco/swimmer/medium-v0
  - mujoco/humanoidstandup/medium-v0
  - mujoco/walker2d/medium-v0
  - mujoco/pusher/medium-v0
  - mujoco/reacher/expert-v0
  - mujoco/humanoid/medium-v0
  - mujoco/reacher/medium-v0
  - mujoco/hopper/simple-v0
  - mujoco/halfcheetah/expert-v0
  - mujoco/invertedpendulum/medium-v0
  - mujoco/humanoid/simple-v0
  - mujoco/humanoidstandup/expert-v0
  - mujoco/ant/expert-v0
  - mujoco/walker2d/expert-v0
  - mujoco/humanoid/expert-v0
  - mujoco/humanoidstandup/simple-v0
  - mujoco/halfcheetah/medium-v0
  - mujoco/ant/simple-

Here, we specifically load trajectories `(s_t, a_t, r_t)` from the `hopper/medium-v0` dataset.

We can see the total number of episodes (or trajectories) included. We also see the total number of steps, which is the sum of per-episode timesteps across the entire dataset.

In [None]:
dataset_id = "mujoco/hopper/medium-v0"

# Download the dataset
minari.download_dataset(dataset_id)

# Load the dataset
dataset = minari.load_dataset(dataset_id)

print(f"Dataset: {dataset_id}")
print(f"Number of episodes: {len(dataset)}")
print(f"Total steps: {dataset.total_steps}")

namespace_metadata.json: 0.00B [00:00, ?B/s]

namespace_metadata.json:   0%|          | 0.00/110 [00:00<?, ?B/s]


Downloading mujoco/hopper/medium-v0 from Farama servers...


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

hopper/medium-v0/data/main_data.hdf5:   0%|          | 0.00/140M [00:00<?, ?B/s]


Dataset mujoco/hopper/medium-v0 downloaded to /root/.minari/datasets/mujoco/hopper/medium-v0
Dataset: mujoco/hopper/medium-v0
Number of episodes: 1327
Total steps: 999404


In [None]:
episode = dataset[0]

print(f"Observations shape: {episode.observations.shape}")
print(f"Actions shape: {episode.actions.shape}")
print(f"Rewards shape: {episode.rewards.shape}")

# Check observation and action dimensions
print(f"\nObservation dimension: {episode.observations.shape[-1]}")
print(f"Action dimension: {episode.actions.shape[-1]}")

Observations shape: (1001, 11)
Actions shape: (1000, 3)
Rewards shape: (1000,)

Observation dimension: 11
Action dimension: 3


## Model Setup

Let's set up the model.

### Summary of Architecture

The central idea of the paper is to treat RL as a *sequence modeling* task. Tokenize each timestep as (return, state, action) and train a Transformer to predict actions from past tokens while causally masking future tokens to prevent lookahead bias. A positional episodic timestep encoding is added. States and actions have their own encoders. Much of the architecture is similar to the vanilla Transformer introduced by Vaswani et. al.

![architecture](https://drive.google.com/uc?export=view&id=1sKjqP9VyJQyZOI9bUEm9X9SXFvJELF_m)

### Trajectory Representation

DTs condition on *future rewards* as opposed to past reward. Thus, the authors compute the per-timestemp returns to go as the sum of future rewards, also known as **returns-to-go**.

$$\hat{R}_t=\sum_{t'=t}^{T} r_{t'} \quad\text{and}\quad \tau=(\hat{R}_1,s_1,a_1,\hat{R}_2,s_2,a_2,\ldots,\hat{R}_T,s_T,a_T).$$

### Credit Assignment
Transformers are formed with stacked layers of self-attention layers with residual connections. Each of these attention layers has $n$ embeddings that correspond to unique input tokens, and the transformer outputs $n$ embeddings as well. The $i^{th}$ token is mapped via linear transformations to it's own key, query, and value ($k_i$, $q_i$, $v_i$) that is used to weigh the attention layers  by the normalized dot product  between $q_i$ and $k_j$.

$$ z_i = \sum^n_{j = 1} \text{softmax} (\{\langle q_i, k_{j'}\rangle\}^n_{j'=1})_j \cdot v_j $$

This enables the architecture to assign credit by implicitly forcing the associations between states and their returns via similarity. The authors modify the original architecture to work with a causal-self attention mask to enable autoregressive generation, substituting the softmax over the $n$ tokens with only the previously sequenced tokens.

In [None]:
def set_seed(seed):
    """Sets the seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def top_k_logits(logits, k):
    """Filters logits to only the top k values."""
    v, ix = torch.topk(logits, k)
    out = logits.clone()
    out[out < v[:, [-1]]] = -float('Inf')
    return out

@torch.no_grad()
def sample(model, x, steps, temperature=1.0, actions=None, rtgs=None, timesteps=None):
    block_size = model.get_block_size()
    model.eval()

    for k in range(steps):
        x_cond = x if x.size(1) <= block_size // 3 else x[:, -block_size // 3:]
        if actions is not None:
            actions = actions if actions.size(1) <= block_size // 3 else actions[:, -block_size // 3:]
        rtgs = rtgs if rtgs.size(1) <= block_size // 3 else rtgs[:, -block_size // 3:]

        action_preds, _ = model(x_cond, actions=actions, targets=None, rtgs=rtgs, timesteps=timesteps)
        action = action_preds[:, -1, :]

        x = action

    return x

In [None]:
class GELU(nn.Module):
    def forward(self, input):
        return F.gelu(input)

class GPTConfig:
    """ base GPT config, params common to all GPT versions """
    embd_pdrop = 0.1
    resid_pdrop = 0.1
    attn_pdrop = 0.1

    def __init__(self, vocab_size, block_size, **kwargs):
        self.vocab_size = vocab_size
        self.block_size = block_size
        for k, v in kwargs.items():
            setattr(self, k, v)

class CausalSelfAttention(nn.Module):
    """ A vanilla multi-head masked self-attention layer """
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.key = nn.Linear(config.n_embd, config.n_embd)
        self.query = nn.Linear(config.n_embd, config.n_embd)
        self.value = nn.Linear(config.n_embd, config.n_embd)
        self.attn_drop = nn.Dropout(config.attn_pdrop)
        self.resid_drop = nn.Dropout(config.resid_pdrop)
        self.proj = nn.Linear(config.n_embd, config.n_embd)
        self.register_buffer("mask", torch.tril(torch.ones(config.block_size + 1, config.block_size + 1))
                             .view(1, 1, config.block_size + 1, config.block_size + 1))
        self.n_head = config.n_head

    def forward(self, x, layer_past=None):
        B, T, C = x.size()
        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_drop(self.proj(y))
        return y

class Block(nn.Module):
    """ An unassuming Transformer block """
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.resid_pdrop),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class GPT(nn.Module):
    """ The full GPT language model, with a context size of block_size """
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.model_type = config.model_type
        self.block_size = config.block_size

        self.state_encoder = nn.Sequential(
            nn.Linear(config.state_dim, config.n_embd),
            nn.Tanh()
        )

        self.ret_emb = nn.Sequential(nn.Linear(1, config.n_embd), nn.Tanh())

        self.action_encoder = nn.Sequential(
            nn.Linear(config.action_dim, config.n_embd),
            nn.Tanh()
        )

        self.action_predictor = nn.Sequential(
            nn.Linear(config.n_embd, config.action_dim),
            nn.Tanh()  # Assumes normalized actions [-1, 1]
        )

        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size + 1, config.n_embd))
        self.global_pos_emb = nn.Parameter(torch.zeros(1, config.max_timestep + 1, config.n_embd))
        self.drop = nn.Dropout(config.embd_pdrop)
        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(config.n_embd)

    def get_block_size(self):
        return self.block_size

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def configure_optimizers(self, train_config):
        decay, no_decay = set(), set()
        whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d)
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn
                if pn.endswith('bias') or pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    decay.add(fpn)
        no_decay.add('pos_emb')
        no_decay.add('global_pos_emb')
        param_dict = {pn: p for pn, p in self.named_parameters()}
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        return torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)

    def forward(self, states, actions, targets=None, rtgs=None, timesteps=None):
        state_embeddings = self.state_encoder(states.type(torch.float32))

        if actions is not None and self.model_type == 'reward_conditioned':
            rtg_embeddings = self.ret_emb(rtgs.type(torch.float32))
            action_embeddings = self.action_encoder(actions)

            # print(f"actions shape: {actions.shape}")
            # print(f"action_embeddings shape: {action_embeddings.shape}")

            # Interleave: [rtg, state, action, rtg, state, action, ...]
            token_embeddings = torch.zeros(
                (states.shape[0], states.shape[1] * 3 - int(targets is None), self.config.n_embd),
                dtype=torch.float32, device=state_embeddings.device
            )
            token_embeddings[:, ::3, :] = rtg_embeddings
            token_embeddings[:, 1::3, :] = state_embeddings
            token_embeddings[:, 2::3, :] = action_embeddings[:, -states.shape[1] + int(targets is None):, :]
        elif actions is None and self.model_type == 'reward_conditioned':
            rtg_embeddings = self.ret_emb(rtgs.type(torch.float32))
            token_embeddings = torch.zeros(
                (states.shape[0], states.shape[1] * 2, self.config.n_embd),
                dtype=torch.float32, device=state_embeddings.device
            )
            token_embeddings[:, ::2, :] = rtg_embeddings
            token_embeddings[:, 1::2, :] = state_embeddings

            # print(f"states.shape: {states.shape}")
            # print(f"rtgs.shape: {rtgs.shape}")
            # print(f"rtg_embeddings.shape: {rtg_embeddings.shape}")
            # print(f"state_embeddings.shape: {state_embeddings.shape}")
            # print(f"token_embeddings.shape: {token_embeddings.shape}")

        # Positional embeddings
        batch_size = states.shape[0]
        all_global_pos_emb = self.global_pos_emb.repeat(batch_size, 1, 1)
        position_embeddings = torch.gather(
            all_global_pos_emb, 1, timesteps.repeat(1, 1, self.config.n_embd)
        ) + self.pos_emb[:, :token_embeddings.shape[1], :]

        # Transformer
        x = self.drop(token_embeddings + position_embeddings)
        x = self.blocks(x)
        x = self.ln_f(x)

        action_preds = self.action_predictor(x)

        if actions is not None and self.model_type == 'reward_conditioned':
            action_preds = action_preds[:, 1::3, :]
        elif actions is None and self.model_type == 'reward_conditioned':
            action_preds = action_preds[:, 1:, :]

        loss = None
        if targets is not None:
            loss = F.mse_loss(action_preds, targets)

        return action_preds, loss

We use the training pipeline and configuration that the authors of the paper use with no changes.

In [None]:
class TrainerConfig:
    # optimization parameters
    max_epochs = 10
    batch_size = 64
    learning_rate = 3e-4
    betas = (0.9, 0.95)
    grad_norm_clip = 1.0
    weight_decay = 0.1  # only applied on matmul weights
    lr_decay = False
    warmup_tokens = 375e6
    final_tokens = 260e9
    ckpt_path = None
    num_workers = 0

    fast_dev_run = False
    max_train_batches = 10000


    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

class Trainer:
    def __init__(self, model, train_dataset, test_dataset, config):
        self.model = model
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self.config = config
        self.device = 'cpu'
        if torch.cuda.is_available():
            self.device = torch.cuda.current_device()
            self.model = torch.nn.DataParallel(self.model).to(self.device)

    def train(self):
        model, config = self.model, self.config
        raw_model = model.module if hasattr(self.model, "module") else model
        self.optimizer = raw_model.configure_optimizers(config)
        self.tokens = 0

        for epoch in range(config.max_epochs):
            if (epoch + 1) % 10 == 0:
                torch.save(model.state_dict(), f'model_epoch_{epoch+1}.pt')
            self.run_epoch(epoch, 'train')
            if self.test_dataset is not None:
                self.run_epoch(epoch, 'test')

            # Evaluate model performance
            self.get_returns(config.target_return)

    def run_epoch(self, epoch_num, split):
        is_train = split == 'train'
        model, config = self.model, self.config
        model.train(is_train)
        data = self.train_dataset if is_train else self.test_dataset
        loader = DataLoader(
            data,
            shuffle=True,
            pin_memory=True,
            batch_size=config.batch_size,
            num_workers=config.num_workers
        )

        losses = []
        pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)
        for it, (x, y, r, t) in pbar:
            if self.config.fast_dev_run and it >= self.config.max_train_batches:
              break

            x, y, r, t = x.to(self.device), y.to(self.device), r.to(self.device), t.to(self.device)

            with torch.set_grad_enabled(is_train):
                logits, loss = model(x, y, y, r, t)
                loss = loss.mean()
                losses.append(loss.item())

            if is_train:
                model.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
                self.optimizer.step()

                if config.lr_decay:
                    self.tokens += (y >= 0).sum()
                    if self.tokens < config.warmup_tokens:
                        lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens))
                    else:
                        progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
                        lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
                    lr = config.learning_rate * lr_mult
                    for param_group in self.optimizer.param_groups:
                        param_group['lr'] = lr
                else:
                    lr = config.learning_rate

                pbar.set_description(f"epoch {epoch_num+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}")

    def get_returns(self, ret):
        self.model.train(False)
        env = gym.make(self.config.game)

        T_rewards = []
        for i in range(10):
            state, _ = env.reset()
            state = torch.from_numpy(state).to(self.device).unsqueeze(0).unsqueeze(0)
            rtgs = [ret]

            sampled_action = sample(
                self.model.module,
                state, 1, temperature=1.0, actions=None,
                rtgs=torch.tensor(rtgs, dtype=torch.float32).to(self.device).reshape(1,-1,1),
                timesteps=torch.zeros((1, 1, 1), dtype=torch.int64).to(self.device)
            )

            j = 0
            reward_sum = 0
            actions = []
            all_states = state
            done = False

            while not done:
                action = sampled_action.squeeze().cpu().numpy()
                actions.append(sampled_action)

                state, reward, terminated, truncated, _ = env.step(action)
                done = terminated or truncated
                reward_sum += reward
                j += 1

                if done:
                    T_rewards.append(reward_sum)
                    break

                state = torch.from_numpy(state).to(self.device).unsqueeze(0).unsqueeze(0)
                all_states = torch.cat([all_states, state], dim=1)
                rtgs.append(rtgs[-1] - reward)
                # print(f"all_states after cat: {all_states.shape}")
                # print(f"all_states before sample: {all_states.shape}")
                # print(f"rtgs list: {rtgs}")

                sampled_action = sample(
                    self.model.module, all_states, 1, temperature=1.0,
                    actions=torch.stack(actions, dim=1).to(self.device),
                    rtgs=torch.FloatTensor(rtgs).to(self.device).reshape(1, -1, 1),
                    timesteps=(min(j, self.config.max_timestep) * torch.ones((1, 1, 1), dtype=torch.int64).to(self.device))
                )

        env.close()
        eval_return = sum(T_rewards) / len(T_rewards) if T_rewards else 0
        print(f"Target return: {ret}, Average evaluation return: {eval_return:.2f}")
        self.model.train(True)
        return eval_return


### Dataset Configuration

We create a function to download the dataset, iterate through episodes to compute the RTG per timestep and flatten the observations, actions, RTG, timesteps and episode length indices into arrays.

Then, the StateActionReturnDataset provides K-step slices for training. Each of the arrays are converted to tensors such that each slice corresponds to episode boundaries using the indices stored earlier. `(state, action, RTG, t)` tuples are returned.

<!-- **INCLUDE CONFIGURATION INFO** -->
<!-- ^^ Configuration ifor was added blow to make it easier to seperate what is going on -->


In [None]:
def create_mujoco_dataset(dataset_id):
    minari.download_dataset(dataset_id)
    dataset = minari.load_dataset(dataset_id)

    obss, actions, rtgs, timesteps, done_idxs = [], [], [], [], []

    print(f"Processing {len(dataset)} episodes from {dataset_id}...")

    current_len = 0
    for episode in tqdm(dataset.iterate_episodes()):
        ep_actions = episode.actions
        ep_rewards = episode.rewards
        ep_observations = episode.observations
        ep_len = len(ep_rewards)

        # Calculate rewards-to-go
        ep_rtgs = np.zeros_like(ep_rewards, dtype=np.float32)
        ep_rtgs[-1] = ep_rewards[-1]
        for i in reversed(range(ep_len - 1)):
            ep_rtgs[i] = ep_rewards[i] + ep_rtgs[i+1]

        obss.extend(ep_observations[:-1])
        actions.extend(ep_actions)
        rtgs.extend(ep_rtgs)
        timesteps.extend(np.arange(ep_len))
        done_idxs.append(current_len + ep_len - 1)
        current_len += ep_len

    return (
        np.array(obss, dtype=np.float32),
        np.array(actions, dtype=np.float32),
        np.array(rtgs, dtype=np.float32),
        np.array(timesteps, dtype=np.int64),
        np.array(done_idxs, dtype=np.int64)
    )


class StateActionReturnDataset(Dataset):
    def __init__(self, data, actions, rtgs, timesteps, done_idxs, block_size):
        self.data = data
        self.actions = actions
        self.rtgs = rtgs
        self.timesteps = timesteps
        self.done_idxs = done_idxs
        self.block_size = block_size

    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        block_size = self.block_size // 3
        done_idx = idx + block_size

        for i in self.done_idxs:
            if i > idx:
                done_idx = min(int(i), done_idx)
                break
        idx = done_idx - block_size

        states = torch.tensor(self.data[idx:done_idx], dtype=torch.float32)
        actions = torch.tensor(self.actions[idx:done_idx], dtype=torch.float32)
        rtgs = torch.tensor(self.rtgs[idx:done_idx], dtype=torch.float32).unsqueeze(1)
        timesteps = torch.tensor(self.timesteps[idx:idx+1], dtype=torch.int64).unsqueeze(1)

        return states, actions, rtgs, timesteps

### Model Configuration

Finally we are able to set up and train the full Decision Transformer! This cell is loading the offline Mujoco Dataset with a sequence context length of 20, meaning the model is conditioning itself over the immediate previous 20 timesteps.

Then the actions, observations, RTG (return to go), and step timesteps to flesh out the dataset proper for 3k tokens (which, as mentioned earlier, holds the trajectory triplets for each timestamp).

Then, the GPT is configured to have 6 layers, 8 heads, and 128 embeddings, replacing what would normally be the vocabulary size for the language the GPT would generate for the action space. Finally, the training can begin with the learning rate being higher for more continuous actions, and across the supervised training is minimizing the MSE loss between the action and the predicted action.

In [None]:
# --- Configuration ---
class Config:
    seed = 123
    context_length = 20
    epochs = 5
    model_type = 'reward_conditioned'
    batch_size = 64

    # MuJoCo specific
    env_name = 'Hopper-v5'
    dataset_id = 'mujoco/hopper/medium-v0'

    # Target return (can update after inspecting dataset stats)
    target_return = 3000.0

    # Derived dims/bounds (filled after binding env or dataset)
    state_dim = None
    act_dim = None
    act_low = None
    act_high = None

set_seed(Config.seed)

# Load MuJoCo dataset (use the MinariDataset class from earlier)
obss, actions, rtgs, timesteps, done_idxs = create_mujoco_dataset(Config.dataset_id)

# Set dimensions from data
Config.state_dim = obss.shape[-1]
Config.action_dim = actions.shape[-1]
max_timesteps = max(timesteps)

print(f"State dim: {Config.state_dim}, Action dim: {Config.action_dim}")
assert Config.state_dim > 0 and Config.action_dim > 0, "Bad dims from dataset"
assert np.isfinite(obss).all() and np.isfinite(actions).all(), "NaN/Inf in dataset"

# Create dataset
block_size = Config.context_length * 3
train_dataset = StateActionReturnDataset(
    obss, actions, rtgs, timesteps, done_idxs,
    block_size=Config.context_length * 3
)

# Create model config
mconf = GPTConfig(
    Config.action_dim,  # Not vocab_size anymore
    Config.context_length * 3,
    n_layer=6,
    n_head=8,
    n_embd=128,
    model_type=Config.model_type,
    max_timestep=max_timesteps,
    state_dim=Config.state_dim,
    action_dim=Config.action_dim
)
model = GPT(mconf)

if not hasattr(model, "block_size"):
    model.block_size = block_size

# Training config
tconf = TrainerConfig(
    max_epochs=Config.epochs,
    batch_size=Config.batch_size,
    learning_rate=6e-4,  # Higher LR for continuous actions
    lr_decay=True,
    warmup_tokens=512*20,
    final_tokens=2*len(train_dataset)*Config.context_length*3,
    num_workers=0,
    seed=Config.seed,
    model_type=Config.model_type,
    game=Config.env_name,  # Now 'Hopper-v4' not 'Breakout'
    max_timestep=max_timesteps,
    target_return=Config.target_return,
    fast_dev_run=True
)

trainer = Trainer(model, train_dataset, None, tconf)
trainer.train()

print("\n--- Training Complete ---")




Processing 1327 episodes from mujoco/hopper/medium-v0...


1327it [00:05, 261.79it/s]


State dim: 11, Action dim: 3


epoch 1 iter 9999: train loss 0.05292. lr 5.617413e-04:  64%|██████▍   | 10000/15615 [05:44<03:13, 29.01it/s]


Target return: 3000.0, Average evaluation return: 1203.04


epoch 2 iter 9999: train loss 0.04728. lr 4.566207e-04:  64%|██████▍   | 10000/15615 [05:55<03:19, 28.12it/s]


Target return: 3000.0, Average evaluation return: 2379.35


epoch 3 iter 9999: train loss 0.05243. lr 3.115114e-04:  64%|██████▍   | 10000/15615 [05:41<03:11, 29.26it/s]


Target return: 3000.0, Average evaluation return: 3082.63


epoch 4 iter 9999: train loss 0.05025. lr 1.634881e-04:  64%|██████▍   | 10000/15615 [05:41<03:11, 29.28it/s]


Target return: 3000.0, Average evaluation return: 3550.61


epoch 5 iter 9999: train loss 0.04446. lr 6.000000e-05:  64%|██████▍   | 10000/15615 [05:41<03:11, 29.29it/s]


Target return: 3000.0, Average evaluation return: 3246.11

--- Training Complete ---


### Rollout Video

Now that we've successfully trained our model, let's see it in action! To generate a rollout video, we reset the environment and get the initial state, and set all the RTG values to 0 so we rely only on the output action of the transformer. Then we begin to sample for the rollout loop!

Within our rollout loop, we take every generated action from each step and feed this to the Mujoco env. The environemnt then moves the robot leg in accordance with the most recent action, and returns the new state of the robot, the reward from the action (which is accumulated throughout the rollout), and then checks if the episode has ended via termination (i.e. has our robot fallen? no? then keep it going!).

If the loop is not terminated, then the newest state is converted to the appropriate tensor and added to the context history. Additionally, the RTG from the action is subtracted from the golbal remainder RTG, informing the model on how much 'reward' is left to get in the rollout. Then the loop starts again, using the updates state context to generate the next action and check if the video needs to be ended.

If the robot falls or the time is up (truncated) then the frame recording ends and the final video is generated.

<!-- **INSERT SUMMARY ON HOW ROLLOUT IS CREATED** -->

In [None]:
from pyvirtualdisplay import Display

# Start virtual display
display = Display(visible=0, size=(1400, 900))
display.start()

<pyvirtualdisplay.display.Display at 0x7898cff3b440>

In [None]:
def generate_rollout_video(model, config, video_folder='videos'):
    """
    Generates a video of the agent's performance for one episode.
    """
    eval_model = model.module if hasattr(model, "module") else model
    eval_model.eval()
    device = next(eval_model.parameters()).device

    print("Setting up environment for video recording...")
    env = gym.make(config.game, render_mode='rgb_array')

    env = gym.wrappers.RecordVideo(
        env,
        video_folder,
        episode_trigger=lambda x: x == 0,
        name_prefix=f"{config.game.lower()}-decision-transformer"
    )

    state, _ = env.reset()
    state = torch.from_numpy(state).to(device).unsqueeze(0).unsqueeze(0)
    rtgs = [config.target_return]

    sampled_action = sample(
        eval_model, state, 1, temperature=1.0, actions=None,
        rtgs=torch.FloatTensor(rtgs).to(device).reshape(1, -1, 1),
        timesteps=torch.zeros((1, 1, 1), dtype=torch.int64).to(device)
    )

    j = 0
    reward_sum = 0
    actions = []
    all_states = state
    done = False

    print("Starting rollout...")
    while not done:
        action = sampled_action.squeeze().cpu().numpy()
        actions.append(sampled_action)

        state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        reward_sum += reward
        j += 1

        if done:
            break

        state = torch.from_numpy(state).to(device).unsqueeze(0).unsqueeze(0)
        all_states = torch.cat([all_states, state], dim=1)
        rtgs.append(rtgs[-1] - reward)

        sampled_action = sample(
            eval_model, all_states, 1, temperature=1.0,
            actions=torch.stack(actions, dim=1).to(device),
            rtgs=torch.FloatTensor(rtgs).to(device).reshape(1, -1, 1),
            timesteps=(min(j, config.max_timestep) * torch.ones((1, 1, 1), dtype=torch.int64).to(device))
        )

    env.close()
    print(f"Rollout complete! Total Reward: {reward_sum}")
    print(f"Video saved in the '{video_folder}' directory.")

In [None]:
# Generate the video using your trained model and config
generate_rollout_video(trainer.model, tconf)

Setting up environment for video recording...
Starting rollout...


  IMAGEMAGICK_BINARY = r"C:\Program Files\ImageMagick-6.8.8-Q16\magick.exe"


Rollout complete! Total Reward: 3373.2724897746543
Video saved in the 'videos' directory.


In [None]:
import glob
import io
import base64
from IPython.display import HTML, display
import os

def show_video(video_path, video_width = 600):
  """
  Shows a video in a Colab notebook.
  """
  video_file = open(video_path, "r+b").read()
  video_url = f"data:video/mp4;base64,{base64.b64encode(video_file).decode()}"

  return HTML(f"""
  <video width="{video_width}" controls>
    <source src="{video_url}">
  </video>
  """)

# Find the most recent video file
video_files = glob.glob('videos/*.mp4')
latest_video = max(video_files, key=os.path.getctime)

# Display the video
print(f"Showing video from: {latest_video}")
display(show_video(latest_video))

Showing video from: videos/hopper-v5-decision-transformer-episode-0.mp4


Look at that robot go! Now we have successfully trained a GPT to continually generate actions based on the previous $n$ context actions, states, and returns!