In [1]:
import numpy as np
import torch
import json

# Example trajectories
trajectories = [
    [
        ([50, 950], [950, 50], (4, 2), 50, [150, 950], [950, 150], False),
        ([150, 950], [950, 150], (4, 2), 30, [250, 950], [950, 250], False),
        ([250, 950], [950, 250], (0, 3), 0, [250, 950], [850, 250], False),
        ([250, 950], [850, 250], (1, 1), 20, [250, 850], [850, 150], False),
        ([250, 850], [850, 150], (0, 4), 0, [250, 850], [950, 150], False),
        ([250, 850], [950, 150], (1, 1), 20, [250, 750], [950, 50], False),
        ([250, 750], [950, 50], (4, 1), -100, [350, 750], [950, 50], False),
        ([350, 750], [950, 50], (1, 1), -80, [350, 650], [950, 50], False),
        ([350, 650], [950, 50], (4, 0), 10, [450, 650], [950, 50], False),
        ([450, 650], [950, 50], (4, 2), 30, [550, 650], [950, 150], False),
        ([550, 650], [950, 150], (2, 4), -70, [550, 750], [950, 150], False),
        ([550, 750], [950, 150], (0, 3), 0, [550, 750], [850, 150], False),
        
        # Add more steps...
    ],[
        ([50, 950], [950, 50], (4, 2), 50, [150, 950], [950, 150], False),
        ([150, 950], [950, 150], (4, 2), 30, [250, 950], [950, 250], False),
        ([250, 950], [950, 250], (0, 3), 0, [250, 950], [850, 250], False),
        ([250, 950], [850, 250], (1, 1), 20, [250, 850], [850, 150], False),
        ([250, 850], [850, 150], (0, 4), 0, [250, 850], [950, 150], False),
        ([250, 850], [950, 150], (1, 1), 20, [250, 750], [950, 50], False),
        ([250, 750], [950, 50], (4, 1), -100, [350, 750], [950, 50], False),
        ([350, 750], [950, 50], (1, 1), -80, [350, 650], [950, 50], False),
        ([350, 650], [950, 50], (4, 0), 10, [450, 650], [950, 50], False),
        ([450, 650], [950, 50], (4, 2), 30, [550, 650], [950, 150], False),
        ([550, 650], [950, 150], (2, 4), -70, [550, 750], [950, 150], False),
        ([550, 750], [950, 150], (0, 3), 0, [550, 750], [850, 150], False),
    ]
    # Add more trajectories...
]
# Let's read the content of the 'data.txt' file and store it in a variable.
# with open('data_1.json', 'r') as file:
#     json_data = json.load(file)




def preprocess_trajectories(trajectories):
    returns, states, actions, timesteps = [], [], [], []
    for traj in trajectories:
        traj_returns = []
        traj_states = []
        traj_actions = []
        traj_timesteps = []
        for t, step in enumerate(traj):
            state_uav1, state_uav2, action, reward, next_state_uav1, next_state_uav2, done = step
            state = state_uav1 + state_uav2
            traj_returns.append(reward)
            traj_states.append(state)
            traj_actions.append(list(action))
            traj_timesteps.append(t)  # Generate timesteps as a sequence of integers
        returns.append(traj_returns)
        states.append(traj_states)
        actions.append(traj_actions)
        timesteps.append(traj_timesteps)
    return np.array(returns), np.array(states), np.array(actions), np.array(timesteps)

returns, states, actions, timesteps = preprocess_trajectories(trajectories)

# Convert to PyTorch tensors
returns_tensor = torch.tensor(returns, dtype=torch.float32)
states_tensor = torch.tensor(states, dtype=torch.float32)
actions_tensor = torch.tensor(actions, dtype=torch.float32)
timesteps_tensor = torch.tensor(timesteps, dtype=torch.float32)

print(returns_tensor.shape)  # Expected: (num_trajectories, num_steps)
print(states_tensor.shape)  # Expected: (num_trajectories, num_steps, state_dim)
print(actions_tensor.shape)  # Expected: (num_trajectories, num_steps, action_dim)
print(timesteps_tensor.shape)  # Expected: (num_trajectories, num_steps)


torch.Size([2, 12])
torch.Size([2, 12, 4])
torch.Size([2, 12, 2])
torch.Size([2, 12])


In [2]:
import torch.nn as nn
from transformers import GPT2Model, GPT2Config

class DecisionTransformer(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim, max_length=50):
        super(DecisionTransformer, self).__init__()
        config = GPT2Config(vocab_size=1, n_positions=max_length, n_embd=hidden_dim, n_layer=4, n_head=8)
        self.transformer = GPT2Model(config)
        
        # Embedding layers
        self.embed_t = nn.Embedding(max_length, hidden_dim)
        self.state_emb = nn.Linear(state_dim, hidden_dim)
        self.action_emb = nn.Linear(action_dim, hidden_dim)
        self.reward_emb = nn.Linear(1, hidden_dim)
        
        self.predict_action = nn.Linear(hidden_dim, action_dim)
        
    def forward(self, returns, states, actions, timesteps):
        pos_embedding = self.embed_t(timesteps.long())
        s_embedding = self.state_emb(states) + pos_embedding
        a_embedding = self.action_emb(actions) + pos_embedding
        R_embedding = self.reward_emb(returns.unsqueeze(-1)) + pos_embedding
        
        # Interleave tokens as (R_1, s_1, a_1, ..., R_K, s_K)
        input_embeds = torch.cat((R_embedding.unsqueeze(1), s_embedding.unsqueeze(1), a_embedding.unsqueeze(1)), dim=1).view(s_embedding.size(0), -1, s_embedding.size(-1))
        
        # Debugging print statements
        print(f"input_embeds shape: {input_embeds.shape}")
        
        # Use transformer to get hidden states
        transformer_outputs = self.transformer(inputs_embeds=input_embeds)
        hidden_states = transformer_outputs.last_hidden_state
        
        print(f"hidden_states shape: {hidden_states.shape}")
        
        # Select hidden states for action prediction tokens
        a_hidden = hidden_states[:, 2::3]  # Assuming the actions are at these positions
        
        print(f"a_hidden shape: {a_hidden.shape}")

        # Predict action
        return self.predict_action(a_hidden)


In [3]:
def forward(self, returns, states, actions, timesteps):
    print(f"returns shape: {returns.shape}")
    print(f"states shape: {states.shape}")
    print(f"actions shape: {actions.shape}")
    print(f"timesteps shape: {timesteps.shape}")
    
    pos_embedding = self.embed_t(timesteps.long())
    s_embedding = self.state_emb(states) + pos_embedding
    a_embedding = self.action_emb(actions) + pos_embedding
    R_embedding = self.reward_emb(returns.unsqueeze(-1)) + pos_embedding

    print(f"pos_embedding shape: {pos_embedding.shape}")
    print(f"s_embedding shape: {s_embedding.shape}")
    print(f"a_embedding shape: {a_embedding.shape}")
    print(f"R_embedding shape: {R_embedding.shape}")

    # Interleave tokens as (R_1, s_1, a_1, ..., R_K, s_K)
    input_embeds = torch.stack((R_embedding, s_embedding, a_embedding), dim=1).view(s_embedding.size(0), -1, s_embedding.size(-1))
    
    print(f"input_embeds shape: {input_embeds.shape}")
    
    # Use transformer to get hidden states
    hidden_states = self.transformer(inputs_embeds=input_embeds).last_hidden_state
    
    print(f"hidden_states shape: {hidden_states.shape}")
    
    # Select hidden states for action prediction tokens
    a_hidden = hidden_states[:, 2::3]  # Assuming the actions are at these positions
    
    print(f"a_hidden shape: {a_hidden.shape}")

    # Predict action
    return self.predict_action(a_hidden)


In [4]:
from torch.utils.data import DataLoader, TensorDataset

# Create a DataLoader for batching
dataset = TensorDataset(returns_tensor, states_tensor, actions_tensor, timesteps_tensor)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

model = DecisionTransformer(state_dim=4, action_dim=2, hidden_dim=128)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.MSELoss()

for epoch in range(100):  # Number of epochs
    for R_batch, s_batch, a_batch, t_batch in dataloader:
        optimizer.zero_grad()
        a_preds = model(R_batch, s_batch, a_batch, t_batch)
        loss = loss_fn(a_preds, a_batch)
        loss.backward()
        optimizer.step()
        
    print(f'Epoch {epoch}, Loss: {loss.item()}')


input_embeds shape: torch.Size([2, 36, 128])
hidden_states shape: torch.Size([2, 36, 128])
a_hidden shape: torch.Size([2, 12, 128])
Epoch 0, Loss: 6.664906978607178
input_embeds shape: torch.Size([2, 36, 128])
hidden_states shape: torch.Size([2, 36, 128])
a_hidden shape: torch.Size([2, 12, 128])
Epoch 1, Loss: 6.3777241706848145
input_embeds shape: torch.Size([2, 36, 128])
hidden_states shape: torch.Size([2, 36, 128])
a_hidden shape: torch.Size([2, 12, 128])
Epoch 2, Loss: 5.943504810333252
input_embeds shape: torch.Size([2, 36, 128])
hidden_states shape: torch.Size([2, 36, 128])
a_hidden shape: torch.Size([2, 12, 128])
Epoch 3, Loss: 6.439935684204102
input_embeds shape: torch.Size([2, 36, 128])
hidden_states shape: torch.Size([2, 36, 128])
a_hidden shape: torch.Size([2, 12, 128])
Epoch 4, Loss: 5.826087474822998
input_embeds shape: torch.Size([2, 36, 128])
hidden_states shape: torch.Size([2, 36, 128])
a_hidden shape: torch.Size([2, 12, 128])
Epoch 5, Loss: 5.825540065765381
input_emb