## 1. Installs and Imports
This notebook implements a Decision Transformer for the `LunarLander-v3` environment.

In [1]:
import os
import random
import csv
from datetime import datetime
import collections
import math

import numpy as np
import gymnasium as gym
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

## 2. Configuration Parameters
Parameters for the environment, training, and model architecture. These are adapted for `LunarLander-v3` while keeping the core model parameters the same as in the original notebook.

In [None]:
# --- Environment and Dataset Parameters ---
env_name = 'LunarLander-v3'
rtg_target = 100               
rtg_scale = 100                 # Scale to normalize returns to go
num_episodes_dataset = 500      # Number of episodes to generate for the dataset

# --- Evaluation Parameters ---
max_eval_ep_len = 1000      # Max length of one evaluation episode
num_eval_ep = 10            # Number of evaluation episodes per iteration

# --- Training Parameters ---
batch_size = 64             # Training batch size
lr = 1e-4                   # Learning rate
wt_decay = 1e-4             # Weight decay
warmup_steps = 5000         # Warmup steps for lr scheduler
max_train_iters = 200
num_updates_per_iter = 100

# --- Model Parameters ---
context_len = 20        # K in decision transformer
n_blocks = 3            # Num of transformer blocks
embed_dim = 128         # Embedding (hidden) dim of transformer
n_heads = 1             # Num of transformer heads
dropout_p = 0.1         # Dropout probability

# --- Logging and Device ---
log_dir = "./dt_runs_lunarlander/"
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 3. Dataset Generation
Since `LunarLander-v2` doesn't have a standard offline dataset like D4RL, we generate one by running a random agent and collecting its trajectories.

In [None]:
def generate_dataset(env_name, num_episodes):
    """Generates a dataset of trajectories using a random policy."""
    env = gym.make(env_name)
    trajectories = []
    print(f"Generating a dataset of {num_episodes} episodes...")

    for i in tqdm(range(num_episodes)):
        obs_list, act_list, rew_list, done_list = [], [], [], []
        done = False
        obs, _ = env.reset()

        while not done:
            action = env.action_space.sample() # Random action
            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            obs_list.append(obs)
            act_list.append(action)
            rew_list.append(reward)
            done_list.append(terminated) # Use termination flag

            obs = next_obs

        trajectories.append({
            'observations': np.array(obs_list, dtype=np.float32),
            'actions': np.array(act_list, dtype=np.int64),
            'rewards': np.array(rew_list, dtype=np.float32),
            'terminals': np.array(done_list, dtype=np.bool_)
        })

    env.close()
    return trajectories

# Generate the dataset
trajectories = generate_dataset(env_name, num_episodes_dataset)

## 4. Decision Transformer Model
The model architecture is adapted to handle discrete actions by using an `nn.Embedding` layer for actions and predicting action logits.

In [None]:
class MaskedCausalAttention(nn.Module):
    def __init__(self, h_dim, max_T, n_heads, drop_p):
        super().__init__()
        self.n_heads = n_heads
        self.max_T = max_T
        self.q_net = nn.Linear(h_dim, h_dim)
        self.k_net = nn.Linear(h_dim, h_dim)
        self.v_net = nn.Linear(h_dim, h_dim)
        self.proj_net = nn.Linear(h_dim, h_dim)
        self.att_drop = nn.Dropout(drop_p)
        self.proj_drop = nn.Dropout(drop_p)
        ones = torch.ones((max_T, max_T))
        mask = torch.tril(ones).view(1, 1, max_T, max_T)
        self.register_buffer('mask',mask)

    def forward(self, x):
        B, T, C = x.shape
        N, D = self.n_heads, C // self.n_heads
        q = self.q_net(x).view(B, T, N, D).transpose(1,2)
        k = self.k_net(x).view(B, T, N, D).transpose(1,2)
        v = self.v_net(x).view(B, T, N, D).transpose(1,2)
        weights = q @ k.transpose(2,3) / math.sqrt(D)
        weights = weights.masked_fill(self.mask[...,:T,:T] == 0, float('-inf'))
        normalized_weights = F.softmax(weights, dim=-1)
        attention = self.att_drop(normalized_weights @ v)
        attention = attention.transpose(1, 2).contiguous().view(B,T,N*D)
        out = self.proj_drop(self.proj_net(attention))
        return out

class Block(nn.Module):
    def __init__(self, h_dim, max_T, n_heads, drop_p):
        super().__init__()
        self.attention = MaskedCausalAttention(h_dim, max_T, n_heads, drop_p)
        self.mlp = nn.Sequential(
                nn.Linear(h_dim, 4*h_dim),
                nn.GELU(),
                nn.Linear(4*h_dim, h_dim),
                nn.Dropout(drop_p),
            )
        self.ln1 = nn.LayerNorm(h_dim)
        self.ln2 = nn.LayerNorm(h_dim)

    def forward(self, x):
        x = x + self.attention(x)
        x = self.ln1(x)
        x = x + self.mlp(x)
        x = self.ln2(x)
        return x

class DecisionTransformer(nn.Module):
    def __init__(self, state_dim, act_dim, n_blocks, h_dim, context_len,
                 n_heads, drop_p, max_timestep=4096):
        super().__init__()
        self.state_dim = state_dim
        self.act_dim = act_dim
        self.h_dim = h_dim

        input_seq_len = 3 * context_len
        blocks = [Block(h_dim, input_seq_len, n_heads, drop_p) for _ in range(n_blocks)]
        self.transformer = nn.Sequential(*blocks)

        self.embed_ln = nn.LayerNorm(h_dim)
        self.embed_timestep = nn.Embedding(max_timestep, h_dim)
        self.embed_rtg = torch.nn.Linear(1, h_dim)
        self.embed_state = torch.nn.Linear(state_dim, h_dim)

        # Use nn.Embedding for discrete actions
        self.embed_action = torch.nn.Embedding(act_dim, h_dim)

        self.predict_rtg = torch.nn.Linear(h_dim, 1)
        self.predict_state = torch.nn.Linear(h_dim, state_dim)
        # Predict action logits for classification
        self.predict_action = nn.Linear(h_dim, act_dim)

    def forward(self, timesteps, states, actions, returns_to_go):
        B, T, _ = states.shape

        time_embeddings = self.embed_timestep(timesteps)
        state_embeddings = self.embed_state(states) + time_embeddings
        action_embeddings = self.embed_action(actions) + time_embeddings
        returns_embeddings = self.embed_rtg(returns_to_go) + time_embeddings

        h = torch.stack(
            (returns_embeddings, state_embeddings, action_embeddings), dim=1
        ).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim)

        h = self.embed_ln(h)
        h = self.transformer(h)
        h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3)

        return_preds = self.predict_rtg(h[:,2])
        state_preds = self.predict_state(h[:,2])
        action_preds = self.predict_action(h[:,1])

        return state_preds, action_preds, return_preds

## 5. Utilities and Dataset Class
Helper functions for training and evaluation, and a custom `Dataset` class to handle the generated trajectories.

In [None]:
def discount_cumsum(x, gamma):
    disc_cumsum = np.zeros_like(x)
    disc_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        disc_cumsum[t] = x[t] + gamma * disc_cumsum[t+1]
    return disc_cumsum

def compute_dataset_stats(trajectories):
    all_states = np.concatenate([traj['observations'] for traj in trajectories], axis=0)
    state_mean = np.mean(all_states, axis=0)
    state_std = np.std(all_states, axis=0) + 1e-6
    return state_mean, state_std

class TrajectoryDataset(Dataset):
    def __init__(self, trajectories, context_len, rtg_scale, state_mean, state_std):
        self.context_len = context_len
        self.state_mean = torch.from_numpy(state_mean).float()
        self.state_std = torch.from_numpy(state_std).float()
        self.trajectories = []

        for traj in trajectories:
            # Normalize states
            obs_norm = (torch.from_numpy(traj['observations']).float() - self.state_mean) / self.state_std
            # Compute returns-to-go
            rtg = discount_cumsum(traj['rewards'], 1.0) / rtg_scale

            self.trajectories.append({
                'observations': obs_norm,
                'actions': torch.from_numpy(traj['actions']).long(),
                'returns_to_go': torch.from_numpy(rtg).float()
            })

    def __len__(self):
        return len(self.trajectories)

    def __getitem__(self, idx):
        traj = self.trajectories[idx]
        traj_len = traj['observations'].shape[0]

        si = random.randint(0, traj_len - 1)

        states = traj['observations'][si : si + self.context_len]
        actions = traj['actions'][si : si + self.context_len]
        returns_to_go = traj['returns_to_go'][si : si + self.context_len].unsqueeze(-1)

        # Padding
        tlen = states.shape[0]
        states = torch.cat([torch.zeros((self.context_len - tlen, states.shape[1])), states], 0)
        actions = torch.cat([torch.zeros(self.context_len - tlen, dtype=torch.long), actions], 0)
        returns_to_go = torch.cat([torch.zeros((self.context_len - tlen, 1)), returns_to_go], 0)

        timesteps = torch.arange(si, si + tlen)
        timesteps = torch.cat([torch.zeros(self.context_len - tlen, dtype=torch.long), timesteps], 0)
        traj_mask = torch.cat([torch.zeros(self.context_len - tlen, dtype=torch.long), torch.ones(tlen, dtype=torch.long)], 0)

        return timesteps, states, actions, returns_to_go, traj_mask

def evaluate_on_env(model, device, context_len, env, rtg_target, rtg_scale,
                    num_eval_ep, max_test_ep_len, state_mean, state_std):
    model.eval()
    total_reward = 0
    state_dim = env.observation_space.shape[0]
    act_dim = env.action_space.n

    state_mean = torch.from_numpy(state_mean).to(device)
    state_std = torch.from_numpy(state_std).to(device)

    with torch.no_grad():
        for _ in range(num_eval_ep):
            obs, _ = env.reset()
            rtg = rtg_target / rtg_scale

            states = torch.zeros((1, max_test_ep_len, state_dim), dtype=torch.float32, device=device)
            actions = torch.zeros((1, max_test_ep_len), dtype=torch.long, device=device)
            rewards_to_go = torch.zeros((1, max_test_ep_len, 1), dtype=torch.float32, device=device)
            timesteps = torch.arange(max_test_ep_len, dtype=torch.long, device=device).unsqueeze(0)

            ep_reward = 0
            for t in range(max_test_ep_len):
                states[0, t] = torch.from_numpy(obs).to(device)
                states[0, t] = (states[0, t] - state_mean) / state_std
                rewards_to_go[0, t] = rtg

                start_idx = max(0, t - context_len + 1)
                _, act_preds, _ = model.forward(timesteps[:, start_idx:t+1],
                                                states[:, start_idx:t+1],
                                                actions[:, start_idx:t+1],
                                                rewards_to_go[:, start_idx:t+1])
                
                act = torch.argmax(act_preds[0, -1]).item()
                actions[0, t] = act

                obs, reward, terminated, truncated, _ = env.step(act)
                done = terminated or truncated
                ep_reward += reward
                rtg -= (reward / rtg_scale)

                if done:
                    break
            total_reward += ep_reward

    return total_reward / num_eval_ep

## 6. Training Script

In [None]:
start_time = datetime.now().replace(microsecond=0)
start_time_str = start_time.strftime("%y-%m-%d-%H-%M-%S")
prefix = "dt_" + env_name
save_model_name = prefix + "_model_" + start_time_str + ".pt"
save_model_path = os.path.join(log_dir, save_model_name)

print("=" * 60)
print("start time: " + start_time_str)
print("model save path: " + save_model_path)
print("=" * 60)

# --- Initialize Dataset and DataLoader ---
state_mean, state_std = compute_dataset_stats(trajectories)
train_dataset = TrajectoryDataset(trajectories, context_len, rtg_scale, state_mean, state_std)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, drop_last=True)
data_iter = iter(train_loader)

# --- Initialize Environment and Model ---
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
act_dim = env.action_space.n

model = DecisionTransformer(
    state_dim=state_dim,
    act_dim=act_dim,
    n_blocks=n_blocks,
    h_dim=embed_dim,
    context_len=context_len,
    n_heads=n_heads,
    drop_p=dropout_p,
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wt_decay)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda steps: min((steps + 1) / warmup_steps, 1))

total_updates = 0
for i_train_iter in range(max_train_iters):
    log_action_losses = []
    model.train()

    for _ in range(num_updates_per_iter):
        try:
            timesteps, states, actions, returns_to_go, traj_mask = next(data_iter)
        except StopIteration:
            data_iter = iter(train_loader)
            timesteps, states, actions, returns_to_go, traj_mask = next(data_iter)

        timesteps, states, actions, returns_to_go, traj_mask = (
            timesteps.to(device),
            states.to(device),
            actions.to(device),
            returns_to_go.to(device),
            traj_mask.to(device)
        )

        state_preds, action_preds, return_preds = model.forward(
            timesteps=timesteps,
            states=states,
            actions=actions,
            returns_to_go=returns_to_go
        )

        # Apply mask and reshape for loss calculation
        action_preds_masked = action_preds.view(-1, act_dim)[traj_mask.view(-1,) > 0]
        action_target_masked = actions.view(-1)[traj_mask.view(-1,) > 0]

        action_loss = F.cross_entropy(action_preds_masked, action_target_masked)

        optimizer.zero_grad()
        action_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
        optimizer.step()
        scheduler.step()

        log_action_losses.append(action_loss.detach().cpu().item())

    # --- Evaluation ---
    eval_avg_reward = evaluate_on_env(model, device, context_len, env, rtg_target, rtg_scale,
                                    num_eval_ep, max_eval_ep_len, state_mean, state_std)

    mean_action_loss = np.mean(log_action_losses)
    time_elapsed = str(datetime.now().replace(microsecond=0) - start_time)
    total_updates += num_updates_per_iter

    log_str = (
        f"\n{'=' * 60}\n"
        f"Time elapsed: {time_elapsed}\n"
        f"Num of updates: {total_updates}\n"
        f"Action loss: {mean_action_loss:.5f}\n"
        f"Eval avg reward: {eval_avg_reward:.5f}\n"
    )
    print(log_str)

    # Save model
    print("Saving current model at: " + save_model_path)
    torch.save(model.state_dict(), save_model_path)

env.close()
print("\nFinished training!")