# Implementing the Decision Transformer from scratch. 
Evaluating on CartPole-v1 environment.
Paper: https://arxiv.org/pdf/2106.01345.pdf

Pavel Nakaznenko, 2023

# Setup & Install

In [None]:
!pip install gym tqdm numpy torch torchvision

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import gym
import random
import numpy as np
from tqdm.auto import tqdm

In [3]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Prepare the dataset

In [4]:
def encode_one_hot(target_dim, num_dims):
    '''A helper function to one-hot encode agent's actions'''
    result = np.zeros(num_dims)
    result[target_dim] = 1
    return result

Decision Transformer requires a dataset of trajectories from an expert policy to train on.
There is a closeform solution for CartPole-v1, which always yields maximum reward. See here for more details:
https://towardsdatascience.com/how-to-beat-the-cartpole-game-in-5-lines-5ab4e738c93f

In [5]:
def theta_omega_policy_4_cart_pole(obs):
    '''Closeform solution for CartPole-v1'''
    theta, w = obs[2:4]
    if abs(theta) < 0.03:
        return 0 if w < 0 else 1
    else:
        return 0 if theta < 0 else 1

In [6]:
class EnvDataset(torch.utils.data.Dataset):
    '''Environment trajectories dataset'''
    def __init__(self, env, max_length, num_trajectories, goal):
        self.data = []
        self.max_length = max_length
        self.env = env

        # Collect trajectories
        pbar = tqdm(range(num_trajectories), desc="Generating trajectories", total=num_trajectories)
        while len(self.data)<num_trajectories:
            state = env.reset()
            trajectory = []
            total_reward = 0
            for t in range(max_length):
                action = theta_omega_policy_4_cart_pole(state)
                next_state, reward, done, _ = env.step(action)
                trajectory.append((state, action, reward))
                total_reward += reward
                state = next_state
                if done:
                    break
                    
            # Filter trajectories by total reward
            if total_reward >= goal:
                self.data.append(trajectory)
                pbar.update(1)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        trajectory = self.data[idx]
        states, actions, rewards = zip(*trajectory)
        actions = [encode_one_hot(action, self.env.action_space.n) for action in actions]
        # Compute returns-to-go
        returns_to_go = np.cumsum(rewards[::-1])[::-1]

        return (returns_to_go, states, actions)

In [7]:
def process_sample(R, s, a, max_length):
    '''Helper function to properly stack and padd the trajectories'''
    pad_len_R = max_length - len(R)
    pad_len_s = max_length - len(s)
    pad_len_a = max_length - len(a)
    states = np.vstack((np.zeros((pad_len_s, len(s[0]))), s))
    actions = np.array([np.zeros(env.action_space.n)] * pad_len_a + list(a))
    returns_to_go = np.pad(np.array(R), (pad_len_R, 0), 'constant')
    mask = np.zeros(max_length)
    if pad_len_s > 0:
        mask[-pad_len_s:] = 1
    return (returns_to_go, states, actions, mask)

In [8]:
def collate_batch(batch):
    '''Collationg function'''
    result = []

    max_length = max(len(sample[0]) for sample in batch)
    for sample in batch:
        returns_to_go, states, actions, mask = process_sample(sample[0], sample[1], sample[2], max_length)

        result.append((torch.FloatTensor(returns_to_go),
                       torch.FloatTensor(states),
                       torch.FloatTensor(actions),
                       torch.LongTensor(mask)))

    return zip(*result)

In [9]:
# Initialize environment
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# Create dataset
max_length = 500  # Maximum length of a sequence
max_return = 500  # Maximum cumulative reward possible
dataset_num = 30000 # Number of trajectories in dataset
dataset = EnvDataset(env, max_length=max_length, num_trajectories=dataset_num, goal=-1)

# Splitting dataset into train and test
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

# Randomly split dataset into train and test datasets
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Create DataLoaders for train and test datasets
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True, collate_fn=collate_batch)

Generating trajectories:   0%|          | 0/30000 [00:00<?, ?it/s]

# Decision Transformer model

In [10]:
class EmbeddingLayer(nn.Module):
    def __init__(self, input_dim, embed_dim):
        super().__init__()
        self.embedding = nn.Linear(input_dim, embed_dim)

    def forward(self, x, pos_embedding):
        return self.embedding(x) + pos_embedding

class DecisionTransformer(nn.Module):
    def __init__(self, state_dim, action_dim, max_length, embed_dim, num_heads, num_layers, dim_feedforward = 2048):
        super().__init__()
        self.embed_dim = embed_dim
        self.embed_s = EmbeddingLayer(state_dim, embed_dim)
        self.embed_a = EmbeddingLayer(action_dim, embed_dim)
        self.embed_R = EmbeddingLayer(1, embed_dim)  # returns-to-go is 1D
        self.embed_t = nn.Embedding(max_length, embed_dim)
        self.embed_ln = nn.LayerNorm(embed_dim)

        self.transformer = nn.Transformer(
            d_model=embed_dim, nhead=num_heads,
            num_encoder_layers=num_layers, num_decoder_layers=num_layers,
            dim_feedforward=dim_feedforward, batch_first=True)

        self.pred_a = nn.Linear(embed_dim, action_dim)
        self.max_length = max_length

    def forward(self, R, s, a, t, mask = None):
        #print(f"{R.shape=}, {s.shape=}, {a.shape=}, {t.shape=}")
        # Compute embeddings for tokens
        pos_embedding = self.embed_t(t)
        s_embedding = self.embed_s(s, pos_embedding)
        a_embedding = self.embed_a(a, pos_embedding)
        R_embedding = self.embed_R(R, pos_embedding)

        # Interleave tokens
        input_embeds = torch.stack((R_embedding, s_embedding, a_embedding), dim=1).permute(0, 2, 1, 3).reshape(s.size(0), 3*s.size(1), self.embed_dim)
        input_embeds = self.embed_ln(input_embeds)

        mask_size = s.size(1)*3

        # Make mask to correspond
        if mask is not None:
            mask = torch.stack((mask, mask, mask), dim=1).permute(0, 2, 1).reshape(s.size(0), mask_size)
            mask = mask.bool()
        else:
            mask = torch.zeros((s.size(0), mask_size)).bool().to(DEVICE)

        attn_mask = self.transformer.generate_square_subsequent_mask(sz=mask_size).to(DEVICE)
        attn_mask = torch.isfinite(attn_mask) # convert to bool
        attn_mask = ~attn_mask # True - is not allowed to attend, False - unchanged
        # Use transformer to get hidden states
        hidden_states = self.transformer(input_embeds, input_embeds,
                                         src_key_padding_mask=mask,
                                         tgt_key_padding_mask=mask,
                                         memory_key_padding_mask=mask,
                                         memory_is_causal=True,
                                         src_is_causal=True,
                                         tgt_is_causal=True,
                                         src_mask=attn_mask,
                                         tgt_mask=attn_mask,
                                         memory_mask=attn_mask
                                         )

        # Get hidden states representations such that
        # hidden_states[:, 0, t] is a hidden state after attending [sequence,r_t]
        # hidden_states[:, 1, t] is a hidden state after attending [sequence,r_t,s_t]
        # hidden_states[:, 2, t] is a hidden state after attending [sequence,r_t,s_t,a_t]
        hidden_states = hidden_states.reshape(s.size(0), s.size(1), 3, self.embed_dim).permute(0, 2, 1, 3)

        # We care about actions only
        a_hidden = hidden_states[:, 1, :]

        # Predict action
        return self.pred_a(a_hidden)

# Train, validation, eval functions

In [11]:
def evaluate_model(model, env, max_length, target_return):
    model.eval()
    state = env.reset()
    done = False
    total_reward = 0
    R = [target_return]
    s, a = [state], [encode_one_hot(0, env.action_space.n)]
    timesteps = np.arange(max_length)
    with torch.no_grad():
        while not done and len(s) < max_length:
            returns_to_go, states, actions, mask = process_sample(R, s, a, len(s))
            returns_to_go_tensor = torch.tensor(returns_to_go, dtype=torch.float).unsqueeze(0).unsqueeze(-1).to(DEVICE)
            states_tensor = torch.tensor(states, dtype=torch.float).unsqueeze(0).to(DEVICE)
            actions_tensor = torch.tensor(actions, dtype=torch.float).unsqueeze(0).to(DEVICE)
            timesteps_tensor = torch.tensor(timesteps, dtype=torch.long)[:len(s)].unsqueeze(0).to(DEVICE)
            mask_tensor = torch.tensor(mask, dtype=torch.long).unsqueeze(0).to(DEVICE)

            action = model(returns_to_go_tensor / target_return,
                           states_tensor,
                           actions_tensor,
                           timesteps_tensor,
                           mask_tensor)

            # Remove batch dim and take the hidden output for the last predicted action
            action = action.squeeze(0)[-1]
            # Most likely action to be taken
            action = torch.argmax(torch.softmax(action, dim=-1))
            action = action.cpu().detach().item()
            state, reward, done, _ = env.step(action)
            total_reward += reward

            # Update sequences
            R.append(R[-1] - reward)
            s.append(state)

            # Update the latest action
            a[-1] = encode_one_hot(action, env.action_space.n)

            # Push a placeholder
            a.append(encode_one_hot(0, env.action_space.n))

    return total_reward

In [12]:
def validate_model(model, dataloader, max_length, target_return):
    model.eval()  # Set the model to evaluation mode
    criterion = torch.nn.CrossEntropyLoss()
    timesteps = torch.LongTensor(np.arange(max_length)).unsqueeze(0).to(DEVICE)

    total_loss = 0
    with torch.no_grad():
        pbar = tqdm(dataloader, total=len(dataloader), desc="Validation")
        for returns_to_go, states, actions, masks in pbar:
            tensor_actions = torch.stack(actions).to(DEVICE)
            tensor_masks = torch.stack(masks).to(DEVICE)
            tensor_returns_to_go = torch.stack(returns_to_go).unsqueeze(-1).to(DEVICE)
            tensor_states = torch.stack(states).to(DEVICE)
            batch_timesteps = timesteps.repeat(len(states), 1)

            predicted_actions = model(tensor_returns_to_go / target_return,
                                      tensor_states,
                                      tensor_actions,
                                      batch_timesteps,
                                      tensor_masks)

            tensor_actions = tensor_actions.reshape(-1, actions[0].shape[1])
            predicted_actions = predicted_actions.reshape(-1, actions[0].shape[1])

            tensor_actions = tensor_actions[tensor_masks.reshape(-1) == 0]
            predicted_actions = predicted_actions[tensor_masks.reshape(-1) == 0]
            loss = criterion(predicted_actions, tensor_actions.detach())

            total_loss += loss.item()
            pbar.set_postfix({"loss":loss.item()})

    average_loss = total_loss / len(dataloader)
    print(f"Average validation loss: {average_loss}")
    return average_loss

In [13]:
def train_model(model, train_dataloader, test_dataloader, optimizer, epochs, grad_clip_norm, target_return):
    model.train()
    timesteps = torch.LongTensor(np.arange(max_length)).unsqueeze(0).to(DEVICE)
    criterion = torch.nn.CrossEntropyLoss()
    for epoch in tqdm(range(epochs), desc="Epoch"):
        total_loss = 0
        pbar = tqdm(train_dataloader, desc="Batch", total=len(train_dataloader))
        for returns_to_go, states, actions, masks in pbar:
            optimizer.zero_grad()
            tensor_actions = torch.stack(actions).to(DEVICE)
            tensor_masks = torch.stack(masks).to(DEVICE)
            tensor_returns_to_go = torch.stack(returns_to_go).unsqueeze(-1).to(DEVICE)
            tensor_states = torch.stack(states).to(DEVICE)
            batch_timesteps = timesteps.repeat(len(states), 1)

            predicted_actions = model(tensor_returns_to_go / target_return,
                                      tensor_states,
                                      tensor_actions,
                                      batch_timesteps,
                                      tensor_masks)

            tensor_actions = tensor_actions.reshape(-1, actions[0].shape[1])
            predicted_actions = predicted_actions.reshape(-1, actions[0].shape[1])

            tensor_actions = tensor_actions[tensor_masks.reshape(-1) == 0]
            predicted_actions = predicted_actions[tensor_masks.reshape(-1) == 0]
            loss = criterion(predicted_actions, tensor_actions.detach())

            loss.backward()

            if grad_clip_norm >= 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)
            optimizer.step()
            total_loss += loss.item()
            pbar.set_postfix({"loss":loss.item()})

        print(f"Epoch {epoch}, Loss: {total_loss / len(train_dataloader)}")
        average_reward = np.mean([evaluate_model(model, env, max_length=max_length, target_return=max_return) for _ in tqdm(range(100), desc="Eval epsiode")])
        print(f"Average Total Reward at epoch {epoch}: {average_reward}")

        validate_model(model, test_dataloader, max_length, target_return=max_return)

# Initial eval & Train

In [14]:
# Initialize the model
embed_dim = 128  # Embedding dimension
num_heads = 2  # Number of attention heads
num_layers = 3  # Number of transformer layers
model = DecisionTransformer(state_dim, action_dim, max_length, embed_dim, num_heads, num_layers).to(DEVICE)
grad_clip_norm = 0.25 # Max grad norm, -1 means no grad clip norm

In [15]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [16]:
# Untrained eval
average_reward = np.mean([evaluate_model(model, env, max_length=max_length, target_return=max_return) for _ in tqdm(range(10), desc="Untrained eval")])
print(f"Untrained Average Reward: {average_reward}")

Untrained eval:   0%|          | 0/10 [00:00<?, ?it/s]

Untrained Average Reward: 17.1


In [17]:
train_model(model, train_dataloader, test_dataloader, optimizer, epochs=5, grad_clip_norm=grad_clip_norm, target_return=max_return)

Epoch:   0%|          | 0/5 [00:00<?, ?it/s]

Batch:   0%|          | 0/3000 [00:00<?, ?it/s]

Epoch 0, Loss: 0.21201970035086076


Eval epsiode:   0%|          | 0/100 [00:00<?, ?it/s]

Average Total Reward at epoch 0: 309.82


Validation:   0%|          | 0/750 [00:00<?, ?it/s]

Average validation loss: 0.09930670402447382


Batch:   0%|          | 0/3000 [00:00<?, ?it/s]

Epoch 1, Loss: 0.07029800980289777


Eval epsiode:   0%|          | 0/100 [00:00<?, ?it/s]

Average Total Reward at epoch 1: 405.27


Validation:   0%|          | 0/750 [00:00<?, ?it/s]

Average validation loss: 0.056869195863604546


Batch:   0%|          | 0/3000 [00:00<?, ?it/s]

Epoch 2, Loss: 0.05393536743024985


Eval epsiode:   0%|          | 0/100 [00:00<?, ?it/s]

Average Total Reward at epoch 2: 442.02


Validation:   0%|          | 0/750 [00:00<?, ?it/s]

Average validation loss: 0.04737702581286431


Batch:   0%|          | 0/3000 [00:00<?, ?it/s]

Epoch 3, Loss: 0.04628381958976388


Eval epsiode:   0%|          | 0/100 [00:00<?, ?it/s]

Average Total Reward at epoch 3: 475.53


Validation:   0%|          | 0/750 [00:00<?, ?it/s]

Average validation loss: 0.040190607354044915


Batch:   0%|          | 0/3000 [00:00<?, ?it/s]

Epoch 4, Loss: 0.039727651101226606


Eval epsiode:   0%|          | 0/100 [00:00<?, ?it/s]

Average Total Reward at epoch 4: 486.35


Validation:   0%|          | 0/750 [00:00<?, ?it/s]

Average validation loss: 0.03840503282099962


# Validation and eval

In [18]:
validate_model(model, test_dataloader, max_length, target_return=max_return)

Validation:   0%|          | 0/750 [00:00<?, ?it/s]

Average validation loss: 0.03840503288557132


0.03840503288557132

In [19]:
# Trained eval
average_reward = np.mean([evaluate_model(model, env, max_length=max_length, target_return=max_return) for _ in tqdm(range(100), desc="Trained eval")])
print(f"Trained Average Reward: {average_reward}")

Trained eval:   0%|          | 0/100 [00:00<?, ?it/s]

Trained Average Reward: 484.76
