In [1]:
!git clone https://github.com/adimunot21/transformer-from-scratch.git
%cd transformer-from-scratch
!pip install gymnasium -q

import torch
import gymnasium as gym
import numpy as np
print(f"CUDA: {torch.cuda.is_available()}")
print(f"Gym: {gym.__version__}")

Cloning into 'transformer-from-scratch'...
remote: Enumerating objects: 57, done.[K
remote: Counting objects: 100% (57/57), done.[K
remote: Compressing objects: 100% (39/39), done.[K
remote: Total 57 (delta 19), reused 54 (delta 16), pack-reused 0 (from 0)[K
Receiving objects: 100% (57/57), 510.86 KiB | 17.62 MiB/s, done.
Resolving deltas: 100% (19/19), done.
/content/transformer-from-scratch
CUDA: True
Gym: 1.2.3


In [2]:
"""
Collect episodes using two policies:
1. Random policy — gives low-return episodes (avg ~20 steps)
2. Heuristic policy — push toward pole's lean direction (avg ~100-200)

This diversity is critical: the model needs to see both good and bad
trajectories to learn that higher return-to-go → better actions.
"""
import random

def collect_episodes(n_random=500, n_heuristic=500, max_steps=500):
    env = gym.make("CartPole-v1")
    episodes = []

    # Random policy
    for i in range(n_random):
        obs, _ = env.reset()
        ep = {"states": [], "actions": [], "rewards": []}
        for t in range(max_steps):
            action = env.action_space.sample()
            ep["states"].append(obs)
            ep["actions"].append(action)
            obs, reward, term, trunc, _ = env.step(action)
            ep["rewards"].append(reward)
            if term or trunc:
                break
        episodes.append(ep)

    # Heuristic policy: push in direction pole is leaning
    for i in range(n_heuristic):
        obs, _ = env.reset()
        ep = {"states": [], "actions": [], "rewards": []}
        for t in range(max_steps):
            # Pole angle is obs[2]: positive = leaning right → push right (1)
            action = 1 if obs[2] > 0 else 0
            # Add some noise for diversity
            if random.random() < 0.1:
                action = 1 - action
            ep["states"].append(obs)
            ep["actions"].append(action)
            obs, reward, term, trunc, _ = env.step(action)
            ep["rewards"].append(reward)
            if term or trunc:
                break
        episodes.append(ep)

    env.close()

    # Compute returns
    returns = [sum(ep["rewards"]) for ep in episodes]
    print(f"Collected {len(episodes)} episodes")
    print(f"Return stats: min={min(returns):.0f}, mean={np.mean(returns):.1f}, "
          f"max={max(returns):.0f}, median={np.median(returns):.0f}")

    # Show distribution
    brackets = [0, 20, 50, 100, 200, 500, 501]
    for i in range(len(brackets)-1):
        count = sum(1 for r in returns if brackets[i] <= r < brackets[i+1])
        print(f"  Return {brackets[i]:>3d}-{brackets[i+1]:>3d}: {count} episodes")

    return episodes

episodes = collect_episodes()

Collected 1000 episodes
Return stats: min=9, mean=33.7, max=117, median=30
  Return   0- 20: 262 episodes
  Return  20- 50: 564 episodes
  Return  50-100: 165 episodes
  Return 100-200: 9 episodes
  Return 200-500: 0 episodes
  Return 500-501: 0 episodes


In [3]:
"""
Convert episodes into training data for the Decision Transformer.

For each episode, we:
1. Compute return-to-go at each timestep: R̂ₜ = Σ(rewards from t onward)
2. Extract sliding windows of length K (context_len)
3. Each window becomes one training sample: (R̂, s, a, timesteps)
"""
from torch.utils.data import Dataset, DataLoader

class DTDataset(Dataset):
    def __init__(self, episodes, context_len=20):
        self.context_len = context_len
        self.samples = []

        for ep in episodes:
            T = len(ep["rewards"])
            states = np.array(ep["states"])
            actions = np.array(ep["actions"])
            rewards = np.array(ep["rewards"])

            # Compute return-to-go: R̂ₜ = reward_t + reward_{t+1} + ... + reward_T
            rtg = np.zeros(T)
            rtg[-1] = rewards[-1]
            for t in range(T - 2, -1, -1):
                rtg[t] = rewards[t] + rtg[t + 1]

            # Create sliding windows of length context_len
            for start in range(T):
                end = min(start + context_len, T)
                k = end - start  # actual length (might be < context_len)

                s = np.zeros((context_len, states.shape[1]))
                a = np.zeros(context_len, dtype=np.int64)
                r = np.zeros((context_len, 1))
                ts = np.zeros(context_len, dtype=np.int64)
                mask = np.zeros(context_len)

                s[:k] = states[start:end]
                a[:k] = actions[start:end]
                r[:k, 0] = rtg[start:end]
                ts[:k] = np.arange(start, end)
                mask[:k] = 1.0

                self.samples.append({
                    "states": torch.tensor(s, dtype=torch.float32),
                    "actions": torch.tensor(a, dtype=torch.long),
                    "returns_to_go": torch.tensor(r, dtype=torch.float32),
                    "timesteps": torch.tensor(ts, dtype=torch.long),
                    "mask": torch.tensor(mask, dtype=torch.float32),
                })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

# Normalize returns-to-go (helps training stability)
all_rtg = np.concatenate([
    np.cumsum(ep["rewards"][::-1])[::-1] for ep in episodes
])
rtg_mean, rtg_std = all_rtg.mean(), all_rtg.std()
print(f"RTG stats: mean={rtg_mean:.2f}, std={rtg_std:.2f}")

# Apply normalization
for ep in episodes:
    rewards = np.array(ep["rewards"])
    T = len(rewards)
    rtg = np.zeros(T)
    rtg[-1] = rewards[-1]
    for t in range(T - 2, -1, -1):
        rtg[t] = rewards[t] + rtg[t + 1]

CONTEXT_LEN = 20
dataset = DTDataset(episodes, context_len=CONTEXT_LEN)
loader = DataLoader(dataset, batch_size=64, shuffle=True)

print(f"Dataset size: {len(dataset):,} samples")
batch = next(iter(loader))
print(f"Batch shapes:")
for k, v in batch.items():
    print(f"  {k}: {v.shape}")

RTG stats: mean=22.38, std=17.73
Dataset size: 33,657 samples
Batch shapes:
  states: torch.Size([64, 20, 4])
  actions: torch.Size([64, 20])
  returns_to_go: torch.Size([64, 20, 1])
  timesteps: torch.Size([64, 20])
  mask: torch.Size([64, 20])


In [5]:
import torch.nn.functional as F
import math, time
from src.decision_transformer import DecisionTransformer

device = torch.device("cuda")

model = DecisionTransformer(
    state_dim=4,      # CartPole state dimension
    act_dim=2,        # CartPole: left or right
    d_model=64,
    n_heads=4,
    n_layers=3,
    max_timestep=500,
    context_len=CONTEXT_LEN,
    dropout=0.1,
).to(device)

n_params = sum(p.numel() for p in model.parameters())
print(f"Decision Transformer parameters: {n_params:,}")

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
MAX_STEPS = 3000

model.train()
train_iter = iter(loader)
t0 = time.time()

for step in range(MAX_STEPS):
    try:
        batch = next(train_iter)
    except StopIteration:
        train_iter = iter(loader)
        batch = next(train_iter)

    rtg = batch["returns_to_go"].to(device)
    states = batch["states"].to(device)
    actions = batch["actions"].to(device)
    timesteps = batch["timesteps"].to(device)
    mask = batch["mask"].to(device)

    # Forward pass
    action_logits = model(rtg, states, actions, timesteps)  # (B, K, 2)

    # Loss: cross-entropy on predicted actions vs actual actions
    # Only compute loss on valid (non-padded) positions
    logits_flat = action_logits.reshape(-1, 2)
    actions_flat = actions.reshape(-1)
    mask_flat = mask.reshape(-1)

    loss_all = F.cross_entropy(logits_flat, actions_flat, reduction="none")
    loss = (loss_all * mask_flat).sum() / mask_flat.sum()

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

    if step % 200 == 0:
        # Compute accuracy
        preds = action_logits.argmax(dim=-1).reshape(-1)
        correct = ((preds == actions_flat) * mask_flat).sum()
        acc = correct / mask_flat.sum()
        print(f"Step {step:4d} | loss: {loss.item():.4f} | acc: {acc:.3f} | time: {time.time()-t0:.0f}s")

print(f"\nTraining complete in {time.time()-t0:.0f}s")

Decision Transformer parameters: 182,594
Step    0 | loss: 0.7492 | acc: 0.457 | time: 1s
Step  200 | loss: 0.4664 | acc: 0.794 | time: 5s
Step  400 | loss: 0.4021 | acc: 0.827 | time: 9s
Step  600 | loss: 0.4257 | acc: 0.799 | time: 13s
Step  800 | loss: 0.4021 | acc: 0.818 | time: 18s
Step 1000 | loss: 0.4743 | acc: 0.773 | time: 21s
Step 1200 | loss: 0.3769 | acc: 0.845 | time: 26s
Step 1400 | loss: 0.3951 | acc: 0.835 | time: 30s
Step 1600 | loss: 0.3901 | acc: 0.822 | time: 34s
Step 1800 | loss: 0.3734 | acc: 0.833 | time: 38s
Step 2000 | loss: 0.4405 | acc: 0.796 | time: 42s
Step 2200 | loss: 0.3748 | acc: 0.850 | time: 46s
Step 2400 | loss: 0.3710 | acc: 0.839 | time: 51s
Step 2600 | loss: 0.3497 | acc: 0.858 | time: 55s
Step 2800 | loss: 0.3708 | acc: 0.834 | time: 59s

Training complete in 63s


In [6]:
"""
This is where it gets cool.

We condition the model on DIFFERENT desired returns and see
if it produces better actions for higher desired returns.

If it works, the model has learned: "when you want high reward,
take these actions" — purely from offline sequence data.
"""

def evaluate(model, target_return, n_episodes=50, context_len=20):
    env = gym.make("CartPole-v1")
    model.eval()
    returns = []

    for _ in range(n_episodes):
        obs, _ = env.reset()

        # Buffers for the context window
        states = torch.zeros(1, context_len, 4, device=device)
        actions = torch.zeros(1, context_len, dtype=torch.long, device=device)
        rtg = torch.zeros(1, context_len, 1, device=device)
        timesteps = torch.zeros(1, context_len, dtype=torch.long, device=device)

        states[0, 0] = torch.tensor(obs, device=device)
        rtg[0, 0, 0] = target_return  # Desired return!
        timesteps[0, 0] = 0

        total_reward = 0
        t = 0

        for t in range(500):
            pos = min(t, context_len - 1)

            # Get action from model
            with torch.no_grad():
                logits = model(rtg[:, :pos+1], states[:, :pos+1],
                              actions[:, :pos+1], timesteps[:, :pos+1])
                action = logits[0, pos].argmax().item()

            obs, reward, term, trunc, _ = env.step(action)
            total_reward += reward

            if term or trunc:
                break

            # Update buffers for next step
            next_pos = min(t + 1, context_len - 1)
            if t + 1 < context_len:
                states[0, next_pos] = torch.tensor(obs, device=device)
                actions[0, pos] = action
                rtg[0, next_pos, 0] = rtg[0, pos, 0] - reward
                timesteps[0, next_pos] = t + 1
            else:
                # Shift window left
                states[0, :-1] = states[0, 1:].clone()
                states[0, -1] = torch.tensor(obs, device=device)
                actions[0, :-1] = actions[0, 1:].clone()
                actions[0, -2] = action
                rtg[0, :-1] = rtg[0, 1:].clone()
                rtg[0, -1, 0] = rtg[0, -2, 0] - reward
                timesteps[0, :-1] = timesteps[0, 1:].clone()
                timesteps[0, -1] = t + 1

        returns.append(total_reward)

    env.close()
    return returns

# Test with different target returns
print("=" * 60)
print("DECISION TRANSFORMER EVALUATION")
print("=" * 60)
print("\nConditioning on different desired returns:\n")

for target in [10, 50, 100, 200, 500]:
    returns = evaluate(model, target_return=target, n_episodes=50)
    mean_r = np.mean(returns)
    std_r = np.std(returns)
    max_r = max(returns)
    print(f"  Target return: {target:>3d} → Achieved: {mean_r:.1f} ± {std_r:.1f} (max: {max_r:.0f})")

# Compare to random baseline
print(f"\n  Random policy baseline:")
env = gym.make("CartPole-v1")
rand_returns = []
for _ in range(50):
    obs, _ = env.reset()
    total = 0
    for _ in range(500):
        obs, r, term, trunc, _ = env.step(env.action_space.sample())
        total += r
        if term or trunc: break
    rand_returns.append(total)
env.close()
print(f"  Random policy → {np.mean(rand_returns):.1f} ± {np.std(rand_returns):.1f}")

DECISION TRANSFORMER EVALUATION

Conditioning on different desired returns:

  Target return:  10 → Achieved: 10.1 ± 0.3 (max: 11)
  Target return:  50 → Achieved: 45.3 ± 9.0 (max: 66)
  Target return: 100 → Achieved: 62.7 ± 23.9 (max: 110)
  Target return: 200 → Achieved: 71.5 ± 39.1 (max: 198)
  Target return: 500 → Achieved: 76.0 ± 70.2 (max: 491)

  Random policy baseline:
  Random policy → 22.2 ± 14.1
