# **D-T** *(Decision Transformer)*

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import h5py
import numpy as np
import math

from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, Dataset

from sklearn.model_selection import train_test_split


### **DEVICE HANDLING**

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')


Device: cuda


### **LOGGING**

In [3]:
writer = SummaryWriter(log_dir = '.runs/DT')


### **DATA HANDLING**

In [4]:
file_path = r"C:\OFFLINE RL\hopper_medium-v2.hdf5"

with h5py.File(file_path, 'r') as f:
    
    obs = np.array(f['observations'])
    act = np.array(f['actions'])
    rew = np.array(f['rewards'])
    terminals = np.array(f['terminals'])
    timeouts = np.array(f['timeouts'])

done = terminals | timeouts
    
state_dim = obs.shape[1]
action_dim = act.shape[1]
max_action = np.abs(act).max()
dataset_size = obs.shape[0]

print(f'state dim: {state_dim} | action dim: {action_dim} | dataset size: {dataset_size} | max actions: {max_action}')


state dim: 11 | action dim: 3 | dataset size: 1000000 | max actions: 0.9999945163726807


### **HYPER PARAMS**

In [5]:
batch_size = 64

head_1 = 64
head_2 = 128
head_3 = 128
head_4 = 64

num_heads = 2
d_model = 128
max_len = 10_000

K = 64

rtg_dim = 1
dropout = 0.1
num_layers = 8
discount = 1.0


### **CREATE SEQUENCE**

In [6]:
def split_into_episodes(obs, act, rew, done = done):
    
    episodes = []
    start = 0
    
    for i in range(len(obs)):
        
        if done[i] == 1:
            
            episodes.append((obs[start:i+1], act[start:i+1], rew[start:i+1]))
            start = i + 1
            
    if start < len(obs):  # in case last traj doesn’t end with terminal
        
        episodes.append((obs[start:], act[start:], rew[start:]))
        
    return episodes


episodes = split_into_episodes(obs, act, rew)


### **HOPPER DATASET**

In [7]:
class HopperDataset(Dataset):
    
    def __init__(self, episodes, K=20):
        
        """
        episodes: list of (obs, act, rew) trajectories
        K: context length (sequence length)
        """
        
        self.episodes = episodes
        self.K = K

    def __len__(self):
        
        return len(self.episodes)

    def __getitem__(self, idx):
        
        obs, act, rew = self.episodes[idx]

        # pad or truncate to fixed K length
        tlen = len(obs)
        
        if tlen >= self.K:
            
            start = np.random.randint(0, tlen - self.K + 1)
            obs = obs[start:start+self.K]
            act = act[start:start+self.K]
            rew = rew[start:start+self.K]
            
        else:  # pad
            
            pad_len = self.K - tlen
            obs = np.concatenate([obs, np.zeros((pad_len, obs.shape[1]))], axis=0)
            act = np.concatenate([act, np.zeros((pad_len, act.shape[1]))], axis=0)
            rew = np.concatenate([rew, np.zeros(pad_len)], axis=0)

        return (
            torch.tensor(obs, dtype=torch.float32, device = device),
            torch.tensor(act, dtype=torch.float32, device = device),
            torch.tensor(rew, dtype=torch.float32, device = device).unsqueeze(1),
        )

    
    

# convert numpy array to tensor

Hopper_Data = HopperDataset(episodes, K = K)

train_data, test_data = train_test_split(Hopper_Data, test_size = 0.2, random_state = 42)
Train_Loader, Test_Loader = DataLoader(train_data, batch_size, shuffle = True, drop_last = True), DataLoader(test_data, batch_size, shuffle = True, drop_last = True)


### **INPUT EMBEDDINGS**


In [8]:
class prepare_embeds(nn.Module):
    
    def __init__(self, state_dim = state_dim, action_dim = action_dim, rtg_dim = rtg_dim, head_1 = head_1, head_2 = head_2, head_3 = head_3, head_4 = head_4):
        super(prepare_embeds, self).__init__()
        
        # embed
        
        def create_embed(input_dim):
            
            mlp = nn.Sequential(
                
                nn.Linear(input_dim, head_1),
                nn.LayerNorm(head_1),
                nn.SiLU(),
                
                nn.Linear(head_1, head_2),
                nn.LayerNorm(head_2),
                nn.SiLU(),
    
                nn.Linear(head_2, head_3),
                nn.SiLU()
                
            )
            
            return mlp
        
        # obs embed
        
        self.obs_embed = create_embed(state_dim)
        self.act_embed = create_embed(action_dim)
        self.rtg_embed = create_embed(rtg_dim)
        
    def forward(self, state, action, rtg):
        
        s_embed = self.obs_embed(state)
        a_embed = self.act_embed(action)
        rtg_embed = self.rtg_embed(rtg)
        
        return s_embed, a_embed, rtg_embed


### **SETUP**

In [9]:
INPUT_EMBEDDING = prepare_embeds().to(device)

print(INPUT_EMBEDDING)


prepare_embeds(
  (obs_embed): Sequential(
    (0): Linear(in_features=11, out_features=64, bias=True)
    (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (2): SiLU()
    (3): Linear(in_features=64, out_features=128, bias=True)
    (4): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (5): SiLU()
    (6): Linear(in_features=128, out_features=128, bias=True)
    (7): SiLU()
  )
  (act_embed): Sequential(
    (0): Linear(in_features=3, out_features=64, bias=True)
    (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (2): SiLU()
    (3): Linear(in_features=64, out_features=128, bias=True)
    (4): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (5): SiLU()
    (6): Linear(in_features=128, out_features=128, bias=True)
    (7): SiLU()
  )
  (rtg_embed): Sequential(
    (0): Linear(in_features=1, out_features=64, bias=True)
    (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (2): SiLU()
    (3): Linear(in_features=64, out_features=12

### **POSITIONAL ENCODER**

In [10]:
class positional_encoder(nn.Module):
    
    def __init__(self, max_len = max_len, d_model = d_model):
        super(positional_encoder, self).__init__()
        
        self.pos = torch.arange(0, max_len, dtype = torch.float32).unsqueeze(1)
        pe = torch.zeros(max_len, d_model)
        
        self.div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10_000) / d_model))
        
        pe[:, 0::2] = torch.sin(self.pos * self.div_term)
        pe[:, 1::2] = torch.cos(self.pos * self.div_term)
        
        pe = pe.unsqueeze(0)
        
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        
        seq_length = x.size(1)
        
        x = x + self.pe[:, :seq_length]
        
        return x


### **SET UP**

In [11]:
POSITIONAL_ENCODER = positional_encoder().to(device)


### **ATTENTION MECHANISM**

In [12]:
class single_dot_attention(nn.Module):
    
    def __init__(self, head_1 = head_1, head_2 = head_2):
        super(single_dot_attention, self).__init__()
        
        
        self.Q = nn.Linear(head_1, head_2)
        self.K = nn.Linear(head_1, head_2)
        self.V = nn.Linear(head_1, head_2)
        
        self.norm = nn.LayerNorm(head_2)
        self.proj = nn.Linear(head_2, head_2)
        
    def forward(self, x):
        
        Q = self.Q(x)
        K = self.K(x)
        V = self.V(x)
        
        score = torch.matmul(Q , K.transpose(-2, -1)) / math.sqrt(Q.size(-1))
        
        weight = F.softmax(score, dim = -1)
        
        attn_out = torch.matmul(weight, V)
        
        proj_out = self.norm(self.proj(attn_out))
        
        return proj_out


### **SETUP**

In [13]:
Single_Dot_Attention = single_dot_attention().to(device)

print(Single_Dot_Attention)


single_dot_attention(
  (Q): Linear(in_features=64, out_features=128, bias=True)
  (K): Linear(in_features=64, out_features=128, bias=True)
  (V): Linear(in_features=64, out_features=128, bias=True)
  (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (proj): Linear(in_features=128, out_features=128, bias=True)
)


### **CASUAL MASK**

In [14]:
def casual_mask(seq_length, device):
    
    mask = torch.tril(torch.ones((seq_length, seq_length), device = device))
    
    return mask.unsqueeze(0).unsqueeze(0)


### **MULTI HEAD ATTENTION**

In [15]:
class multi_head_attention(nn.Module):
    
    def __init__(self, head_2 = head_2, num_heads = num_heads):
        super(multi_head_attention, self).__init__()
        
        # per head dim
        
        assert head_2 % num_heads == 0
        
        self.head_dim = head_2 // num_heads
        self.num_heads = num_heads
        
        # project Q, K, V
        
        self.Q = nn.Linear(head_2, head_2)
        self.V = nn.Linear(head_2, head_2)
        self.K = nn.Linear(head_2, head_2)
        
        # final projection
        
        self.norm = nn.LayerNorm(head_2)
        self.proj = nn.Linear(head_2, head_2)
        
    def forward(self, x):
        
        # get dimension : [ batch, time , dim ]
        
        B, T, D = x.size()
        
        # project
        
        Q = self.Q(x)
        V = self.V(x)
        K = self.K(x)
        
        # nor transform: [ batch, time, num head, neuros per head]
        
        Q = Q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        
        # get scores
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (math.sqrt(self.head_dim))
        
        mask = casual_mask(T, x.device)

        scores = scores.masked_fill(mask == 0, -torch.inf)
        
        weights = F.softmax(scores, dim = -1)
        
        attn = torch.matmul(weights, V)
        
        # now transform again : [ batch, time, num heads * neurons per head ]
        
        attn = attn.transpose(1, 2).contiguous().view(B, T, D)
        
        # final pr0jection
        
        out = self.norm(x + self.proj(attn))
        
        return out


### **SETUP**

In [16]:
Multi_Head_Attention = multi_head_attention().to(device)

print(Multi_Head_Attention)


multi_head_attention(
  (Q): Linear(in_features=128, out_features=128, bias=True)
  (V): Linear(in_features=128, out_features=128, bias=True)
  (K): Linear(in_features=128, out_features=128, bias=True)
  (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (proj): Linear(in_features=128, out_features=128, bias=True)
)


### **FEED FORWARD NN**

In [17]:
class feed_forward_network(nn.Module):
    
    def __init__(self, head_2 = head_2, head_3 = head_3, head_4 = head_4):
        super(feed_forward_network, self).__init__()
        
        # Pre norm layer
        
        self.pre_norm = nn.LayerNorm(head_2)
        
        # mlp 
        
        self.mlp = nn.Sequential(
        
            nn.Linear(head_2, head_3),
            nn.SiLU(),
            
            nn.Linear(head_3, head_4),
            nn.SiLU(),
            
            nn.Linear(head_4, head_2)
            
        )

        
    def forward(self, x):
        
        # pre norm
        
        pre_norm = self.pre_norm(x)
        
        # mlp
        
        ffn = self.mlp(pre_norm)
        
        # residula network

        out = x + ffn

        return out


### **SETUP**

In [18]:
FFN = feed_forward_network().to(device)

print(FFN)


feed_forward_network(
  (pre_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (mlp): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): SiLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): SiLU()
    (4): Linear(in_features=64, out_features=128, bias=True)
  )
)


### **TRANSFORMER BLOCK**

In [19]:
class transformer_block(nn.Module):
    
    def __init__(self, dropout = dropout, head_2 = head_2):
        super(transformer_block, self).__init__()

        
        self.norm1 = nn.LayerNorm(head_2)
        self.multi_head = multi_head_attention()
        self.dropout1 = nn.Dropout(dropout)
        
        self.norm2 = nn.LayerNorm(head_2)
        self.ffn = feed_forward_network()
        self.dropout2 = nn.Dropout(dropout)
        
    def forward(self, x):
        
        norm1 = self.norm1(x)
        attn = self.multi_head(norm1)
        x = x + self.dropout1(attn)
        
        ffn = self.ffn(self.norm2(x))
        x = x + self.dropout2(ffn)
        
        return x


### **SET UP**

In [20]:
Transformer_Block = transformer_block(dropout)

print(Transformer_Block)


transformer_block(
  (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (multi_head): multi_head_attention(
    (Q): Linear(in_features=128, out_features=128, bias=True)
    (V): Linear(in_features=128, out_features=128, bias=True)
    (K): Linear(in_features=128, out_features=128, bias=True)
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (proj): Linear(in_features=128, out_features=128, bias=True)
  )
  (dropout1): Dropout(p=0.1, inplace=False)
  (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (ffn): feed_forward_network(
    (pre_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (mlp): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): SiLU()
      (2): Linear(in_features=128, out_features=64, bias=True)
      (3): SiLU()
      (4): Linear(in_features=64, out_features=128, bias=True)
    )
  )
  (dropout2): Dropout(p=0.1, inplace=False)
)


### **ACTION HEAD**

In [21]:
class policy_net(nn.Module):
    
    def __init__(self, head_2 = head_2, action_dim = action_dim, max_action = max_action):
        super(policy_net, self).__init__()
        
        self.mlp = nn.Sequential(
            
            nn.Linear(head_2, head_2),
            nn.SiLU(),
            nn.Linear(head_2, action_dim)
            
        )

        # max action
        
        self.max_action = max_action
        
    def forward(self, x):
        
        action = self.mlp(x)
        action = action * self.max_action
        
        return action


### **COMPLETE DT**

In [22]:
class decision_transformer(nn.ModuleList):
    
    def __init__(self, num_layers = num_layers):
        super(decision_transformer, self).__init__()
        
        # embeddings
        
        self.embedding = prepare_embeds()
        
        # positional encodings
        
        self.pos_encodings = positional_encoder()
        
        # Encoder block
        
        self.layers = nn.ModuleList([
            
            transformer_block().to(device)
            for _ in range(num_layers)
            
        ])
        
        # policy
        
        self.policy = policy_net()
        
        # normalization
        
        self.apply(self.init_weight)
        
    def forward(self, rtg, state, action):
        
        # Get embeddings
        
        rtg_embed, state_embed, action_embed = self.embedding.forward(state = state, action = action, rtg = rtg)

        # postions
        
        cat = torch.cat([ rtg_embed, state_embed, action_embed ], dim = 1)
                
        x = self.pos_encodings.forward(cat)
        
        for layer in self.layers:
            
            x = layer(x)
            
        action = self.policy(x)
            
        return action
    
    def init_weight(self, m):
        
        if isinstance(m, nn.Linear):
            
            nn.init.orthogonal_(m.weight)
            
            if m.bias is not None:
                
                nn.init.zeros_(m.bias)


### **SET UP**

In [23]:
Decision_Transformer = decision_transformer().to(device)

print(Decision_Transformer)


decision_transformer(
  (0): prepare_embeds(
    (obs_embed): Sequential(
      (0): Linear(in_features=11, out_features=64, bias=True)
      (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (2): SiLU()
      (3): Linear(in_features=64, out_features=128, bias=True)
      (4): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (5): SiLU()
      (6): Linear(in_features=128, out_features=128, bias=True)
      (7): SiLU()
    )
    (act_embed): Sequential(
      (0): Linear(in_features=3, out_features=64, bias=True)
      (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (2): SiLU()
      (3): Linear(in_features=64, out_features=128, bias=True)
      (4): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (5): SiLU()
      (6): Linear(in_features=128, out_features=128, bias=True)
      (7): SiLU()
    )
    (rtg_embed): Sequential(
      (0): Linear(in_features=1, out_features=64, bias=True)
      (1): LayerNorm((64,), eps=1e-05, elementwise_a

In [None]:
count = [p.numel() for p in Decision_Transformer.parameters()]

print(f'Number of learnable params: {count}')


### **OPTIMIZER**

In [24]:
# lr

policy_lr = 3e-4
embed_lr = 1e-4
transformer_lr = 1e-5

# iteration

total_iter = 50
T_max = 500

# params

policy_params = Decision_Transformer.policy.parameters()
embed_params = Decision_Transformer.embedding.parameters()
tranformer_params = Decision_Transformer.layers.parameters()

# optimizer

OPTIMIZER = optim.AdamW([
    
    {'params': policy_params, 'lr': policy_lr, 'weight_decay': 0},
    {'params': embed_params, 'lr': embed_lr, 'weight_decay': 1e-6},
    {'params': tranformer_params, 'lr': transformer_lr, 'weight_decay': 1e-6}
    
])

# Scheduler

warmup = optim.lr_scheduler.ConstantLR(OPTIMIZER, factor = 0.5, total_iters = total_iter)
cosine = optim.lr_scheduler.CosineAnnealingLR(OPTIMIZER, T_max = T_max - total_iter, eta_min = 1e-5)

SCHEDULER = optim.lr_scheduler.SequentialLR(OPTIMIZER, [warmup, cosine], [total_iter])


### **COMPUTE RTG**

In [25]:
def compute_rtg(rewards, discount = discount):
    
    rtg = []
    running = 0
    
    B, T, _ = rewards.size()
    
    for r in reversed(range(T)):
        
        running = rewards[:, r] + discount * running
        
        rtg.insert(0, running)
        
    rtg = torch.stack(rtg).to(device)
    
    rtg = rtg.transpose(1, 0)
    
    return rtg


### **TRAINING**

In [26]:
def train_loop(epochs, loader, Decision_Transformer = Decision_Transformer, OPTIMIZER = OPTIMIZER, SCHEDULER = SCHEDULER): 
    
    for epoch in range(epochs):
        
        total_loss = 0.0
        
        for states, true_action, rewards in loader:
            
            rtg = compute_rtg(rewards)
            
            pred_action = Decision_Transformer.forward(rtg, states, true_action[:, :-1])
            
            B, seq_length, act_dim = pred_action.size()

            action_indices = torch.arange(2, seq_length, 3, device = pred_action.device)
            
            pred_action = pred_action[:, action_indices, :]
            
            loss = F.mse_loss(pred_action, true_action[:, 1:], reduction = 'mean')
            
            OPTIMIZER.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(Decision_Transformer.parameters(), max_norm = 0.5)
            OPTIMIZER.step()
            SCHEDULER.step()
            
            total_loss += loss
            
        avg_loss = total_loss / len(loader)
        
        writer.add_scalar('DT_LOSS', avg_loss, epoch)
        
        print(f'Epoch: {epoch} | avg loss: {avg_loss:.3f}')


In [27]:
train_loop(500, Train_Loader)


Epoch: 0 | avg loss: 87.017




Epoch: 1 | avg loss: 21.315
Epoch: 2 | avg loss: 6.830
Epoch: 3 | avg loss: 3.228
Epoch: 4 | avg loss: 1.890
Epoch: 5 | avg loss: 1.257
Epoch: 6 | avg loss: 0.964
Epoch: 7 | avg loss: 0.777
Epoch: 8 | avg loss: 0.667
Epoch: 9 | avg loss: 0.605
Epoch: 10 | avg loss: 0.560
Epoch: 11 | avg loss: 0.531
Epoch: 12 | avg loss: 0.503
Epoch: 13 | avg loss: 0.483
Epoch: 14 | avg loss: 0.469
Epoch: 15 | avg loss: 0.456
Epoch: 16 | avg loss: 0.446
Epoch: 17 | avg loss: 0.441
Epoch: 18 | avg loss: 0.434
Epoch: 19 | avg loss: 0.428
Epoch: 20 | avg loss: 0.422
Epoch: 21 | avg loss: 0.416
Epoch: 22 | avg loss: 0.408
Epoch: 23 | avg loss: 0.401
Epoch: 24 | avg loss: 0.390
Epoch: 25 | avg loss: 0.380
Epoch: 26 | avg loss: 0.372
Epoch: 27 | avg loss: 0.363
Epoch: 28 | avg loss: 0.355
Epoch: 29 | avg loss: 0.347
Epoch: 30 | avg loss: 0.342
Epoch: 31 | avg loss: 0.337
Epoch: 32 | avg loss: 0.332
Epoch: 33 | avg loss: 0.327
Epoch: 34 | avg loss: 0.323
Epoch: 35 | avg loss: 0.318
Epoch: 36 | avg loss: 0.314
