# Lunar Lander Landing with Decision Transformer

## Project Overview

This project demonstrates the application of a **Decision Transformer (DT)** model to solve the Lunar Lander environment from Gymnasium. The Decision Transformer is a novel architecture in reinforcement learning that frames the problem of optimal control as a sequence modeling task, leveraging the power of transformer networks, widely used in natural language processing. Instead of predicting actions based on the current state, the DT predicts actions conditioned on desired future returns, states, and actions, effectively acting as a goal-conditioned policy.

The primary goal is to train an agent to successfully land the Lunar Lander module on a designated landing pad between two flags. This involves precise control over thelander's thrusters to manage its descent rate and horizontal movement.

### Key Features:
-   **Decision Transformer Implementation:** A custom PyTorch implementation of the Decision Transformer architecture.
-   **Expert Trajectory Learning:** The model is trained on a dataset of expert trajectories, allowing it to learn optimal behaviors through imitation learning.
-   **Optimized Training:** Incorporates several optimizations for efficient and stable training, including:
    -   State normalization.
    -   Optimized trajectory processing and sampling strategy.
    -   Mixed precision training for faster computation.
    -   Gradient clipping and learning rate scheduling.
    -   Early stopping mechanism.
-   **Pygame Visualization:** Real-time visualization of the agent's performance in the Lunar Lander environment during evaluation.

### Environment: LunarLander-v3

The Lunar Lander environment is a classic control problem in reinforcement learning. The agent controls a lander that needs to land safely between two flags on a landing pad. The state space consists of 8 continuous values (position, velocity, angle, angular velocity of the lander, and whether each leg is in contact with the ground). The action space consists of 4 discrete actions (do nothing, fire main engine, fire left engine, fire right engine).


## 1. Setup and Configuration

This section imports all necessary libraries and defines the global configuration parameters for the environment, model, training process, and evaluation. These parameters can be easily adjusted to experiment with different settings.

In [None]:
import os
import random
import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import collections
from collections import deque
import pygame
import time
import math
import pickle
from tqdm import tqdm
import torch.nn.functional as F

# --- Configuration ---
ENV_ID = 'LunarLander-v3'
SEED = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Trajectory Data
TRAJECTORY_FILE = '../datasets/trajectories_5000.pkl' # Path to expert trajectories
MIN_EXPERT_RETURN = 250 # Minimum return for a trajectory to be considered 'expert'

# Optimized Decision Transformer Hyperparameters
DT_CONTEXT_LEN = 20 # Length of the sequence context for the transformer
DT_N_HEADS = 8 # Number of attention heads in the transformer
DT_N_LAYERS = 6 # Number of transformer encoder layers
DT_EMBED_DIM = 256 # Dimension of the embeddings
DT_DROPOUT = 0.1 # Dropout rate for regularization
DT_LR = 3e-4 # Learning rate for the optimizer
DT_WEIGHT_DECAY = 1e-4 # Weight decay for L2 regularization
DT_BATCH_SIZE = 256 # Batch size for training
DT_NUM_EPOCHS = 50 # Number of training epochs
DT_TRAJECTORIES_TO_USE = 5000 # Number of expert trajectories to use for training

# Enhanced Early Stopping
EARLY_STOPPING_PATIENCE = 5 # Number of epochs to wait for improvement before stopping
EARLY_STOPPING_MIN_DELTA = 0.001 # Minimum change in loss to qualify as an improvement

# Learning Rate Scheduling
USE_LR_SCHEDULER = True # Whether to use a learning rate scheduler
SCHEDULER_FACTOR = 0.8 # Factor by which the learning rate will be reduced
SCHEDULER_PATIENCE = 3 # Number of epochs with no improvement after which learning rate will be reduced

# Data Augmentation
USE_DATA_AUGMENTATION = True # Whether to apply data augmentation to states
AUGMENTATION_NOISE_STD = 0.01 # Standard deviation of noise for data augmentation

# Pygame Visualization
PYGAME_FPS = 60 # Frames per second for Pygame visualization
NUM_EVAL_EPISODES = 10 # Number of episodes to run for evaluation

# Global Normalization Parameters (will be calculated from data)
state_mean = None
state_std = None


## 2. Utility Functions

This section defines helper functions essential for data preparation and overall system stability.

In [None]:
def set_seed(seed):
    """Sets the random seed for reproducibility across different libraries."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(SEED)

def load_trajectories_from_pickle(file_path):
    """Loads expert trajectories from a pickle file."""
    print(f"\n--- Loading Expert Trajectories from {file_path} ---")
    if not os.path.exists(file_path):
        print(f"Error: Trajectory file not found at {file_path}")
        return None
    try:
        with open(file_path, 'rb') as f:
            trajectories = pickle.load(f)
        print(f"Successfully loaded {len(trajectories)} trajectories.")
        return trajectories
    except Exception as e:
        print(f"Error loading trajectories from pickle file: {e}")
        return None

def calculate_state_normalization_params(trajectories):
    """Calculates mean and standard deviation for state normalization across all trajectories.
    Normalization is crucial for neural networks to ensure stable training by centering
    inputs around zero mean and scaling them to similar magnitudes. This prevents issues
    like poor gradient flow, slow convergence, and overfitting to high-magnitude features.
    """
    print("Calculating state normalization parameters...")
    
    all_states = []
    for traj in trajectories:
        all_states.append(np.array(traj['states']))
    
    all_states = np.concatenate(all_states, axis=0).astype(np.float32)
    
    mean = np.mean(all_states, axis=0, keepdims=True)
    std = np.std(all_states, axis=0, keepdims=True)
    std = np.maximum(std, 1e-8)  # Prevent division by zero
    
    print(f"Processed {len(all_states)} states for normalization")
    return mean.flatten(), std.flatten()

def normalize_state(state, mean, std):
    """Normalizes a single state vector using the pre-calculated mean and standard deviation."""
    return (state - mean) / std

def process_trajectories_optimized(raw_trajectories, context_len, gamma=0.995, state_mean=None, state_std=None):
    """Processes raw trajectories into a format suitable for Decision Transformer training.
    This includes normalizing states, calculating returns-to-go, and extracting sequences
    of a specified context length. It also implements a smart sampling strategy to prioritize
    high-value segments of trajectories.
    """
    processed_data = []
    
    print("Processing trajectories with optimized sampling...")
    
    for trajectory in tqdm(raw_trajectories, desc="Processing trajectories"):
        states = np.array(trajectory['states'], dtype=np.float32)
        actions = np.array(trajectory['actions'], dtype=np.int64)
        rewards = np.array(trajectory['rewards'], dtype=np.float32)
        
        # Vectorized normalization
        normalized_states = (states - state_mean) / state_std
        
        # Vectorized returns-to-go calculation: G_t = R_t + gamma * G_{t+1}
        returns_to_go = np.zeros_like(rewards)
        returns_to_go[-1] = rewards[-1]
        for t in range(len(rewards) - 2, -1, -1):
            returns_to_go[t] = rewards[t] + gamma * returns_to_go[t + 1]
        
        traj_len = len(normalized_states)
        
        # Smart sampling: sample more from high-value parts of trajectory
        # This helps the model focus on more critical and informative segments
        high_value_threshold = np.percentile(returns_to_go, 70)
        high_value_indices = np.where(returns_to_go >= high_value_threshold)[0]
        
        # Base sampling evenly distributes samples across the trajectory
        sample_indices = []
        for i in range(0, traj_len, max(1, traj_len // 20)): 
            sample_indices.append(i)
        
        # Add extra samples from high-value states for more focused learning
        for idx in high_value_indices[::2]: # Every other high-value state
            if idx not in sample_indices:
                sample_indices.append(idx)
        
        sample_indices = sorted(set(sample_indices))
        
        for i in sample_indices:
            end_idx = i + 1
            start_idx = max(0, end_idx - context_len)
            
            # Extract sequences for states, actions, returns-to-go, and timesteps
            s_seq = normalized_states[start_idx:end_idx]
            a_seq = actions[start_idx:end_idx]
            r_seq = returns_to_go[start_idx:end_idx]
            timesteps_seq = np.arange(start_idx, end_idx)
            
            # Pad sequences to `context_len` if they are shorter
            # Padding ensures all sequences have a consistent length for batching
            pad_len = context_len - len(s_seq)
            if pad_len > 0:
                s_seq = np.vstack([np.zeros((pad_len, s_seq.shape[1])), s_seq])
                a_seq = np.concatenate([np.zeros(pad_len, dtype=np.int64), a_seq])
                r_seq = np.concatenate([np.zeros(pad_len), r_seq])
                # Timesteps padding should be relative to the sequence start
                timesteps_seq = np.concatenate([np.zeros(pad_len, dtype=np.int64), np.arange(len(timesteps_seq)) + pad_len])
            
            processed_data.append({
                'states': s_seq.astype(np.float32),
                'actions': a_seq.astype(np.int64),
                'returns_to_go': r_seq.astype(np.float32),
                'timesteps': timesteps_seq.astype(np.int64),
                'target_action': actions[i] # The action taken at the current timestep 'i'
            })
    
    print(f"Generated {len(processed_data)} training samples")
    return processed_data


## 3. Dataset Class

The `OptimizedTrajectoryDataset` class is a PyTorch `Dataset` that prepares the processed trajectory data for the DataLoader. It handles fetching individual samples and applies data augmentation if configured.

In [None]:
class OptimizedTrajectoryDataset(Dataset):
    """Custom Dataset for Decision Transformer training.
    It provides sequences of states, actions, returns-to-go, and timesteps.
    Supports optional data augmentation by adding noise to states.
    """
    def __init__(self, data, use_augmentation=False, noise_std=0.01):
        self.data = data
        self.use_augmentation = use_augmentation
        self.noise_std = noise_std

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        states = item['states'].copy()
        
        # Data augmentation: add small noise to states to improve robustness
        if self.use_augmentation and random.random() < 0.3: # Apply noise with 30% probability
            noise = np.random.normal(0, self.noise_std, states.shape).astype(np.float32)
            states = states + noise
        
        return (torch.tensor(states, dtype=torch.float32),
                torch.tensor(item['actions'], dtype=torch.long),
                torch.tensor(item['returns_to_go'], dtype=torch.float32),
                torch.tensor(item['timesteps'], dtype=torch.long),
                torch.tensor(item['target_action'], dtype=torch.long))


## 4. Decision Transformer Model (`OptimizedDecisionTransformer`)

This is the core of the project: a PyTorch implementation of the Decision Transformer. The model takes sequences of states, actions, returns-to-go, and timesteps as input and predicts the next action. It leverages a transformer encoder for sequence modeling and uses various optimizations for better performance and stability.

In [None]:
class OptimizedDecisionTransformer(nn.Module):
    """Optimized Decision Transformer model for sequence modeling in reinforcement learning."""
    def __init__(self, state_dim, action_dim, embed_dim, context_len, n_heads, n_layers, dropout, max_timestep=4096):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.embed_dim = embed_dim
        self.context_len = context_len

        # State embedding: Projects raw state vector into the embedding dimension
        self.state_embedding = nn.Sequential(
            nn.Linear(state_dim, embed_dim),
            nn.LayerNorm(embed_dim),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5) # Reduced dropout for embedding
        )
        
        # Action embedding: Uses nn.Embedding for discrete actions, which is efficient for categorical inputs
        self.action_embedding = nn.Embedding(action_dim, embed_dim)

        # Return-to-go embedding: Projects scalar return-to-go into the embedding dimension
        self.return_embedding = nn.Sequential(
            nn.Linear(1, embed_dim),
            nn.LayerNorm(embed_dim),
            nn.ReLU()
        )
        
        # Learnable positional embedding: Adds positional information to the sequence tokens.
        # The length is `context_len * 3` because each timestep has 3 tokens: (return, state, action).
        self.pos_embedding = nn.Parameter(torch.randn(1, context_len * 3, embed_dim) * 0.02)
        
        self.dropout = nn.Dropout(dropout)
        
        # Transformer Encoder Layer: Building block of the transformer.
        # `norm_first=False` (Post-norm) is used here as it can sometimes be more stable with specific initializations
        # or if pre-norm introduces too much initial noise.
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=n_heads,
            dim_feedforward=2 * embed_dim, # Reduced from 4x for efficiency
            dropout=dropout,
            activation='gelu',
            batch_first=True,
            norm_first=False 
        )

        # Transformer Encoder: Stacks multiple `encoder_layer` instances.
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        
        # Action prediction head: Maps the transformer's output to action probabilities.
        self.action_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim // 2, action_dim)
        )
        
        # Initialize weights for better training stability
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """Initializes weights of linear layers, embeddings, and normalization layers for optimal training."""
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def forward(self, states, actions, returns_to_go, timesteps):
        """Forward pass of the Decision Transformer.
        Processes embedded states, actions, and returns-to-go through the transformer.
        """
        batch_size, seq_len = states.shape[0], states.shape[1]
        
        # Embed each modality
        state_embeddings = self.state_embedding(states)
        action_embeddings = self.action_embedding(actions)
        # Unsqueeze adds a dimension (batch_size, seq_len, 1) to match linear layer input
        return_embeddings = self.return_embedding(returns_to_go.unsqueeze(-1))
        
        # Stack tokens efficiently as (batch_size, seq_len, 3, embed_dim)
        # and then reshape to (batch_size, seq_len * 3, embed_dim).
        # The order is crucial: (return, state, action) for causal masking.
        stacked_inputs = torch.stack([return_embeddings, state_embeddings, action_embeddings], dim=2)
        stacked_inputs = stacked_inputs.reshape(batch_size, seq_len * 3, self.embed_dim)
        
        # Add learnable positional embeddings to the stacked inputs.
        stacked_inputs = stacked_inputs + self.pos_embedding[:, :seq_len * 3, :]
        stacked_inputs = self.dropout(stacked_inputs)
        
        # Optimized causal mask to prevent attention to future tokens. 
        # For a given timestep `t`, the action token `a_t` should only attend to `s_t`, `r_t`, and `s_<t`, `a_<t`, `r_<t`.
        mask = self._create_optimized_causal_mask(seq_len, stacked_inputs.device)
        
        # Transformer forward pass
        transformer_outputs = self.transformer(stacked_inputs, mask=mask)
        
        # Extract state tokens: these are at indices 1, 4, 7, ... (1 + 3*k)
        # The model predicts actions based on the state representations.
        state_tokens = transformer_outputs[:, 1::3, :]
        
        # Predict actions using the action head
        action_preds = self.action_head(state_tokens)
        
        return action_preds

    def _create_optimized_causal_mask(self, seq_len, device):
        """Creates an optimized causal mask for the Decision Transformer.
        This mask ensures that the transformer adheres to the causal structure of RL sequences,
        where predictions for time `t` can only depend on information up to time `t`.
        The mask prevents tokens from attending to future states, actions, or returns.
        """
        total_len = seq_len * 3 # Total tokens in the sequence (return, state, action per timestep)
        
        # Initialize a base causal mask (upper triangular matrix with -inf)
        mask = torch.triu(torch.ones(total_len, total_len, device=device) * float('-inf'), diagonal=1)
        
        # Adjust mask for Decision Transformer token structure (R_t, S_t, A_t)
        for i in range(total_len):
            timestep = i // 3 # Current timestep index
            token_type = i % 3  # 0 for Return, 1 for State, 2 for Action
            
            # A token can always see all past tokens.
            # Specific causal adjustments for current timestep:
            if token_type == 1:  # State token (S_t)
                # S_t can see R_t (its corresponding return-to-go for current timestep)
                mask[i, timestep * 3] = 0 
            elif token_type == 2:  # Action token (A_t)
                # A_t can see R_t and S_t (its corresponding return-to-go and state)
                mask[i, timestep * 3:timestep * 3 + 2] = 0
        
        return mask


## 5. Training Function (`train_optimized_decision_transformer`)

This function orchestrates the training loop for the Decision Transformer. It includes:
-   **Mixed Precision Training:** Utilizes `torch.amp.GradScaler` for faster training and reduced memory usage on compatible hardware (GPUs).
-   **Gradient Clipping:** Prevents exploding gradients by limiting their norm.
-   **Early Stopping:** Monitors the validation loss and stops training if no significant improvement is observed for a certain number of epochs, preventing overfitting.
-   **Learning Rate Scheduling:** Adjusts the learning rate dynamically based on the training loss, helping the model converge more effectively.
-   **Model Saving:** Saves the best performing model based on validation loss.

In [None]:
def train_optimized_decision_transformer(model, dataloader, optimizer, scheduler, num_epochs, patience, min_delta):
    """Trains the Decision Transformer model with optimizations.
    Includes mixed precision training, gradient clipping, early stopping, and learning rate scheduling.
    """
    print("\n--- Training Optimized Decision Transformer ---")
    model.train() # Set model to training mode
    
    # Initialize early stopping parameters
    best_loss = float('inf')
    epochs_no_improve = 0
    best_model_state = None
    
    # Mixed precision training for faster computation on CUDA devices
    scaler = torch.amp.GradScaler('cuda') if DEVICE.type == 'cuda' else None

    # Path to save the best model
    os.makedirs('models', exist_ok=True)
    model_path = '../models/lunar_lander_model_5000.pkl'

    for epoch in range(num_epochs):
        total_loss = 0
        total_accuracy = 0
        num_batches = 0
        
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")
        
        for states, actions, returns_to_go, timesteps, target_actions in progress_bar:
            # Move data to the appropriate device (CPU/GPU)
            states = states.to(DEVICE, non_blocking=True)
            actions = actions.to(DEVICE, non_blocking=True)
            returns_to_go = returns_to_go.to(DEVICE, non_blocking=True)
            timesteps = timesteps.to(DEVICE, non_blocking=True)
            target_actions = target_actions.to(DEVICE, non_blocking=True)

            optimizer.zero_grad() # Clear gradients from previous step
            
            # Mixed precision forward pass
            if scaler is not None:
                with torch.amp.autocast('cuda'):
                    action_preds = model(states, actions, returns_to_go, timesteps)
                    # Calculate Cross-Entropy Loss for action prediction (only for the last action in sequence)
                    loss = F.cross_entropy(action_preds[:, -1, :], target_actions)
                
                scaler.scale(loss).backward() # Scale loss and perform backward pass
                scaler.unscale_(optimizer) # Unscale gradients before clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) # Clip gradients
                scaler.step(optimizer) # Update model parameters
                scaler.update() # Update the scaler for the next iteration
            else:
                # Standard float32 training
                action_preds = model(states, actions, returns_to_go, timesteps)
                loss = F.cross_entropy(action_preds[:, -1, :], target_actions)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                optimizer.step()
            
            # Calculate accuracy (no gradient calculation needed)
            with torch.no_grad():
                predicted_actions = torch.argmax(action_preds[:, -1, :], dim=-1)
                accuracy = (predicted_actions == target_actions).float().mean()
            
            # Aggregate loss and accuracy for epoch statistics
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            num_batches += 1
            
            # Update progress bar with current batch metrics
            progress_bar.set_postfix({
                'Loss': f'{total_loss/num_batches:.4f}', 
                'Acc': f'{total_accuracy/num_batches:.4f}',
                'LR': f'{optimizer.param_groups[0]["lr"]:.6f}'
            })

        # Calculate average loss and accuracy for the epoch
        avg_loss = total_loss / num_batches
        avg_accuracy = total_accuracy / num_batches
        
        # Step the learning rate scheduler if enabled
        if scheduler is not None:
            scheduler.step(avg_loss)
        
        print(f"Epoch {epoch + 1}: Loss={avg_loss:.4f}, Acc={avg_accuracy:.4f}, LR={optimizer.param_groups[0]['lr']:.6f}")

        # Early stopping logic
        if avg_loss < best_loss - min_delta:
            best_loss = avg_loss
            epochs_no_improve = 0
            # Save the model's state dictionary if it's the best so far
            best_model_state = model.state_dict().copy()
            torch.save({
                'model_state_dict': best_model_state,
                'state_mean': state_mean,
                'state_std': state_std,
                'model_config': {
                    'state_dim': model.state_dim,
                    'action_dim': model.action_dim,
                    'embed_dim': model.embed_dim,
                    'context_len': model.context_len,
                    'n_heads': DT_N_HEADS,
                    'n_layers': DT_N_LAYERS,
                    'dropout': DT_DROPOUT
                }
            }, model_path)
            print(f"✓ New best model saved with loss: {best_loss:.4f}")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"\n🛑 Early stopping after {epoch + 1} epochs!")
                break

    # Load the best model state after training (if early stopping occurred)
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"✓ Model restored to best state (loss: {best_loss:.4f})")
    
    return model


## 6. Pygame Visualization Functions

These functions handle the setup and rendering of the Lunar Lander environment using Pygame, allowing for real-time visualization of the agent's performance during evaluation.

In [None]:
def init_pygame(width, height):
    """Initializes Pygame for rendering the environment."""
    pygame.init()
    # Add extra height for displaying text (score, episode number)
    screen = pygame.display.set_mode((width, height + 50))
    pygame.display.set_caption("Optimized Lunar Lander Decision Transformer")
    font = pygame.font.Font(None, 36) # Default font, size 36
    return screen, font

def render_env_pygame(env, screen, font, current_score, episode_num, total_episodes):
    """Renders the Gymnasium environment using Pygame and displays score/episode info."""
    frame = env.render() # Get the RGB array frame from the environment
    if frame is not None:
        # Pygame expects (width, height, channels), so transpose the numpy array
        frame = np.transpose(frame, (1, 0, 2))
        pygame_surface = pygame.surfarray.make_surface(frame)
        screen.fill((0, 0, 0)) # Clear screen with black
        screen.blit(pygame_surface, (0, 0)) # Draw the environment frame
    
    # Render score and episode information
    score_text = font.render(f"Score: {current_score:.1f}", True, (255, 255, 255)) # White color
    episode_text = font.render(f"Episode: {episode_num}/{total_episodes}", True, (255, 255, 255))
    
    screen.blit(score_text, (10, screen.get_height() - 40)) # Position at bottom-left
    screen.blit(episode_text, (10, screen.get_height() - 80)) # Position above score
    pygame.display.flip() # Update the full display Surface to the screen


## 7. Main Execution Block

This block orchestrates the entire workflow: environment setup, data loading and preprocessing, model training, and evaluation with visualization. This is where all the previously defined functions are called in sequence to run the experiment.

In [None]:
if __name__ == "__main__":
    print(f"🚀 Device: {DEVICE}")
    print(f"🎯 Using up to {DT_TRAJECTORIES_TO_USE} trajectories with return >= {MIN_EXPERT_RETURN}")

    # --- 7.1 Environment Initialization ---
    env_dt = gym.make(ENV_ID, render_mode='rgb_array')
    state_dim = env_dt.observation_space.shape[0]
    action_dim = env_dt.action_space.n
    # Get max_episode_steps if available, otherwise default to 1000
    max_timestep = env_dt.spec.max_episode_steps if env_dt.spec else 1000

    # --- 7.2 Load and Filter Trajectories ---
    raw_expert_trajectories = load_trajectories_from_pickle(TRAJECTORY_FILE)
    if not raw_expert_trajectories:
        print("❌ No trajectories loaded. Exiting.")
        exit()

    # Smart filtering: prioritize high-performing trajectories for training
    # Sort trajectories by their total return in descending order.
    trajectory_returns = [(i, sum(traj['rewards'])) for i, traj in enumerate(raw_expert_trajectories)]
    trajectory_returns.sort(key=lambda x: x[1], reverse=True)
    
    # Select a proportion of the top-performing trajectories
    top_indices = [idx for idx, ret in trajectory_returns[:int(DT_TRAJECTORIES_TO_USE * 0.8)] if ret >= MIN_EXPERT_RETURN]
    # Select some random remaining good trajectories for diversity
    remaining_good = [idx for idx, ret in trajectory_returns[int(DT_TRAJECTORIES_TO_USE * 0.8):] if ret >= MIN_EXPERT_RETURN]
    
    random_indices = []
    if len(remaining_good) > 0:
        random_indices = random.sample(remaining_good, min(len(remaining_good), DT_TRAJECTORIES_TO_USE - len(top_indices)))
    
    selected_indices = top_indices + random_indices
    # Final filtered list of trajectories, truncated to DT_TRAJECTORIES_TO_USE
    filtered_trajectories = [raw_expert_trajectories[i] for i in selected_indices[:DT_TRAJECTORIES_TO_USE]]
    
    print(f"✓ Selected {len(filtered_trajectories)} high-quality trajectories")
    print(f"  📊 Mean return of selected trajectories: {np.mean([sum(traj['rewards']) for traj in filtered_trajectories]):.1f}")

    # --- 7.3 Data Preparation ---
    # Calculate normalization parameters from the selected trajectories
    state_mean, state_std = calculate_state_normalization_params(filtered_trajectories)

    # Process trajectories for Decision Transformer input format
    processed_dt_data = process_trajectories_optimized(
        filtered_trajectories, DT_CONTEXT_LEN, gamma=0.995, 
        state_mean=state_mean, state_std=state_std
    )
    
    # Create Dataset and DataLoader for efficient batching during training
    dt_dataset = OptimizedTrajectoryDataset(processed_dt_data, USE_DATA_AUGMENTATION, AUGMENTATION_NOISE_STD)
    dt_dataloader = DataLoader(
        dt_dataset, 
        batch_size=DT_BATCH_SIZE, 
        shuffle=True, 
        num_workers=2, # Use multiple workers for faster data loading
        pin_memory=True if DEVICE.type == 'cuda' else False # Pin memory for faster GPU transfers
    )

    # --- 7.4 Model Initialization ---
    dt_model = OptimizedDecisionTransformer(
        state_dim=state_dim,
        action_dim=action_dim,
        embed_dim=DT_EMBED_DIM,
        context_len=DT_CONTEXT_LEN,
        n_heads=DT_N_HEADS,
        n_layers=DT_N_LAYERS,
        dropout=DT_DROPOUT,
        max_timestep=max_timestep
    ).to(DEVICE) # Move model to the specified device

    # Optimized AdamW optimizer with custom betas for transformers
    dt_optimizer = optim.AdamW(
        dt_model.parameters(), 
        lr=DT_LR, 
        weight_decay=DT_WEIGHT_DECAY,
        betas=(0.9, 0.95) 
    )
    
    scheduler = None
    if USE_LR_SCHEDULER:
        # ReduceLROnPlateau scheduler reduces LR when a metric has stopped improving
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            dt_optimizer, mode='min', factor=SCHEDULER_FACTOR, 
            patience=SCHEDULER_PATIENCE
        )

    print(f"🔧 Model parameters: {sum(p.numel() for p in dt_model.parameters()):,}")

    # --- 7.5 Model Training ---
    start_time = time.time()
    dt_model = train_optimized_decision_transformer(
        dt_model, dt_dataloader, dt_optimizer, scheduler, 
        DT_NUM_EPOCHS, EARLY_STOPPING_PATIENCE, EARLY_STOPPING_MIN_DELTA
    )
    training_time = time.time() - start_time
    print(f"⏱️  Training completed in {training_time:.1f} seconds")

    # --- 7.6 Evaluation and Visualization ---
    print("\n🎮 Starting evaluation...")
    
    try:
        # Prompt user for number of evaluation episodes
        NUM_EVAL_EPISODES = int(input(f"Enter number of episodes to watch (default: {NUM_EVAL_EPISODES}): ") or NUM_EVAL_EPISODES)
    except ValueError:
        print(f"Invalid input, using default {NUM_EVAL_EPISODES} episodes.")

    # Initialize pygame for rendering
    _ = env_dt.reset() # Reset env once to get initial frame for dimension
    dummy_frame = env_dt.render()
    render_width = dummy_frame.shape[1] if dummy_frame is not None else 600
    render_height = dummy_frame.shape[0] if dummy_frame is not None else 400

    pygame_screen, pygame_font = init_pygame(render_width, render_height)
    clock = pygame.time.Clock() # To control frame rate

    dt_model.eval() # Set model to evaluation mode (disables dropout, etc.)
    eval_returns = []
    
    for i_episode in range(NUM_EVAL_EPISODES):
        state, _ = env_dt.reset() # Reset environment for new episode
        state = normalize_state(state, state_mean, state_std)

        # Initialize context deques with padding
        states = deque(maxlen=DT_CONTEXT_LEN)
        actions = deque(maxlen=DT_CONTEXT_LEN)
        returns_to_go = deque(maxlen=DT_CONTEXT_LEN)
        timesteps = deque(maxlen=DT_CONTEXT_LEN)

        current_episode_return = 0
        # Set a target return for the agent to aim for
        # This is a crucial aspect of Decision Transformers
        target_return = MIN_EXPERT_RETURN * 1.2 

        # Pre-populate context with zeros/defaults for the initial steps
        for _ in range(DT_CONTEXT_LEN):
            states.append(np.zeros(state_dim, dtype=np.float32))
            actions.append(0) # Default action
            returns_to_go.append(0.0)
            timesteps.append(0)

        with torch.no_grad(): # Disable gradient calculation for inference
            for t in range(max_timestep):
                # Update context with current state, estimated return, and timestep
                states.append(state)
                # The return-to-go is dynamic: it's the target return minus current cumulative reward
                returns_to_go.append(max(target_return - current_episode_return, 0)) 
                timesteps.append(t)
                
                # Prepare inputs for the model (unsqueeze to add batch dimension)
                s_input = torch.tensor(np.array(list(states)), dtype=torch.float32, device=DEVICE).unsqueeze(0)
                a_input = torch.tensor(np.array(list(actions)), dtype=torch.long, device=DEVICE).unsqueeze(0)
                r_input = torch.tensor(np.array(list(returns_to_go)), dtype=torch.float32, device=DEVICE).unsqueeze(0)
                t_input = torch.tensor(np.array(list(timesteps)), dtype=torch.long, device=DEVICE).unsqueeze(0)

                # Model inference: predict actions
                action_preds = dt_model(s_input, a_input, r_input, t_input)
                # Select the action with the highest probability from the last prediction in sequence
                predicted_action = torch.argmax(action_preds[0, -1, :]).item()

                # Take a step in the environment
                next_state, reward, terminated, truncated, _ = env_dt.step(predicted_action)
                done = terminated or truncated
                
                next_state = normalize_state(next_state, state_mean, state_std)
                current_episode_return += reward
                state = next_state
                actions.append(predicted_action) # Add taken action to context

                # Render the environment using Pygame
                render_env_pygame(env_dt, pygame_screen, pygame_font, current_episode_return, i_episode + 1, NUM_EVAL_EPISODES)
                clock.tick(PYGAME_FPS) # Control rendering speed

                # Handle Pygame events (e.g., closing the window)
                for event in pygame.event.get():
                    if event.type == pygame.QUIT:
                        pygame.quit()
                        exit()

                if done:
                    break

        eval_returns.append(current_episode_return)
        print(f"Episode {i_episode + 1}: Return = {current_episode_return:.2f}")

    # --- 7.7 Evaluation Results Summary ---
    print(f"\n📊 Evaluation Results:")
    print(f"Mean Return: {np.mean(eval_returns):.2f} ± {np.std(eval_returns):.2f}")
    print(f"Best Return: {np.max(eval_returns):.2f}")
    # Calculate success rate (e.g., return >= 200 is generally considered a success for Lunar Lander)
    print(f"Success Rate: {sum(1 for r in eval_returns if r >= 200)/len(eval_returns)*100:.1f}%")
    
    env_dt.close() # Close the Gymnasium environment
    pygame.quit() # Quit Pygame
    print("✅ Evaluation complete!")
