# Chapter 02: Decision Transformer ‚Äì Offline RL as Discrete Path Integral

> "The agent does not learn a policy or a value function.  
> It learns the partition function over all possible futures, weighted by exponential reward."

‰∏Ä‰∏™GPT-style TransformerÂÖ∂ÂÆûÂ∞±ÊòØÁ¶ªÊï£Ë∑ØÂæÑÁßØÂàÜÔºö

$$P(trajectory) \propto exp(Return(\tau))$$


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, random_split
import gymnasium as gym
import numpy as np
import math
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm.auto import tqdm

sns.set_style("white")
plt.rcParams['mathtext.fontset'] = 'stix'
plt.rcParams['font.family'] = 'STIXGeneral'

device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: mps


### 1. Data Preparation: Building the Discrete Path-Integral Ensemble

In the path-integral picture, the "dataset" is nothing but a Monte Carlo estimate of the trajectory partition function
$$Z = \sum_{\tau} \exp\left( \sum_t \gamma^t r_t \right) \approx \sum_{i=1}^N \exp\big(G(\tau_i)\big)$$

In [2]:
def compute_returns_to_go(rewards: torch.Tensor, gamma: float = 0.99) -> torch.Tensor:
    """
    Compute return-to-go (RTG) for each timestep in a trajectory.
    
    RTG represents the expected cumulative future reward from each timestep.
    This is a KEY conditioning signal for the Decision Transformer.
    
    Args:
        rewards: (T,) tensor of per-timestep rewards
        gamma: discount factor for future rewards
    
    Returns:
        rtg: (T,) tensor where rtg[t] = sum_{k=t}^T gamma^{k-t} * r_k
    
    Example:
        If rewards = [1, 1, 1] and gamma = 0.99:
        rtg = [1 + 0.99 + 0.99^2, 1 + 0.99, 1] = [2.9701, 1.99, 1.0]
    """
    T = len(rewards)
    rtg = torch.zeros_like(rewards)
    
    # Compute RTG by backward accumulation
    # Start from the end: rtg[T-1] = rewards[T-1]
    # Work backwards: rtg[t] = rewards[t] + gamma * rtg[t+1]
    cumulative = 0.0
    for t in reversed(range(T)):
        cumulative = rewards[t] + gamma * cumulative
        rtg[t] = cumulative
    
    return rtg  # (T,)


In [3]:
def collect_trajectories(
    env_name: str = "CartPole-v1",
    n_trajectories: int = 1000,
    max_length: int = 1000,
    target_rtg_range: tuple = (0, 500),  # Target RTG range for this collection
    gamma: float = 0.99,
):
    """
    Collect trajectories with RTG-conditional behavior.
    
    KEY IDEA: To teach the model RTG conditioning, we generate data where
    the same state leads to different actions depending on the target RTG.
    
    Strategy:
    - Low RTG target (0-100): Use poor policy (high randomness)
    - Medium RTG target (100-300): Use medium policy (moderate randomness)
    - High RTG target (300-500): Use expert policy (low randomness)
    
    Args:
        env_name: Gymnasium environment name
        n_trajectories: Number of episodes to collect
        max_length: Maximum steps per episode
        target_rtg_range: (min_rtg, max_rtg) - range of target RTGs to sample from
        gamma: Discount factor
    
    Returns:
        trajectories: List of dicts with 'states', 'actions', 'rewards', 'returns', 'target_rtg'
    """
    env = gym.make(env_name)
    trajectories = []
    
    min_rtg, max_rtg = target_rtg_range
    
    for _ in tqdm(range(n_trajectories), desc=f"Collecting RTG-conditional data"):
        # Sample a policy quality level
        # We'll assign the target RTG AFTER seeing the actual return
        policy_quality = np.random.choice(['poor', 'medium', 'expert'], p=[0.3, 0.4, 0.3])
        
        # Determine randomness rate based on policy quality
        if policy_quality == 'poor':
            random_rate = 0.6  # 60% random ‚Üí expect low returns
        elif policy_quality == 'medium':
            random_rate = 0.3  # 30% random ‚Üí expect medium returns
        else:  # expert
            random_rate = 0.1  # 10% random ‚Üí expect high returns
        
        obs, _ = env.reset()
        states, actions, rewards = [], [], []
        
        for _ in range(max_length):
            # Apply RTG-conditional policy
            if np.random.random() < random_rate:
                # Random action (exploration based on target RTG)
                action = env.action_space.sample()
            else:
                # Use expert heuristic
                if "CartPole" in env_name:
                    angle = obs[2]
                    angular_velocity = obs[3]
                    # Expert heuristic: balance the pole
                    action = 1 if (angle + 0.3 * angular_velocity) > 0 else 0
                else:
                    action = env.action_space.sample()
            
            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            states.append(torch.tensor(obs, dtype=torch.float32))
            
            # For discrete actions, store as integer
            if isinstance(env.action_space, gym.spaces.Discrete):
                actions.append(torch.tensor(action, dtype=torch.long))
            else:
                actions.append(torch.tensor(action, dtype=torch.float32))
            
            rewards.append(torch.tensor(reward, dtype=torch.float32))
            
            obs = next_obs
            if done or truncated:
                break
        
        states = torch.stack(states)
        if isinstance(env.action_space, gym.spaces.Discrete):
            # For discrete actions, keep as long tensor
            actions = torch.stack(actions).unsqueeze(-1)  # (T, 1)
        else:
            # For continuous actions, stack directly
            actions = torch.stack(actions)
        rewards = torch.stack(rewards)
        
        # CRITICAL FIX: Assign target_rtg based on ACTUAL achieved return
        # This ensures RTG correctly correlates with trajectory quality!
        actual_return = rewards.sum().item()
        
        trajectories.append({
            "states": states,
            "actions": actions,
            "rewards": rewards,
            "returns": actual_return,
            "target_rtg": actual_return,  # Use actual return as target RTG
            "policy_quality": policy_quality,  # Store for analysis
        })
    
    env.close()
    
    # Print statistics
    returns = [t['returns'] for t in trajectories]
    target_rtgs = [t['target_rtg'] for t in trajectories]
    poor_trajs = [t for t in trajectories if t['policy_quality'] == 'poor']
    medium_trajs = [t for t in trajectories if t['policy_quality'] == 'medium']
    expert_trajs = [t for t in trajectories if t['policy_quality'] == 'expert']
    
    print(f"\nCollected {len(trajectories)} trajectories:")
    print(f"  Policy distribution: {len(poor_trajs)} poor, {len(medium_trajs)} medium, {len(expert_trajs)} expert")
    print(f"  Poor policy:   mean return = {np.mean([t['returns'] for t in poor_trajs]):.1f}")
    print(f"  Medium policy: mean return = {np.mean([t['returns'] for t in medium_trajs]):.1f}")
    print(f"  Expert policy: mean return = {np.mean([t['returns'] for t in expert_trajs]):.1f}")
    print(f"  Overall: {np.mean(returns):.1f} ¬± {np.std(returns):.1f}, range=[{np.min(returns):.0f}, {np.max(returns):.0f}]")
    print(f"  ‚úÖ Target RTG = Actual return (perfect correlation by design)")
    
    return trajectories


So here we collect RTG-conditional data across the full spectrum, and we are going to build a model learning from mixed-quality offline data to reproduce expert behavior.

In [5]:
env_name = "CartPole-v1"

trajectories = collect_trajectories(
    env_name=env_name,
    n_trajectories=2000,  # Collect diverse data across all RTG levels
    target_rtg_range=(0, 500),
    gamma=0.99,
)

# Print dataset statistics
returns = [t['returns'] for t in trajectories]
print(f"\n{'='*70}")
print("STEP 2: Dataset Statistics")
print("="*70)
print(f"Total trajectories: {len(trajectories)}")
print(f"Return range: [{min(returns):.1f}, {max(returns):.1f}]")
print(f"Return mean: {np.mean(returns):.1f} ¬± {np.std(returns):.1f}")
print(f"Return percentiles:")
print(f"  25th: {np.percentile(returns, 25):.1f}")
print(f"  50th: {np.percentile(returns, 50):.1f}")
print(f"  75th: {np.percentile(returns, 75):.1f}")
print(f"  90th: {np.percentile(returns, 90):.1f}")

Collecting RTG-conditional data:   0%|          | 0/2000 [00:00<?, ?it/s]


Collected 2000 trajectories:
  Policy distribution: 664 poor, 753 medium, 583 expert
  Poor policy:   mean return = 94.8
  Medium policy: mean return = 377.4
  Expert policy: mean return = 489.3
  Overall: 316.2 ¬± 184.5, range=[9, 500]
  ‚úÖ Target RTG = Actual return (perfect correlation by design)

STEP 2: Dataset Statistics
Total trajectories: 2000
Return range: [9.0, 500.0]
Return mean: 316.2 ¬± 184.5
Return percentiles:
  25th: 133.8
  50th: 354.0
  75th: 500.0
  90th: 500.0


Prepare dataset

In [6]:
class TrajectoryDataset(Dataset):
    """
    Dataset for Decision Transformer training.
    
    Converts trajectories into fixed-length context windows of (state, action, RTG) tuples.
    Each sample is a sliding window of `context_length` consecutive timesteps.
    
    This design allows the transformer to attend over recent history when predicting actions.
    """
    def __init__(self, trajectories, context_length: int = 20, gamma: float = 0.99):
        """
        Args:
            trajectories: List of dicts with keys 'states', 'actions', 'rewards'
            context_length: Number of timesteps in each training sample
            gamma: Discount factor for computing RTG
        """
        self.context_length = context_length
        self.gamma = gamma
        
        # We'll store all training samples as fixed-length sequences
        self.states = []
        self.actions = []
        self.rtg = []
        self.timesteps = []
        
        # Process each trajectory into multiple training samples
        for traj in trajectories:
            states = traj['states']      # (T, state_dim)
            actions = traj['actions']    # (T, action_dim)
            rewards = traj['rewards']    # (T,) - 1D tensor
            
            T = len(states)
            # Skip trajectories that are too short
            if T <= context_length:
                continue
            
            # Compute return-to-go for this trajectory
            rtg = compute_returns_to_go(rewards, gamma)  # (T,)
            
            # Create sliding windows of context_length
            # For a trajectory of length T, we get (T - context_length + 1) samples
            for start in range(T - context_length + 1):
                end = start + context_length
                
                self.states.append(states[start:end])
                self.actions.append(actions[start:end])
                self.rtg.append(rtg[start:end])  # (context_length,)
                self.timesteps.append(torch.arange(start, end))
        
        # Stack into tensors for efficient batching
        self.states = torch.stack(self.states)              # (N, L, state_dim)
        self.actions = torch.stack(self.actions)            # (N, L, action_dim)
        self.rtg = torch.stack(self.rtg).unsqueeze(-1)      # (N, L, 1)
        self.timesteps = torch.stack(self.timesteps)        # (N, L)
        
        print(f"Built dataset with {len(self)} sequences (context_length={context_length})")
    
    def __len__(self):
        return len(self.states)
    
    def __getitem__(self, idx):
        return (
            self.states[idx],
            self.actions[idx],
            self.rtg[idx],
            self.timesteps[idx],
            torch.ones(self.context_length, dtype=torch.bool)
        )
    
    @property
    def state_dim(self):
        return self.states.shape[-1]
    
    @property
    def action_dim(self):
        return self.actions.shape[-1]


In [7]:
# Convert trajectories into fixed-length context windows
dataset = TrajectoryDataset(trajectories, context_length=20)
loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True)

# Check RTG statistics before scaling
rtg_mean = dataset.rtg.mean()
rtg_std = dataset.rtg.std()
rtg_max = dataset.rtg.max()
rtg_min = dataset.rtg.min()
print(f"\nRTG statistics (before scaling):")
print(f"  Mean={rtg_mean:.2f}, Std={rtg_std:.2f}, Range=[{rtg_min:.2f}, {rtg_max:.2f}]")

# === RTG SCALING ===
# CartPole returns are in [0, 500]. We scale by 100 to get [0, 5].
# This normalization helps the model learn more effectively.
# IMPORTANT: Use the SAME scaling at evaluation time!
rtg_scale_factor = 100.0
dataset.rtg = dataset.rtg / rtg_scale_factor
print(f"‚úÖ RTG scaled by {rtg_scale_factor}, new range: [{dataset.rtg.min():.2f}, {dataset.rtg.max():.2f}]")

Built dataset with 594689 sequences (context_length=20)

RTG statistics (before scaling):
  Mean=76.23, Std=24.76, Range=[1.00, 99.34]
‚úÖ RTG scaled by 100.0, new range: [0.01, 0.99]


### 2. The Decision Transformer Model ‚Äì Discrete Path Integral Propagator

In [8]:
class DecisionTransformer(nn.Module):
    """
    Decision Transformer: Offline RL as Sequence Modeling
    
    Key idea: Model (state, action, return-to-go) sequences with a Transformer,
    then predict actions conditioned on desired returns.
    
    Architecture:
        1. Embed states, actions, and RTGs into hidden_dim vectors
        2. Add positional (timestep) embeddings
        3. Stack as sequence: [RTG_1, s_1, a_1, RTG_2, s_2, a_2, ...]
        4. Process with causal Transformer
        5. Predict actions from state embeddings
    """
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hidden_dim: int = 128,
        n_layers: int = 3,
        n_heads: int = 8,
        max_timestep: int = 1024,
        dropout: float = 0.2,
        rtg_scale: float = 5.0,  # Amplification factor for RTG embeddings
    ):
        """
        Args:
            state_dim: Dimension of state observations
            action_dim: Dimension of action space (for discrete: 1, for continuous: action_space.shape[0])
            hidden_dim: Transformer hidden dimension
            n_layers: Number of transformer layers
            n_heads: Number of attention heads
            max_timestep: Maximum timestep for positional embeddings
            dropout: Dropout rate
            rtg_scale: Multiplicative factor to amplify RTG signal (helps learning)
        """
        super().__init__()
        self.hidden_dim = hidden_dim
        self.rtg_scale = rtg_scale

        # CRITICAL: All embeddings should have similar complexity!
        # RTG gets a deeper network to strengthen its conditioning signal
        self.embed_rtg = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        
        # State embedding - should also be deep to learn good representations
        self.embed_state = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        
        # Action embedding - also deep for consistency
        self.embed_action = nn.Sequential(
            nn.Linear(action_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

        # Timestep embeddings encode position in the trajectory
        self.embed_timestep = nn.Embedding(max_timestep, hidden_dim)

        # Transformer encoder with causal masking
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=n_heads,
            dim_feedforward=4 * hidden_dim,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,  # Pre-LN for training stability
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=n_layers,
            enable_nested_tensor=False,  # Disable for causal masking
        )

        # Action prediction head - use MLP for better expressiveness
        self.predict_action = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, 2),  # CartPole has 2 discrete actions
        )

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """Initialize weights with small random values (GPT-style)"""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, states, actions, returns_to_go, timesteps, debug=False):
        """
        Forward pass: predict actions given states, past actions, and desired RTG.
        
        Args:
            states: (B, L, state_dim) - batch of state sequences
            actions: (B, L, action_dim) - batch of action sequences
            returns_to_go: (B, L, 1) - desired future returns at each timestep
            timesteps: (B, L) - timestep indices for positional encoding
            debug: If True, print embedding statistics
        
        Returns:
            action_preds: (B, L, action_space_size) - predicted action logits
        
        Process:
            1. Embed states, actions, RTG ‚Üí (B, L, hidden_dim) each
            2. Add timestep embeddings to all
            3. Stack as [RTG_1, s_1, a_1, RTG_2, s_2, a_2, ...] ‚Üí (B, 3L, hidden_dim)
            4. Apply causal Transformer (can't look into future)
            5. Extract state embeddings (every 3rd position starting at index 2)
            6. Predict actions from state embeddings
        """
        # Ensure correct dtypes
        states = states.float()
        actions = actions.float()
        returns_to_go = returns_to_go.float()
        timesteps = timesteps.long()

        B, L = states.shape[:2]  # Batch size, sequence length

        # === STEP 1: Embed inputs ===
        # RTG embedding is amplified by rtg_scale to strengthen its signal
        rtg_embed    = self.embed_rtg(returns_to_go) * self.rtg_scale
        state_embed  = self.embed_state(states)
        action_embed = self.embed_action(actions)

        # Debug: Check embedding magnitudes
        if debug:
            print(f"\nüîç DEBUG Forward Pass:")
            print(f"  RTG input range: [{returns_to_go.min():.3f}, {returns_to_go.max():.3f}]")
            print(f"  RTG embed norm (before scale): {(self.embed_rtg(returns_to_go)).norm(dim=-1).mean():.3f}")
            print(f"  RTG embed norm (after scale): {rtg_embed.norm(dim=-1).mean():.3f}")
            print(f"  State embed norm: {state_embed.norm(dim=-1).mean():.3f}")
            print(f"  Action embed norm: {action_embed.norm(dim=-1).mean():.3f}")

        # === STEP 2: Add positional (timestep) embeddings ===
        time_embed = self.embed_timestep(timesteps)  # (B, L, hidden_dim)
        rtg_embed    += time_embed
        state_embed  += time_embed
        action_embed += time_embed

        # === STEP 3: Stack as interleaved sequence ===
        # Shape: (B, L, 3, hidden_dim) ‚Üí (B, 3*L, hidden_dim)
        # Order: [RTG_0, state_0, action_0, RTG_1, state_1, action_1, ...]
        stacked = torch.stack((rtg_embed, state_embed, action_embed), dim=2)
        stacked = stacked.reshape(B, 3 * L, -1)

        # === STEP 4: Apply causal Transformer ===
        # Causal mask ensures position i can only attend to positions <= i
        # This prevents the model from "cheating" by seeing future information
        seq_len = 3 * L
        mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(stacked.device)
        h = self.transformer(stacked, mask=mask)  # (B, 3*L, hidden_dim)

        # === STEP 5: Extract action predictions ===
        # CRITICAL: Predict action from STATE position (1::3), not action position!
        # Sequence order: [RTG_0, state_0, action_0, RTG_1, state_1, action_1, ...]
        # Positions:      [ 0,      1,        2,      3,      4,        5,    ...]
        # 
        # At position 1 (state_0), model has seen: RTG_0, state_0
        # This is the RIGHT place to predict action_0!
        # 
        # If we use position 2 (action_0), model would see: RTG_0, state_0, action_0
        # This is WRONG - it's cheating by seeing the answer!
        action_preds = self.predict_action(h[:, 1::3, :])  # (B, L, action_dim)

        return action_preds


### 3. Training

Initializing Decision Transformer

In [9]:
model = DecisionTransformer(
    state_dim=dataset.state_dim,      # 4 for CartPole (position, velocity, angle, angular_velocity)
    action_dim=dataset.action_dim,    # 1 (we have 2 discrete actions, but store as single integer)
    hidden_dim=128,                   
    n_layers=4,                      
    n_heads=4,
    max_timestep=1024,                 # Maximum trajectory length
    dropout=0.1,                       # Dropout for regularization
    rtg_scale=20.0,                    
).to(device)

In [10]:
# Split into train/val for monitoring overfitting
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=256, shuffle=False)

# Optimizer: AdamW with small learning rate and weight decay
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
# Training hyperparameters
num_epochs = 10
steps_per_epoch = len(train_loader)
total_steps = num_epochs * steps_per_epoch

print(f"Epochs: {num_epochs}")
print(f"Total steps: {total_steps:,}")

Epochs: 10
Total steps: 18,590


#### Start training

In [11]:
global_step = 0
model.train()

best_val_loss = float('inf')
patience = 3  # Early stopping: stop if no improvement for 3 epochs
early_stop_counter = 0

pbar = tqdm(total=total_steps, desc="Training")

# Debug flag to print first batch info
first_batch_debug = True

for epoch in range(num_epochs):
    # === Training Phase ===
    model.train()
    train_loss = 0.0
    num_train_batches = 0
    
    for batch in train_loader:
        # Unpack batch and move to device
        states, actions, rtgs, timesteps, mask = [x.to(device) for x in batch]

        # Debug: Print first batch information
        if first_batch_debug and epoch == 0:
            print(f"\n{'='*60}")
            print(f"üîç FIRST BATCH ANALYSIS")
            print(f"{'='*60}")
            print(f"Batch shapes: states={states.shape}, actions={actions.shape}, rtgs={rtgs.shape}")
            print(f"RTG range in batch: [{rtgs.min():.3f}, {rtgs.max():.3f}]")
            print(f"Actions in batch: unique values = {actions.unique().tolist()}")
            
            # Check if there's any RTG-action correlation in the data
            rtg_flat = rtgs.flatten().cpu().numpy()
            action_flat = actions.flatten().cpu().numpy()
            print(f"\nRTG-Action correlation check:")
            print(f"  Low RTG (<-1): actions = {action_flat[rtg_flat < -1][:10]}")
            print(f"  High RTG (>1): actions = {action_flat[rtg_flat > 1][:10]}")
            first_batch_debug = False

        # Forward pass
        pred_actions = model(
            states=states,
            actions=actions,
            returns_to_go=rtgs,
            timesteps=timesteps,
            debug=(epoch == 0 and num_train_batches == 0),  # Debug embeddings in first batch
        )

        # === Compute Loss ===
        # For discrete actions: use cross-entropy loss
        actions_discrete = actions.squeeze(-1).long()  # (B, L) - squeeze out action_dim
        loss = F.cross_entropy(
            pred_actions.transpose(1, 2),  # (B, num_actions, L) - cross_entropy expects class dim at index 1
            actions_discrete,              # (B, L) - target class indices
            reduction='none'               # Don't reduce yet, we need to apply mask
        )
        # Apply mask to ignore padded positions
        loss = (loss * mask).sum() / mask.sum()

        # Backward pass and optimization step
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping for stability
        optimizer.step()

        # Track metrics
        train_loss += loss.item()
        num_train_batches += 1

        pbar.update(1)
        global_step += 1

        # Log progress every 1000 steps
        if global_step % 1000 == 0:
            pbar.set_postfix({"train_loss": f"{loss.item():.4g}"})

    avg_train_loss = train_loss / num_train_batches
    print(f"\nEpoch {epoch+1}/{num_epochs} | Avg train loss: {avg_train_loss:.4g}")

    # === Validation Phase ===
    model.eval()
    val_loss = 0.0
    num_val_batches = 0
    
    with torch.no_grad():
        for batch in val_loader:
            states, actions, rtgs, timesteps, mask = [x.to(device) for x in batch]

            pred_actions = model(
                states=states,
                actions=actions,
                returns_to_go=rtgs,
                timesteps=timesteps,
            )

            actions_discrete = actions.squeeze(-1).long()
            loss = F.cross_entropy(
                pred_actions.transpose(1, 2),
                actions_discrete,
                reduction='none'
            )
            loss = (loss * mask).sum() / mask.sum()

            val_loss += loss.item()
            num_val_batches += 1

    avg_val_loss = val_loss / num_val_batches
    print(f"Epoch {epoch+1}/{num_epochs} | Avg val loss: {avg_val_loss:.4g}")

    # === Early Stopping Check ===
    if avg_val_loss < best_val_loss - 1e-5:
        best_val_loss = avg_val_loss
        early_stop_counter = 0
        # Save best model checkpoint
        torch.save(model.state_dict(), 'best_dt_model.pt')
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print(f"‚ö†Ô∏è  Early stopping triggered after epoch {epoch+1} (no improvement for {patience} epochs)")
            break

pbar.close()
print("\n‚úÖ Training completed!")

Training:   0%|          | 0/18590 [00:00<?, ?it/s]


üîç FIRST BATCH ANALYSIS
Batch shapes: states=torch.Size([256, 20, 4]), actions=torch.Size([256, 20, 1]), rtgs=torch.Size([256, 20, 1])
RTG range in batch: [0.010, 0.992]
Actions in batch: unique values = [0, 1]

RTG-Action correlation check:
  Low RTG (<-1): actions = []
  High RTG (>1): actions = []

üîç DEBUG Forward Pass:
  RTG input range: [0.010, 0.992]
  RTG embed norm (before scale): 1.632
  RTG embed norm (after scale): 32.649
  State embed norm: 1.454
  Action embed norm: 0.868

Epoch 1/10 | Avg train loss: 0.3824
Epoch 1/10 | Avg val loss: 0.3629

Epoch 2/10 | Avg train loss: 0.3561
Epoch 2/10 | Avg val loss: 0.3574

Epoch 3/10 | Avg train loss: 0.3508
Epoch 3/10 | Avg val loss: 0.35

Epoch 4/10 | Avg train loss: 0.3476
Epoch 4/10 | Avg val loss: 0.3449

Epoch 5/10 | Avg train loss: 0.3453
Epoch 5/10 | Avg val loss: 0.3433

Epoch 6/10 | Avg train loss: 0.3439
Epoch 6/10 | Avg val loss: 0.3396

Epoch 7/10 | Avg train loss: 0.3421
Epoch 7/10 | Avg val loss: 0.3377

Epoch 8/

### 4. Evaluation Function

The model is trained on random + expert trajectories, it is expected to achieve expert-level performance by conditioning on high returns.

In [12]:
@torch.no_grad()
def evaluate_dt(
    model, 
    env_name: str,
    target_return: float, 
    rtg_scale_factor: float,
    context_length: int,
    device: str,
    n_eval: int = 20
):
    """
    Evaluate the trained Decision Transformer in the environment.
    
    This is the KEY test: given a desired return (target_return), can the model
    generate a trajectory that achieves it?
    
    Args:
        model: Trained DecisionTransformer
        env_name: Environment to evaluate in
        target_return: Desired cumulative return (the "goal" we condition on)
        rtg_scale_factor: Same scaling factor used in training
        context_length: Maximum history length to condition on
        device: torch device
        n_eval: Number of evaluation episodes
    
    Returns:
        (mean_return, std_return): Average and std of achieved returns
    """
    model.eval()
    env = gym.make(env_name)
    scores = []

    for _ in range(n_eval):
        state, _ = env.reset()
        state = torch.tensor(state, dtype=torch.float32, device=device)

        # Maintain history of (state, action, RTG, timestep) for context window
        states_hist = []
        actions_hist = []
        rtgs_hist = []
        timesteps_hist = []

        # Scale target return same way as training data
        norm_rtg = target_return / rtg_scale_factor
        episode_reward = 0.0
        done = False
        timestep = 0

        gamma = 0.99  # Must match training

        while not done:
            # === Build input for prediction ===
            # Add current state, rtg, and timestep
            states_hist.append(state.unsqueeze(0))  # (1, state_dim)
            rtgs_hist.append(torch.tensor([norm_rtg], dtype=torch.float32, device=device))
            timesteps_hist.append(torch.tensor([timestep], dtype=torch.long, device=device))
            
            # For actions: use dummy action for the current timestep (we're predicting it)
            # and real actions for all previous timesteps
            actions_in_list = actions_hist.copy() + [torch.zeros(1, 1, dtype=torch.long, device=device)]
            
            # Build context window (last context_length steps)
            hist_len = len(states_hist)
            start_idx = max(0, hist_len - context_length)
            
            # Concatenate history into model input format
            states_in = torch.cat(states_hist[start_idx:], dim=0).unsqueeze(0)      # (1, T, state_dim)
            actions_in = torch.cat(actions_in_list[start_idx:], dim=0).unsqueeze(0)  # (1, T, 1)
            rtgs_in = torch.cat(rtgs_hist[start_idx:], dim=0).unsqueeze(0).unsqueeze(-1)  # (1, T, 1)
            timesteps_in = torch.cat(timesteps_hist[start_idx:], dim=0).unsqueeze(0)  # (1, T)

            # === Left-pad if history shorter than context_length ===
            if hist_len < context_length:
                pad_len = context_length - hist_len
                # Pad with zeros (model will learn to ignore padded positions)
                pad_states = torch.zeros(1, pad_len, states_in.shape[-1], device=device)
                pad_actions = torch.zeros(1, pad_len, 1, dtype=torch.long, device=device)
                pad_rtgs = torch.zeros(1, pad_len, 1, device=device)
                pad_times = torch.zeros(1, pad_len, dtype=torch.long, device=device)

                states_in = torch.cat([pad_states, states_in], dim=1)
                actions_in = torch.cat([pad_actions, actions_in], dim=1)
                rtgs_in = torch.cat([pad_rtgs, rtgs_in], dim=1)
                timesteps_in = torch.cat([pad_times, timesteps_in], dim=1)

            # === Predict action ===
            with torch.no_grad():
                # Debug first step
                if timestep == 0:
                    print(f"\nüîç EVAL DEBUG (timestep=0):")
                    print(f"  states_in.shape: {states_in.shape}")
                    print(f"  actions_in.shape: {actions_in.shape}")
                    print(f"  rtgs_in.shape: {rtgs_in.shape}")
                    print(f"  rtgs_in value: {rtgs_in[0, -1, 0].item():.3f} (normalized)")
                
                pred = model(states=states_in, actions=actions_in, returns_to_go=rtgs_in, timesteps=timesteps_in)
                action_logit = pred[0, -1]  # Take last timestep's prediction
                action_to_take = torch.argmax(action_logit).item()  # Greedy action selection
                
                if timestep == 0:
                    print(f"  pred.shape: {pred.shape}")
                    print(f"  action_logit: {action_logit}")
                    print(f"  selected action: {action_to_take}")

            # === Execute action in environment ===
            next_state, reward, terminated, truncated, _ = env.step(action_to_take)
            done = terminated or truncated
            episode_reward += reward

            # === Update for next step ===
            # Store the action we just took (for next prediction's history)
            actions_hist.append(torch.tensor([[action_to_take]], dtype=torch.long, device=device))
            
            # Update state
            state = torch.tensor(next_state, dtype=torch.float32, device=device)
            
            # Update RTG with gamma to match training
            # Training: rtg[t] = r[t] + gamma * rtg[t+1]
            # Inference: rtg[t+1] = (rtg[t] - r[t]) / gamma
            norm_rtg = (norm_rtg - reward / rtg_scale_factor) / gamma
            
            timestep += 1

        scores.append(episode_reward)

    env.close()
    mean = np.mean(scores)
    std = np.std(scores)
    return mean, std


### 5. DEBUG & ANALYSIS - Test RTG Conditioning

In [13]:
model.eval()

DecisionTransformer(
  (embed_rtg): Sequential(
    (0): Linear(in_features=1, out_features=128, bias=True)
    (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (2): GELU(approximate='none')
    (3): Linear(in_features=128, out_features=128, bias=True)
  )
  (embed_state): Sequential(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (2): GELU(approximate='none')
    (3): Linear(in_features=128, out_features=128, bias=True)
  )
  (embed_action): Sequential(
    (0): Linear(in_features=1, out_features=128, bias=True)
    (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (2): GELU(approximate='none')
    (3): Linear(in_features=128, out_features=128, bias=True)
  )
  (embed_timestep): Embedding(1024, 128)
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQ

### 6. EVALUATION - Test in Real Environment

In [15]:
print("Training data statistics:")
all_returns = [t['returns'] for t in trajectories]
all_target_rtgs = [t['target_rtg'] for t in trajectories]
print(f"  Mean return: {np.mean(all_returns):.1f} ¬± {np.std(all_returns):.1f}")
print(f"  Return range: [{np.min(all_returns):.0f}, {np.max(all_returns):.0f}]")
print(f"  Target RTG range: [{np.min(all_target_rtgs):.0f}, {np.max(all_target_rtgs):.0f}]")
print(f"  Correlation (target_rtg vs actual_return): {np.corrcoef(all_target_rtgs, all_returns)[0,1]:.3f}")

Training data statistics:
  Mean return: 316.2 ¬± 184.5
  Return range: [9, 500]
  Target RTG range: [9, 500]
  Correlation (target_rtg vs actual_return): 1.000


#### Test 1: Low Target RTG

Ask for poor performance - should match random policy

In [18]:
low_target = 50
low_mean, low_std = evaluate_dt(
    model=model, env_name=env_name, target_return=low_target,
    rtg_scale_factor=rtg_scale_factor,
    context_length=dataset.context_length,
    device=device, n_eval=50
)
print(f"   Target RTG={low_target}  ‚Üí Achieved: {low_mean:.1f} ¬± {low_std:.1f}")


üîç EVAL DEBUG (timestep=0):
  states_in.shape: torch.Size([1, 20, 4])
  actions_in.shape: torch.Size([1, 20, 1])
  rtgs_in.shape: torch.Size([1, 20, 1])
  rtgs_in value: 0.500 (normalized)
  pred.shape: torch.Size([1, 20, 2])
  action_logit: tensor([ 0.5866, -0.5397], device='mps:0')
  selected action: 0

üîç EVAL DEBUG (timestep=0):
  states_in.shape: torch.Size([1, 20, 4])
  actions_in.shape: torch.Size([1, 20, 1])
  rtgs_in.shape: torch.Size([1, 20, 1])
  rtgs_in value: 0.500 (normalized)
  pred.shape: torch.Size([1, 20, 2])
  action_logit: tensor([-0.2815,  0.1674], device='mps:0')
  selected action: 1

üîç EVAL DEBUG (timestep=0):
  states_in.shape: torch.Size([1, 20, 4])
  actions_in.shape: torch.Size([1, 20, 1])
  rtgs_in.shape: torch.Size([1, 20, 1])
  rtgs_in value: 0.500 (normalized)
  pred.shape: torch.Size([1, 20, 2])
  action_logit: tensor([-0.4603,  0.1679], device='mps:0')
  selected action: 1

üîç EVAL DEBUG (timestep=0):
  states_in.shape: torch.Size([1, 20, 4])


#### Test 2: High Target RTG

Ask for expert performance

In [17]:
high_target = 500
high_mean, high_std = evaluate_dt(
    model=model, env_name=env_name, target_return=high_target,
    rtg_scale_factor=rtg_scale_factor,
    context_length=dataset.context_length,
    device=device, n_eval=50
)
print(f"   Target RTG={high_target} ‚Üí Achieved: {high_mean:.1f} ¬± {high_std:.1f}")



üìä Test 2: HIGH target RTG (asking for expert performance)
----------------------------------------------------------------------

üîç EVAL DEBUG (timestep=0):
  states_in.shape: torch.Size([1, 20, 4])
  actions_in.shape: torch.Size([1, 20, 1])
  rtgs_in.shape: torch.Size([1, 20, 1])
  rtgs_in value: 5.000 (normalized)
  pred.shape: torch.Size([1, 20, 2])
  action_logit: tensor([-0.3681,  0.1792], device='mps:0')
  selected action: 1

üîç EVAL DEBUG (timestep=0):
  states_in.shape: torch.Size([1, 20, 4])
  actions_in.shape: torch.Size([1, 20, 1])
  rtgs_in.shape: torch.Size([1, 20, 1])
  rtgs_in value: 5.000 (normalized)
  pred.shape: torch.Size([1, 20, 2])
  action_logit: tensor([ 0.0930, -0.3420], device='mps:0')
  selected action: 0

üîç EVAL DEBUG (timestep=0):
  states_in.shape: torch.Size([1, 20, 4])
  actions_in.shape: torch.Size([1, 20, 1])
  rtgs_in.shape: torch.Size([1, 20, 1])
  rtgs_in value: 5.000 (normalized)
  pred.shape: torch.Size([1, 20, 2])
  action_logit: tens

DT learned to reproduce expert behavior from mixed offline data!
By conditioning on high RTG (500), it achieves expert-level performance.
This demonstrates offline RL working: no environment interaction during training!