## Decision Transformer

---

> Internship neural networks
>
> Group 4: Reinforcement learning
>
> Deadline 28.02.23 23:59

---

In [None]:
# source: https://github.com/kzl/decision-transformer
# source: https://github.com/nikhilbarhate99/min-decision-transformer

In [None]:
import math
import numpy as np
import random

import torch
import torch.nn as nn
import torch.nn.functional as F

# Define decision transformer architecture

In [None]:
class CausalAttention(nn.Module):
    '''
    This is the masked causal attention for our decision transformer

    hidden_dim: hidden dimension of transformer
    input_seq_len: 3 (states, actions, rtgs) * length of trajectory
    n_heads: number of attention heads
    drop_p: dropout probability for attention and projection
    '''
    def __init__(self, hidden_dim, input_seq_len, n_heads, drop_p):
        super().__init__()

        self.n_heads = n_heads # number of heads

        # layers for value, key and query of attention
        self.value_net = nn.Linear(hidden_dim, hidden_dim)
        self.key_net = nn.Linear(hidden_dim, hidden_dim)
        self.query_net = nn.Linear(hidden_dim, hidden_dim)

        self.proj_net = nn.Linear(hidden_dim, hidden_dim)
        
        # create Mask
        mask = torch.tril(torch.ones((input_seq_len, input_seq_len))).view(1, 1, input_seq_len, input_seq_len)

        # register buffer prevent mask to get updated while backpropagation 
        self.register_buffer('mask',mask)
        
        # regularization
        self.att_drop = nn.Dropout(drop_p)
        self.proj_drop = nn.Dropout(drop_p)

    def forward(self, x):
        batch_size, traj_length, C = x.shape # batch size, seq length, hidden_dim * n_heads

        att_dim = C // self.n_heads # attention dim

        # calculate query, key, values for all heads in batch and rearrange to (batch_size, n_heads, traj_length, att_dim) 
        values = self.value_net(x).view(batch_size, traj_length, self.n_heads, att_dim).transpose(1,2)
        keys = self.key_net(x).view(batch_size, traj_length, self.n_heads, att_dim).transpose(1,2)
        queries = self.query_net(x).view(batch_size, traj_length, self.n_heads, att_dim).transpose(1,2)

        attention_weights = queries @ keys.transpose(2,3) / math.sqrt(att_dim)
        # apply causal mask to weights
        attention_weights = attention_weights.masked_fill(self.mask[...,:traj_length,:traj_length] == 0, float('-inf'))
        
        # normalize with softmax
        normalized_weights = F.softmax(attention_weights, dim=-1)

        # attention (batch_size, n_heads, traj_length, D)
        attention = self.att_drop(normalized_weights @ values)

        # gather heads and project (batch_size, n_heads, traj_length, att_dim) -> (batch_size, traj_length, C)
        attention = attention.transpose(1, 2).contiguous().view(batch_size, traj_length,C)

        return self.proj_drop(self.proj_net(attention))

In [None]:
class Block(nn.Module):
    '''
    This is the transformer block

    hidden_dim: hidden dimension of transformer
    input_seq_len: 3 (states, actions, rtgs) * length of trajectory
    n_heads: number of attention heads
    drop_p: dropout probability for attention and projection
    '''
    def __init__(self, hidden_dim, input_seq_len, n_heads, drop_p):
        super().__init__()
        self.ln1 = nn.LayerNorm(hidden_dim)
        self.ln2 = nn.LayerNorm(hidden_dim)
        self.attention = CausalAttention(hidden_dim, input_seq_len, n_heads, drop_p)
        self.mlp = nn.Sequential(
                nn.Linear(hidden_dim, 4*hidden_dim),
                nn.GELU(),
                nn.Linear(4*hidden_dim, hidden_dim),
                nn.Dropout(drop_p),
            )
        
    def forward(self, x):
        x = x + self.attention(x)
        x = self.ln1(x)
        x = x + self.mlp(x)
        x = self.ln2(x)
        return x

In [None]:
class DecisionTransformer(nn.Module):
    '''
    This is the decision transformer

    state_dim: dimension of the states
    act_dim: dimension of the actions
    n_blocks: number of transformer blocks
    hidden_dim: hidden dimension of transformer
    context_len: length of the context our decision transformer looks at
    n_heads: number of attention heads
    drop_p: dropout probability for attention and projection
    vocab_size: number of possible actions
    max_timestep: maximum length of a game
    '''
    def __init__(self, state_dim, act_dim, n_blocks, hidden_dim, context_len,
                 n_heads, drop_p, vocab_size, max_timestep=21):
        super().__init__()

        self.act_dim = act_dim
        self.hidden_dim = hidden_dim

        ### transformer
        input_seq_len = 3 * context_len
        blocks = [Block(hidden_dim, input_seq_len, n_heads, drop_p) for _ in range(n_blocks)]
        self.transformer = nn.Sequential(*blocks)
        
        ### prediction heads
        self.predict_rtg = nn.Linear(hidden_dim, 1)
        self.predict_state = nn.Linear(hidden_dim, state_dim)
        self.predict_action = nn.Linear(hidden_dim, vocab_size, bias = False)
        self.SL = nn.Softmax(dim = 2)
        
        # embedding for actions
        self.action_embedding = nn.Embedding(vocab_size, hidden_dim)

        ### embeddings for projection
        self.embedding_layer = nn.LayerNorm(hidden_dim)
        self.t_embedding = nn.Embedding(max_timestep, hidden_dim)
        self.rtg_embedding = nn.Linear(1, hidden_dim)
        self.state_embedding = nn.Linear(state_dim, hidden_dim)

        

    def forward(self, timesteps, states, actions, returns_to_go, traj_mask=None):

        batch_size, traj_length, _ = states.shape
        
        # embeddings
        time_embeddings = self.t_embedding(timesteps)
        state_embeddings = self.state_embedding(states) + time_embeddings
        action_embeddings = self.action_embedding(actions) + time_embeddings
        returns_embeddings = self.rtg_embedding(returns_to_go) + time_embeddings
        
        # get rtg, states and actions in form (r_0, s_0, a_0, r_1, s_1, a_1, ...)
        h = torch.stack((returns_embeddings, state_embeddings, action_embeddings), dim=1).permute(0, 2, 1, 3).reshape(batch_size, 3 * traj_length, self.hidden_dim)

        # transformer and prediction
        h = self.transformer(h)
        
        h = self.embedding_layer(h)

        # predict action given r, s
        action_preds = self.predict_action(h.reshape(batch_size, traj_length, 3, self.hidden_dim).permute(0, 2, 1, 3)[:,1])
        action_target = torch.clone(actions).detach().to(device).to(torch.int64)

        loss = None
        if traj_mask is not None:
            # only consider non padded elements
            action_preds_reshape = action_preds.to(torch.float32).view(-1, self.act_dim, vocab_size)[traj_mask.view(-1,) > 0].view(-1, vocab_size)
            action_target_reshape = action_target.view(-1, self.act_dim)[traj_mask.view(-1,) > 0].view(-1)
            
            loss = F.cross_entropy(action_preds_reshape, action_target_reshape, reduction='mean')
        
        action_preds = self.SL(action_preds)
        
        return action_preds, loss

# Decision transformer agent

In [None]:
class DTAgent():
    '''
    This is the agent class for our decision transformer

    state_dim: dimension of the states
    act_dim: dimension of the actions
    n_blocks: number of transformer blocks
    hidden_dim: hidden dimension of transformer
    context_len: length of the context our decision transformer looks at
    n_heads: number of attention heads
    drop_p: dropout probability for attention and projection
    rtg_target: target reward our decision transformer wants to get
    vocab_size: number of possible actions
    max_timestep: maximum length of a game
    '''
    def __init__(self, state_dim, act_dim, n_blocks, hidden_dim, context_len,
                 n_heads, drop_p, rtg_target, vocab_size, max_timestep=21):
        super().__init__()
        
        self.state_dim = state_dim
        self.act_dim = act_dim
        self.n_blocks = n_blocks
        self.hidden_dim = hidden_dim
        self.context_len = context_len
        self.n_heads= n_heads
        self.drop_p = drop_p
        self.rtg_target = rtg_target
        self.max_timestep = max_timestep
        
        self.model =  DecisionTransformer(
            state_dim=self.state_dim,
            act_dim=self.act_dim,
            n_blocks=self.n_blocks,
            hidden_dim=self.hidden_dim,
            context_len=self.context_len,
            n_heads=self.n_heads,
            drop_p=self.drop_p,
            vocab_size = vocab_size).to(device)
        
        self.timesteps = torch.arange(start=0, end=self.max_timestep, step=1)
        self.timesteps = self.timesteps.repeat(1, 1).to(device)
        
        # zeros place holders
        self.actions = torch.zeros((1, self.max_timestep),
                                dtype=torch.int32, device=device)

        self.states = torch.zeros((1, self.max_timestep, state_dim),
                                dtype=torch.float32, device=device)
            
        self.rewards_to_go = torch.zeros((1, self.max_timestep, 1),
                                dtype=torch.float32, device=device)
        
        self.running_rtg = self.rtg_target
        
    def select_action(self, t, running_reward, state, available_actions, EPS = 0, againstDQN=True):
        '''
        selects action for actual state in game

        t: timestep
        running_reward: actual reward we got in game
        state: actual state of the game
        available_actions: the available actions/columns to throw a coin in
        EPS: epsilon for exploration/explotation ratio (not necessary for dt)
        againstDQN: tells if dt plays against dqn -> if yes, action selection over probabilities of the actions to get some variance in the games

        returns: action to take
        '''
        self.model.eval()
        with torch.no_grad():
            # add state in placeholder and normalize
            self.states[0, t] = torch.from_numpy(state.flatten()).to(torch.float32).to(device)
                
            # calcualate running rtg and add in placeholder
            self.running_rtg = self.running_rtg - (running_reward)
            self.rewards_to_go[0, t] = self.running_rtg
            
            if t < self.context_len:
                act_preds, _ = self.model.forward(self.timesteps[:,:self.context_len],
                                                    self.states[:,:self.context_len],
                                                    self.actions[:,:self.context_len],
                                                    self.rewards_to_go[:,:self.context_len])
                act = act_preds[0, t].detach()
            else:
                act_preds, _ = self.model.forward(self.timesteps[:,t-self.context_len+1:t+1],
                                                    self.states[:,t-self.context_len+1:t+1],
                                                    self.actions[:,t-self.context_len+1:t+1],
                                                    self.rewards_to_go[:,t-self.context_len+1:t+1])
                act = act_preds[0, -1].detach()
            
            act = [act[i] for i in available_actions]
            for i in range(len(act)):
                    act[i] = act[i].cpu()
            if againstDQN:
                action = random.choices(range(len(act)), act)[0]
            else:
                action = np.argmax(act)
            action = available_actions[action]
            act = torch.tensor(action)
            
            # add action in placeholder
            self.actions[0, t] = act.to(torch.int32)
            
            return action
        
    def reset_agent(self):
        '''
        resets the dt for next game
        '''
        # zeros place holders
        self.actions = torch.zeros((1, self.max_timestep),
                                dtype=torch.int32, device=device)

        self.states = torch.zeros((1, self.max_timestep, self.state_dim),
                                dtype=torch.float32, device=device)
            
        self.rewards_to_go = torch.zeros((1, self.max_timestep, 1),
                                dtype=torch.float32, device=device)
        
        self.running_rtg = self.rtg_target
    