In [2]:
# pip install torch

Collecting torch
  Downloading torch-2.0.0-cp310-none-macosx_11_0_arm64.whl (55.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.8/55.8 MB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
[?25hCollecting networkx
  Downloading networkx-3.1-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m15.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting jinja2
  Using cached Jinja2-3.1.2-py3-none-any.whl (133 kB)
Collecting sympy
  Using cached sympy-1.11.1-py3-none-any.whl (6.5 MB)
Collecting mpmath>=0.19
  Using cached mpmath-1.3.0-py3-none-any.whl (536 kB)
Installing collected packages: mpmath, sympy, networkx, jinja2, torch
Successfully installed jinja2-3.1.2 mpmath-1.3.0 networkx-3.1 sympy-1.11.1 torch-2.0.0
Note: you may need to restart the kernel to use updated packages.


In [5]:
import torch
import torch.nn as nn
class DecisionTransformer(nn.Module):
    def __init__(self, state_dim, action_dim, reward_dim, embedding_dim, nhead, num_layers):
        super(DecisionTransformer, self).__init__() 
        self.state_embedding = nn.Linear(state_dim, embedding_dim) 
        self.action_embedding = nn.Linear(action_dim, embedding_dim) 
        self.reward_embedding = nn.Linear(reward_dim, embedding_dim)
        
        self.transformer = nn.Transformer(embedding_dim, nhead, num_layers) 
        self.output_layer = nn.Linear(embedding_dim, action_dim)
        
    def forward(self, states, actions, rewards, mask=None): 
        state_embeds = self.state_embedding(states) 
        action_embeds = self.action_embedding(actions) 
        reward_embeds = self.reward_embedding(rewards)
        input_embeds = torch.cat((state_embeds, action_embeds, reward_embeds), dim=1) 
        # tgt = torch.rand((20, 32, 512))
        # TODO: what is tgt?
        transformer_output = self.transformer(input_embeds.transpose(0, 1), src_key_padding_mask=mask, tgt=input_embeds.transpose(0, 1)) 
        action_logits = self.output_layer(transformer_output[-1])
        return action_logits

state_dim = 10 
action_dim = 5 
reward_dim = 1 
embedding_dim = 64 
nhead = 4 
num_layers = 2

model = DecisionTransformer(state_dim, action_dim, reward_dim, embedding_dim, nhead, num_layers)

# Load dataset
states = torch.randn(32, 10, state_dim)
actions = torch.randint(0, action_dim, (32, 10, 1))
rewards = torch.randn(32, 10, 1)

# Transfer to one-hot
actions_one_hot = torch.zeros(32, 10, action_dim) 
actions_one_hot.scatter_(2, actions, 1)

# Got outputs of the model
action_logits = model(states, actions_one_hot, rewards)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

# Training (one batch)
optimizer.zero_grad()
loss = loss_fn(action_logits, actions.squeeze(-1)[:, -1]) 
loss.backward()
optimizer.step()

In [6]:
action_logits

tensor([[-0.5293,  0.0439,  0.8802, -0.7718, -1.1242],
        [ 0.1930,  0.4248,  1.2229, -0.2353, -0.4624],
        [-0.1271,  0.2617,  0.8487, -0.2455, -0.5983],
        [-0.5104, -0.2908,  1.0161, -0.5251, -0.6145],
        [-0.2598,  0.1549,  1.0200, -0.4023, -0.6141],
        [-0.0993,  0.8159,  1.1459, -0.6534, -0.3159],
        [-0.8067,  0.1132,  0.4998, -0.0729, -1.0338],
        [-0.0929,  0.1405,  1.0532, -0.5955, -0.9101],
        [ 0.0291,  0.2613,  0.8143, -0.2087, -0.1612],
        [-0.0151,  0.2863,  0.4164, -0.5595, -0.3650],
        [-0.5437,  0.4898,  0.4624, -0.4046, -0.8115],
        [ 0.1059,  0.3046,  0.8754, -0.5264, -0.6838],
        [-0.0451, -0.2717,  0.1272, -0.0477, -1.1460],
        [-0.0667,  0.1108,  0.4717, -0.1777, -1.0654],
        [-0.0175,  0.5096,  0.7076, -0.5101, -0.1396],
        [ 0.0511,  0.5153,  1.2029, -0.3754, -1.0642],
        [-0.2405,  0.2329,  1.1951, -0.1411, -1.0951],
        [-0.0241,  0.3872,  1.2932, -0.4436, -0.5085],
        [ 