In [226]:
import torch
import torch.nn as nn
from typing import Optional

In [227]:
class MyPlannerTransformer(nn.Module):
    def __init__(self, num_features=6, numOfEncoderLayers = 6, numOfDecoderLayers=6) -> None:
        super().__init__()
        self.transformer_model = nn.Transformer(
            d_model=num_features,
            nhead=3,
            num_encoder_layers=numOfEncoderLayers,
            num_decoder_layers=numOfDecoderLayers,
            dim_feedforward=2048,
            dropout=0.1,
        )
        # map tgt to have the same number of features as src
        self.tgt = nn.Linear(2, num_features)

    def forward(
        self,
        src: torch.tensor,
        tgt: torch.tensor,
        src_mask: Optional[torch.tensor]=None,
        tgt_mask: Optional[torch.tensor]=None,
        memory_mask: Optional[torch.tensor]=None,
        src_key_padding_mask: Optional[torch.tensor]=None,
        tgt_key_padding_mask: Optional[torch.tensor]=None,
        memory_key_padding_mask: Optional[torch.tensor]=None,
        max_len: Optional[int]=None,
        
    )->torch.tensor:
        tgt = self.tgt(tgt)
        out = self.transformer_model(src, tgt)
        return out[:,:2]
    
    def predict(self, src: torch.tensor, tgt: torch.tensor, max_len: Optional[int]=None)->torch.tensor:
        self.eval()
        with torch.no_grad():
            prediction = self.forward(src, tgt)
        return prediction
    
    

In [228]:
batch_size = 10
src = torch.rand(batch_size, 6)
tgt = torch.rand(batch_size, 2)

In [229]:
src.shape

torch.Size([10, 6])

In [230]:
tgt.shape

torch.Size([10, 2])

In [231]:
src[0]

tensor([0.0169, 0.9111, 0.0025, 0.1272, 0.6298, 0.8852])

In [232]:
tgt[0]

tensor([0.9917, 0.0126])

In [209]:
model = MyPlannerTransformer()
model.forward(src, tgt).shape



torch.Size([10, 2])

In [233]:
model(src, tgt)

tensor([[0.4863, 0.6740],
        [0.4883, 0.7036],
        [0.4937, 0.7097],
        [0.4837, 0.6759],
        [0.5022, 0.7103],
        [0.5036, 0.7114],
        [0.5009, 0.7100],
        [0.5020, 0.7053],
        [0.4883, 0.7040],
        [0.4959, 0.7045]], grad_fn=<SliceBackward0>)

In [234]:
def training_loop(epochs):
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    for epoch in range(epochs):
        
        model.train()
        
        for batch in range(10):

                # Zero the gradients
                optimizer.zero_grad()

                # Forward pass
                output = model(src, tgt)

                # Compute the loss
                loss = criterion(output, tgt)

                # Backward pass and optimization
                loss.backward()
                optimizer.step()

                if (batch + 1) % 10 == 0:
                    print(f'Epoch [{epoch + 1}/{epochs}], Batch [{batch + 1}/{10}], Loss: {loss.item():.4f}')
training_loop(1)


Epoch [1/1], Batch [10/10], Loss: 0.0990


In [235]:
# save model
torch.save(model.state_dict(), 'model.pth')

In [236]:
# load model
model.load_state_dict(torch.load('model.pth'))

<All keys matched successfully>

In [237]:
tgt

tensor([[0.9917, 0.0126],
        [0.2232, 0.5795],
        [0.1295, 0.8194],
        [0.8521, 0.0363],
        [0.7630, 0.8032],
        [0.7915, 0.8569],
        [0.6897, 0.7898],
        [0.9975, 0.6156],
        [0.1996, 0.5916],
        [0.6765, 0.5879]])

In [238]:
# predict
yhat = model.predict(src, tgt)

In [239]:
yhat

tensor([[0.6641, 0.5229],
        [0.6551, 0.5556],
        [0.6551, 0.5627],
        [0.6619, 0.5256],
        [0.6637, 0.5580],
        [0.6640, 0.5592],
        [0.6628, 0.5582],
        [0.6670, 0.5504],
        [0.6548, 0.5562],
        [0.6625, 0.5521]])

In [240]:
# calculate MSE between yhat and tgt
criterion = nn.MSELoss()
loss = criterion(yhat, tgt)
print(loss)

tensor(0.0863)
