In [5]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense,LayerNormalization, BatchNormalization, Dropout, Conv1D, Flatten, Reshape,TimeDistributed
from tensorflow.keras.models import Model

# define transformer block
class TransformerBlock(Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate =0.1):
        super(TransformerBlock, self).__init__()
        self.att = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)

        self.ffn = tf.keras.Sequential([ 
            Dense(ff_dim, activation = "relu"),
            Dense(embed_dim)
        ])

        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)

    def call(self, inputs, training, mask=None):
        attn_output = self.att(inputs, inputs, inputs, attention_mask=mask) # 3 inputs for query, key, value
        attn_output = self.dropout1(attn_output,training=training)
        out1 = self.layernorm1(inputs + attn_output)

        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

In [10]:
#### define DecisionTransformer model
class DecisionTransformer(tf.keras.Model):
    def __init__(self, state_dim, action_dim, embedding_dim, num_heads, ff_dim, num_layers):
        super(DecisionTransformer, self).__init__()
        self.state_embed = Dense(embedding_dim, activation="relu")
        self.action_embed = Dense(embedding_dim, activation="relu")
        self.transformer_layers = [TransformerBlock(embedding_dim, num_heads, ff_dim) 
                                   for _ in range(num_layers)]
        self.dense = TimeDistributed(Dense(action_dim))

    def call(self, states, actions):
        state_embeddings = self.state_embed(states)
        action_embeddings = self.action_embed(actions)

        x = state_embeddings + action_embeddings
        for transformer_layer in self.transformer_layers:
            x = transformer_layer(x, training=True)
        return self.dense(x)

In [11]:
# Example
state_dim = 20
action_dim = 5
embedding_dim = 128
num_heads = 4
ff_dim = 512
num_layers = 6

# Initialize DesitionTransformer
dt = DecisionTransformer(state_dim, action_dim, embedding_dim, num_heads, ff_dim, num_layers)

# generate states and actions
states = tf.random.uniform((32, 100, state_dim))
# Batch of 32 sequences of 100 states
actions = tf.random.uniform((32, 100, action_dim))
# Batch of 32 sequences of 100 states

# get model prediction
output = dt(states, actions, training=True)

print(output.shape)

(32, 100, 5)
