In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [16]:
batch_size = 1
seq_len = 20
embedded_dim = 128

In [17]:
torch.manual_seed(42)
coords = torch.rand(batch_size,seq_len,embedded_dim)*100
print(coords.shape)

torch.Size([1, 20, 128])


In [18]:
Q = nn.Linear(embedded_dim,embedded_dim)
K = nn.Linear(embedded_dim,embedded_dim)
V = nn.Linear(embedded_dim,embedded_dim)

In [19]:
q = Q(coords)
k = K(coords)
v = V(coords)

In [20]:
q.shape, k.shape, v.shape

(torch.Size([1, 20, 128]), torch.Size([1, 20, 128]), torch.Size([1, 20, 128]))

In [21]:
k.transpose(-2,-1).shape

torch.Size([1, 128, 20])

In [22]:
pre_scores = torch.matmul(q,k.transpose(-2,-1)) / (embedded_dim**0.5)

In [23]:
pre_scores.shape

torch.Size([1, 20, 20])

In [28]:
scores = F.softmax(pre_scores,dim=-1)

In [29]:
scores.shape

torch.Size([1, 20, 20])

In [30]:
atten_val = torch.matmul(scores,v)

In [31]:
atten_val.shape

torch.Size([1, 20, 128])

In [49]:
batch_size = 256
seq_len = 20
embed_dim = 128

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim=2, embed_dim=128, seq_len=20):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(seq_len, embed_dim)
        self.fc = nn.Linear(input_dim, embed_dim)

    def forward(self, x):
        # x: (batch, seq_len, input_dim)
        batch_size, seq_len, _ = x.size()
        print(positions)
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, seq_len)
        pos_emb = self.embedding(positions)  # (batch, seq_len, embed_dim)
        x_proj = self.fc(x)                  # (batch, seq_len, embed_dim)
        return x_proj + pos_emb

In [None]:
enc = Encoder(2,embed_dim,seq_len)
enc_out = enc(torch.rand(batch_size,seq_len,2))
enc_out.shape

torch.Size([256, 20, 128])

In [None]:
class DecoderWithAttention(nn.Module):
    def __init__(self, embed_dim, hidden_dim, output_dim, seq_len):
        super(DecoderWithAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.seq_len = seq_len
        self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True)
        self.attn = nn.Linear(hidden_dim + embed_dim, seq_len)
        self.attn_combine = nn.Linear(hidden_dim + embed_dim, embed_dim)
        self.out = nn.Linear(hidden_dim, output_dim)

    def forward(self, encoder_outputs, decoder_input, hidden):
        # encoder_outputs: (batch, seq_len, embed_dim)
        # decoder_input: (batch, 1, embed_dim)
        # hidden: (1, batch, hidden_dim)
        batch_size = encoder_outputs.size(0)

        # Repeat hidden state across seq_len for attention
        hidden_repeat = hidden[-1].unsqueeze(1).repeat(1, self.seq_len, 1)  # (batch, seq_len, hidden_dim)
        attn_input = torch.cat((encoder_outputs, hidden_repeat), dim=2)      # (batch, seq_len, embed_dim+hidden_dim)
        attn_weights = F.softmax(self.attn(attn_input), dim=1)               # (batch, seq_len, seq_len)

        # Compute context vector as weighted sum of encoder outputs
        context = torch.bmm(attn_weights.transpose(1,2), encoder_outputs)    # (batch, seq_len, embed_dim)
        context = context[:, 0:1, :]  # Use only the first context vector for current step

        # Combine context with decoder input
        rnn_input = torch.cat((decoder_input, context), dim=2)               # (batch, 1, embed_dim*2)
        rnn_input = self.attn_combine(rnn_input)                             # (batch, 1, embed_dim)
        rnn_input = F.relu(rnn_input)

        output, hidden = self.gru(rnn_input, hidden)                         # output: (batch, 1, hidden_dim)
        output = self.out(output.squeeze(1))                                 # (batch, output_dim)
        return output, hidden, attn_weights


In [None]:

# Example usage:
decoder = DecoderWithAttention(embed_dim=embedded_dim, hidden_dim=128, output_dim=embedded_dim, seq_len=seq_len)
decoder_input = torch.zeros(batch_size, 1, embedded_dim)  # initial input (e.g., <SOS>)
hidden = torch.zeros(1, batch_size, 128)
output, hidden, attn_weights = decoder(enc_out, decoder_input, hidden)


In [1]:
task_seq = [0, 3, 1, 5, 7, 2]
task_seq = [tid for tid in task_seq if tid not in [0, 1, 2]]
print(task_seq)


[3, 5, 7]


In [5]:
ROBOT_DEPOTS = {
    0: (0.1, 0.1),   # Robot 0's base
    1: (0.9, 0.1),   # Robot 1's base
    2: (0.5, 0.9),   # Robot 2's base
}
TASK_COORDINATES = {
    0: (0.2, 0.2),   # Task 0
    1: (0.8, 0.2),   # Task 1
    2: (0.5, 0.8),   # Task 2   
    3: (0.3, 0.7),   # Task 3
    4: (0.6, 0.5),   # Task 4
    5: (0.4, 0.4),   # Task 5
}

In [9]:
task_seq = [0, 3, 1, 5, 4, 2]
depot_coord = ROBOT_DEPOTS[0]
# task_seq = [tid for tid in task_seq if tid not in [0, 1, 2]]
coords = [depot_coord] + [TASK_COORDINATES[tid] for tid in task_seq]
print(coords)

[(0.1, 0.1), (0.2, 0.2), (0.3, 0.7), (0.8, 0.2), (0.4, 0.4), (0.6, 0.5), (0.5, 0.8)]
