# **T - T** *(Trajectory Transformer)*

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import numpy as np
import random
import h5py
import math

from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, Dataset

import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


### **DEVICE**

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')


Device: cuda


### LOGGING

In [3]:
writer = SummaryWriter(log_dir = './runs/TT')


### **HYPER PARAMS**

In [4]:
head_1 = 128
head_2 = 256
head_3 = 256
head_4 = 128

K = 64
batch_size = 128
num_heads = 2

reward_dim = 1

maxlen = 10_000
dropout = 0.1
num_layers = 8

state_vocab = 512
action_vocab = 128
reward_vocab = 16


### **DATA ENGINEERING**

In [5]:
path = r"C:\OFFLINE RL\hopper_medium-v2.hdf5"

with h5py.File(path, mode = 'r') as f:
    
    obs = np.array(f['observations'])
    act = np.array(f['actions'])
    rew = np.array(f['rewards'])
    next_obs = np.array(f['next_observations'])
    terminals = np.array(f['terminals'])
    timeouts = np.array(f['timeouts'])
    
done = timeouts | terminals

state_dim = obs.shape[1]
action_dim = act.shape[1]
max_action = np.abs(act).max()
dataset_size = obs.shape[0]

print(f'state dim: {state_dim} | action dim: {action_dim} | max action: {max_action} | dataset size: {dataset_size}')


state dim: 11 | action dim: 3 | max action: 0.9999945163726807 | dataset size: 1000000


### **DATA HANDLING**

In [6]:
# create sequence

episodes = []
start = 0

for i in range(len(obs)):
    
    if done[i] == 1:
        
        episodes.append((obs[start: i + 1], act[start: i + 1], rew[start: i + 1], next_obs[start: i + 1]))
        
if start > len(obs):
    
    episodes.append((obs[start: ], act[start: ], rew[start: ], next_obs[start: ]))

class HopperDataset(Dataset):
    
    def __init__(self, episodes, K):
        
        self.episodes = episodes
        self.K = K
        
    def __len__(self):
        
        return len(self.episodes)
    
    def __getitem__(self, idx):
        
        obs, act, rew, next_obs = self.episodes[idx]
        
        t_len = len(obs)
        
        if t_len >= self.K:
            
            start = random.randint(0, t_len - self.K)
                        
            obs = obs[start: start + self.K]
            act = act[start: start + self.K]
            rew = rew[start: start + self.K]
            next_obs = next_obs[start: start + self.K]
            
        else: # padding
            
            pad_len = self.K - t_len
            
            obs = np.concatenate([obs, np.zeros((pad_len, obs.shape[1]))], axis = 0)
            act = np.concatenate([act, np.zeros((pad_len, act.shape[1]))], axis = 0)
            rew = np.concatenate([rew, np.zeros(pad_len)], axis = 0)
            next_obs = np.concatenate([next_obs, np.zeros((pad_len, next_obs.shape[1]))], axis = 0)
            
        return (
            
            torch.from_numpy(obs).float().to(device),
            torch.from_numpy(act).float().to(device),
            torch.from_numpy(rew).float().to(device).unsqueeze(-1),
            torch.from_numpy(next_obs).float().to(device)
            
        )
    

hopper_dataset = HopperDataset(episodes, K)

data_loader = DataLoader(hopper_dataset, batch_size, shuffle = True, drop_last = True)


### **INPUT EMBEDDINGS**

In [7]:
class input_embedding(nn.Module):
    
    def __init__(self, state_dim = state_dim, action_dim = action_dim, reward_dim = reward_dim, head_1 = head_1, head_2 = head_2, head_3 = head_3, head_4 = head_4):
        super(input_embedding, self).__init__()
        
        def create_embed(input_dim):
            
            embed = nn.Sequential(
                
                nn.Linear(input_dim, head_1),
                nn.LayerNorm(head_1),
                nn.SiLU(),
                
                nn.Linear(head_1, head_2),
                nn.LayerNorm(head_2),
                nn.SiLU(),
                
                nn.Linear(head_2, head_3),
                nn.LayerNorm(head_3),
                nn.SiLU(),
                
                nn.Linear(head_3, head_4)
                
            )
            
            return embed
        
        # create embeds 
        
        self.state_embed = create_embed(state_dim)
        self.action_embed = create_embed(action_dim)
        self.reward_embed = create_embed(reward_dim)
        
    def forward(self, state, action, reward):
        
        state_embed = self.state_embed(state)
        action_embed = self.action_embed(action)
        reward_embed = self.reward_embed(reward)
        
        return state_embed, action_embed, reward_embed
    
    

### **SETUP**


In [8]:
INPUT_EMBEDDINGS = input_embedding().to(device)
print(INPUT_EMBEDDINGS)


input_embedding(
  (state_embed): Sequential(
    (0): Linear(in_features=11, out_features=128, bias=True)
    (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (2): SiLU()
    (3): Linear(in_features=128, out_features=256, bias=True)
    (4): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (5): SiLU()
    (6): Linear(in_features=256, out_features=256, bias=True)
    (7): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (8): SiLU()
    (9): Linear(in_features=256, out_features=128, bias=True)
  )
  (action_embed): Sequential(
    (0): Linear(in_features=3, out_features=128, bias=True)
    (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (2): SiLU()
    (3): Linear(in_features=128, out_features=256, bias=True)
    (4): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (5): SiLU()
    (6): Linear(in_features=256, out_features=256, bias=True)
    (7): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (8): SiLU()
    (9): Linear

### **POSITIONAL ENCODING**

In [9]:
class pos_encoding(nn.Module):
    
    def __init__(self, maxlen = maxlen, head_4 = head_4):
        super().__init__()
        
        pe = torch.zeros(maxlen, head_4)
        pos = torch.arange(0, maxlen, dtype = torch.float32).unsqueeze(1)
        
        div_term = torch.exp(torch.arange(0, head_4, 2).float() * (-math.log(10_000.0) / head_4))
        
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        
        seq_length = x.size(1)
        
        x = x + self.pe[:, :seq_length]
        
        return x


### **SETUP**

In [10]:
POSITIONAL_ENCODING = pos_encoding()

print(POSITIONAL_ENCODING)


pos_encoding()


### **MASKING**

In [11]:
def casual_mask(seq_length, device):
    
    mask = torch.tril(torch.ones((seq_length, seq_length), device = device))
    
    return mask.unsqueeze(0).unsqueeze(0)


### **MULTI HEAD ATTENTION**

In [12]:
class Multi_Head_Attention(nn.Module):
    
    def __init__(self, head_4 = head_4, num_heads = num_heads):
        super(Multi_Head_Attention, self).__init__()
        
        assert head_4 % num_heads == 0
        
        self.head_dim = head_4 // num_heads
        self.num_heads = num_heads
        
        self.Q = nn.Linear(head_4, head_4)
        self.V = nn.Linear(head_4, head_4)
        self.K = nn.Linear(head_4, head_4)
        
        self.norm = nn.LayerNorm(head_4)
        self.proj = nn.Linear(head_4, head_4)
        
        
    def forward(self, x):
        
        B, T, D = x.size()
        
        Q = self.Q(x)
        K = self.K(x)
        V = self.V(x)
        
        Q = Q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        mask = casual_mask(T, Q.device)
        
        scores = scores.masked_fill(mask == 0, -np.inf)
        
        weights = torch.softmax(scores, dim = -1)
        
        attn = torch.matmul(weights, V)
        
        attn = attn.transpose(1, 2).contiguous().view(B, T, D)
        
        out = self.norm(self.proj(attn))
        
        return out


### **SET UP**

In [13]:
MHA = Multi_Head_Attention()

print(MHA)


Multi_Head_Attention(
  (Q): Linear(in_features=128, out_features=128, bias=True)
  (V): Linear(in_features=128, out_features=128, bias=True)
  (K): Linear(in_features=128, out_features=128, bias=True)
  (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (proj): Linear(in_features=128, out_features=128, bias=True)
)


### **FEED FORWARD NETWORK**

In [14]:
class feed_forward_network(nn.Module):
    
    def __init__(self, head_4 = head_4, head_3 = head_3):
        super(feed_forward_network, self).__init__()
        
        self.pre_norm = nn.LayerNorm(head_4)
        
        self.ffn = nn.Sequential(
            
            nn.Linear(head_4, head_3),
            nn.SiLU(),
            
            nn.Linear(head_3, head_4)
            
        )
        
    def forward(self, x):
        
        
        norm = self.pre_norm(x)
        
        ffn = self.ffn(norm)
        
        x = x + ffn
        
        return x


### **SETUP**

In [15]:
FEED_FORWARD_NETWORK = feed_forward_network()

print(FEED_FORWARD_NETWORK)


feed_forward_network(
  (pre_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (ffn): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): SiLU()
    (2): Linear(in_features=256, out_features=128, bias=True)
  )
)


### **TRANSFORMER BLOCK**

In [16]:
class transformer_block(nn.Module):
    
    def __init__(self, dropout = dropout, head_4 = head_4):
        super(transformer_block, self).__init__()
        
        self.norm1 = nn.LayerNorm(head_4)
        self.mha = Multi_Head_Attention()
        self.drop = nn.Dropout(dropout)
        
        self.norm2 = nn.LayerNorm(head_4)
        self.ffn = feed_forward_network()
        self.drop2 = nn.Dropout(dropout)

    def forward(self, x):
        
        norm1 = self.norm1(x)
        attn = self.mha(norm1)
        x = x + self.drop(attn)
        
        norm2 = self.norm2(x)
        ffn = self.ffn(norm2)
        x = x + self.drop2(ffn)
        
        return x


### **SETUP**

In [17]:
TRANSFORMER_BLOCK = transformer_block()

print(TRANSFORMER_BLOCK)


transformer_block(
  (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (mha): Multi_Head_Attention(
    (Q): Linear(in_features=128, out_features=128, bias=True)
    (V): Linear(in_features=128, out_features=128, bias=True)
    (K): Linear(in_features=128, out_features=128, bias=True)
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (proj): Linear(in_features=128, out_features=128, bias=True)
  )
  (drop): Dropout(p=0.1, inplace=False)
  (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (ffn): feed_forward_network(
    (pre_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (ffn): Sequential(
      (0): Linear(in_features=128, out_features=256, bias=True)
      (1): SiLU()
      (2): Linear(in_features=256, out_features=128, bias=True)
    )
  )
  (drop2): Dropout(p=0.1, inplace=False)
)


### **PREDICTION HEAD**

In [18]:
class prediction_net(nn.Module):
    
    def __init__(self, head_4 = head_4, state_dim = state_dim, action_dim = action_dim, reward_dim = reward_dim):
        super(prediction_net, self).__init__()
        
        
        def create(output_dim):
        
            mlp = nn.Sequential(
                
                nn.Linear(head_4, head_4),
                nn.SiLU(),
                nn.Linear(head_4, output_dim)
                
            )
            
            return mlp
        
        self.pred_state = create(state_dim)
        self.pred_action = create(action_dim)
        self.pred_reward = create(reward_dim)

    def forward(self, x):
        
        pred_state = self.pred_state(x)
        pred_action = self.pred_action(x)
        pred_reward = self.pred_reward(x)
        
        return pred_state, pred_action, pred_reward


### **T - T**

In [19]:
class traj_transformer(nn.Module):
    
    def __init__(self, num_layers = num_layers):
        super(traj_transformer, self).__init__()
        
        # embedding
        
        self.embedding = input_embedding()
        
        # pos encoding
        
        self.pos_encoding = pos_encoding()
        
        # tranformer block
        
        self.layers = nn.ModuleList([
            
            transformer_block()
            
            for _ in range(num_layers)
            
        ])
        
        # policy head ?
        
        self.prediction = prediction_net()
        
        # normalization
        
        self.apply(self.init_weight)
        
    def forward(self, state, action, reward):
        
        # get embedding
        
        state_embed, action_embed, reward_embed = self.embedding.forward(state, action, reward)
        
        # positional encoding
        
        cat = torch.cat([state_embed, action_embed, reward_embed], dim = 1)
        
        x = self.pos_encoding.forward(cat)
        
        # transformer block
        
        for layer in self.layers:
            
            x = layer(x)

        pred_state, pred_action, pred_reward = self.prediction.forward(x)

        return pred_state, pred_action, pred_reward
        
    def init_weight(self, m):
        
        if isinstance(m, nn.Linear):
            
            nn.init.orthogonal_(m.weight)
            
            if m.bias is not None:
                
                nn.init.zeros_(m.bias)
                
                

### **SETUP**

In [20]:
TRAJECTORY_TRANSFORMER = traj_transformer().to(device)

print(TRAJECTORY_TRANSFORMER)


traj_transformer(
  (embedding): input_embedding(
    (state_embed): Sequential(
      (0): Linear(in_features=11, out_features=128, bias=True)
      (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (2): SiLU()
      (3): Linear(in_features=128, out_features=256, bias=True)
      (4): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (5): SiLU()
      (6): Linear(in_features=256, out_features=256, bias=True)
      (7): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (8): SiLU()
      (9): Linear(in_features=256, out_features=128, bias=True)
    )
    (action_embed): Sequential(
      (0): Linear(in_features=3, out_features=128, bias=True)
      (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (2): SiLU()
      (3): Linear(in_features=128, out_features=256, bias=True)
      (4): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (5): SiLU()
      (6): Linear(in_features=256, out_features=256, bias=True)
      (7): LayerNorm(

### **NO. OF LEARNABLE PARAMS**

In [21]:
params = [p.numel() for p in TRAJECTORY_TRANSFORMER.parameters()]

print(f'Number of learnable params: {np.sum(params)}')


Number of learnable params: 1516687


### **OPTIMIZER**

In [22]:
# lr

prediction_lr = 3e-4
embed_lr = 1e-4
transformer_lr = 1e-5

# iterations

total_iters = 10_000
T_max = 50_000
epochs = 500

# params

prediction_params = TRAJECTORY_TRANSFORMER.prediction.parameters()
embed_params = TRAJECTORY_TRANSFORMER.embedding.parameters()
transformer_params = TRAJECTORY_TRANSFORMER.layers.parameters()

# optimizer

OPTIMIZER = optim.AdamW([
    
    {'params': prediction_params, 'lr': prediction_lr, 'weight_decay': 0},
    {'params': embed_params, 'lr': embed_lr, 'weight_decay': 1e-6},
    {'params': transformer_params, 'lr': transformer_lr, 'weight_decay': 1e-6}
    
])

# Scheduler

warmup = optim.lr_scheduler.LinearLR(OPTIMIZER, start_factor = 0.5, end_factor = 1, total_iters = total_iters)
cosine = optim.lr_scheduler.CosineAnnealingLR(OPTIMIZER, T_max = T_max - total_iters, eta_min = 1e-5)

SCHEDULER = optim.lr_scheduler.SequentialLR(OPTIMIZER, [warmup, cosine], milestones = [total_iters])


### **TRAINING**

In [23]:
def train_loop(loader, epochs = epochs):
    
    import os

    torch.autograd.set_detect_anomaly(True)

    os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
        
    for epoch in range(epochs):
        
        total_loss = 0.0
        
        for states, actions, rewards, next_obs in loader:
            
            pred_state, pred_action, pred_reward = TRAJECTORY_TRANSFORMER.forward(states, actions, rewards)
            
            B, seq_length, act_dim = pred_action.size()

            action_indices = torch.arange(1, seq_length, 3, device = pred_action.device)
            state_indices = torch.arange(0, seq_length, 3, device = pred_state.device)
            reward_indices = torch.arange(2, seq_length, 3, device = pred_reward.device)
        
            pred_states  = pred_state[:, state_indices, :]
            pred_actions = pred_action[:, action_indices, :]
            pred_rewards = pred_reward[:, reward_indices, :]

            loss_1 = F.smooth_l1_loss(pred_states, states)
            loss_2 = F.smooth_l1_loss(pred_rewards, rewards)
            loss_3 = F.smooth_l1_loss(pred_actions, actions)
            
            batch_loss = loss_1 + loss_2 + loss_3
            
            OPTIMIZER.zero_grad()
            batch_loss.backward()
            torch.nn.utils.clip_grad_norm_(TRAJECTORY_TRANSFORMER.parameters(), max_norm = 0.5)
            OPTIMIZER.step()
            SCHEDULER.step()
            
            total_loss += batch_loss.item()
        
        avg_loss = total_loss / len(loader)
        
        writer.add_scalar('Avg TT Loss', avg_loss, epoch)
        writer.flush()
        
        print(f'Epoch: {epoch} | loss: {avg_loss:.3f}')
            

In [24]:

train_loop(data_loader)


Epoch: 0 | loss: 8.631
Epoch: 1 | loss: 5.081
Epoch: 2 | loss: 3.621
Epoch: 3 | loss: 2.675
Epoch: 4 | loss: 2.138
Epoch: 5 | loss: 1.798
Epoch: 6 | loss: 1.583
Epoch: 7 | loss: 1.445
Epoch: 8 | loss: 1.342
Epoch: 9 | loss: 1.279
Epoch: 10 | loss: 1.215
Epoch: 11 | loss: 1.139
Epoch: 12 | loss: 1.098
Epoch: 13 | loss: 1.053
Epoch: 14 | loss: 1.007
Epoch: 15 | loss: 0.962
Epoch: 16 | loss: 0.917
Epoch: 17 | loss: 0.861
Epoch: 18 | loss: 0.821
Epoch: 19 | loss: 0.793
Epoch: 20 | loss: 0.761
Epoch: 21 | loss: 0.736
Epoch: 22 | loss: 0.714
Epoch: 23 | loss: 0.682
Epoch: 24 | loss: 0.665
Epoch: 25 | loss: 0.656
Epoch: 26 | loss: 0.652
Epoch: 27 | loss: 0.646
Epoch: 28 | loss: 0.640
Epoch: 29 | loss: 0.628
Epoch: 30 | loss: 0.621
Epoch: 31 | loss: 0.617
Epoch: 32 | loss: 0.610
Epoch: 33 | loss: 0.603
Epoch: 34 | loss: 0.594
Epoch: 35 | loss: 0.593
Epoch: 36 | loss: 0.582
Epoch: 37 | loss: 0.573
Epoch: 38 | loss: 0.579
Epoch: 39 | loss: 0.575
Epoch: 40 | loss: 0.566
Epoch: 41 | loss: 0.560
Ep