In [None]:
import h5py
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from transformers import DecisionTransformerConfig, DecisionTransformerModel
import torch.optim as optim
import torch.nn as nn

# --- Constants ---
STATE_DIM = 105       # antv5 observation dim
ACT_DIM = 8           # antv5 action dim
MAX_LENGTH = 20       # DT context window
BATCH_SIZE = 16
LR = 1e-4
EPOCHS = 20
DATA_PATH = "../data/processed/concatenated_data.hdf5"
DT_MODEL_PATH = "../models/best_DT.pth"


In [None]:
# --- Data Loading ---
with h5py.File(DATA_PATH, "r") as f:
    data = f["concatenated_data"]
    actions = data["actions"][:]
    observations = data["observations"][:]
    rewards_to_go = data["rewards_to_go"][:]
    time_to_go = data["time_to_go"][:]

rewards_to_go = rewards_to_go.reshape(-1, 1)
time_to_go = time_to_go.reshape(-1, 1)

# Combine into (state, rtg, timestep) for DT
X = np.concatenate((observations, rewards_to_go, time_to_go), axis=-1)
y = actions

In [None]:
# --- Split Data into Sequences (Necessary for DT) ---
def create_sequences(X, y, seq_len=MAX_LENGTH):
    num_sequences = len(X) - seq_len + 1
    sequences_X, sequences_y = [], []
    for i in range(num_sequences):
        sequences_X.append(X[i:i+seq_len])
        sequences_y.append(y[i:i+seq_len])
    return np.stack(sequences_X), np.stack(sequences_y)


In [None]:
# --- DataLoader ---
def create_dataloaders(X_train, X_test, X_val, y_train, y_test, y_val, batch_size=BATCH_SIZE):
    train_dataset = TensorDataset(
        torch.FloatTensor(X_train),  # (states, rtg, timesteps)
        torch.FloatTensor(y_train)   # actions
    )
    val_dataset = TensorDataset(torch.FloatTensor(X_val), torch.FloatTensor(y_val))
    test_dataset = TensorDataset(torch.FloatTensor(X_test), torch.FloatTensor(y_test))

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    return train_loader, val_loader, test_loader

In [None]:
# --- Decision Transformer Setup ---
config = DecisionTransformerConfig(
    state_dim=STATE_DIM,
    act_dim=ACT_DIM,
    max_length=MAX_LENGTH,
    n_positions=MAX_LENGTH,
)
model = DecisionTransformerModel(config)
optimizer = optim.Adam(model.parameters(), lr=LR)
loss_fn = nn.MSELoss()

In [None]:
# --- Training Loop ---
def train_dt(model, train_loader, val_loader, epochs=EPOCHS):
    best_val_loss = float('inf')
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for states_actions, targets in train_loader:
            # states_actions: (batch, seq_len, STATE_DIM + 2)
            # targets: (batch, seq_len, ACT_DIM)
            optimizer.zero_grad()
            
            # Split inputs into states, rtg, timesteps
            states = states_actions[:, :, :STATE_DIM]
            rtg = states_actions[:, :, STATE_DIM].unsqueeze(-1)
            timesteps = states_actions[:, :, STATE_DIM+1].unsqueeze(-1)
            
            # Forward pass (DT predicts actions)
            outputs = model(
                states=states,
                actions=targets,  # DT uses past actions for prediction
                rewards_to_go=rtg,
                timesteps=timesteps.long(),
            ).logits
            
            loss = loss_fn(outputs, targets)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        # Validation
        val_loss = evaluate_dt(model, val_loader)
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), DT_MODEL_PATH)

In [None]:
def evaluate_dt(model, loader):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for states_actions, targets in loader:
            states = states_actions[:, :, :STATE_DIM]
            rtg = states_actions[:, :, STATE_DIM].unsqueeze(-1)
            timesteps = states_actions[:, :, STATE_DIM+1].unsqueeze(-1)
            
            outputs = model(
                states=states,
                actions=targets,
                rewards_to_go=rtg,
                timesteps=timesteps.long(),
            ).logits
            
            loss = loss_fn(outputs, targets)
            total_loss += loss.item()
    return total_loss / len(loader)

In [None]:
# --- Main ---

X, y = load_data()

# Split into sequences
X_seq, y_seq = create_sequences(X, y)
X_train, X_test, X_val, y_train, y_test, y_val = train_test_val_split(X_seq, y_seq)

# Create dataloaders
train_loader, val_loader, test_loader = create_dataloaders(
    X_train, X_test, X_val, y_train, y_test, y_val
)

# Train DT
train_dt(model, train_loader, val_loader)

# Test
test_loss = evaluate_dt(model, test_loader)
print(f"Test Loss: {test_loss:.4f}")