In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import numpy as np

In [3]:
from nuplan_extent.planning.training.modeling.models.encoders.gameformer_encoder import GameFormerEncoder

In [6]:
B = 5
T = 11
ds = 11
Nm_lane = 6
Np_lane = 100
dp_lane = 15
Nm_cw = 4
Np_cw = 100
dp_cw = 3
N = 10

In [5]:
ego_hist = torch.randn(B,1,T,ds)
agent_hist = torch.randn(B,N,T,ds)
lanes = torch.randn(B,1,Nm_lane,Np_lane,dp_lane)
lanes_mask = torch.randn(B,1,Nm_lane,Np_lane)
cw = torch.randn(B,1,Nm_cw,Np_cw,dp_cw)
cw_mask = torch.randn(B,1,Nm_cw,Np_cw)

In [6]:
encoder = GameFormerEncoder(
    input_agent_dim=11,
    input_lane_dim=15,
    input_crosswalk_dim=3,
)

out = encoder(ego_hist, agent_hist, lanes, lanes_mask, cw, cw_mask)
print(out['agent_context_encoding'].shape)
print(out['scene_context_encoding'].shape)
print(out['scene_context_mask'].shape)

torch.Size([11, 5, 256])
torch.Size([91, 5, 256])
torch.Size([5, 91])


# Test unravel map index

In [35]:
encoder = GameFormerEncoder(
    input_agent_dim=11,
    input_lane_dim=15,
    input_crosswalk_dim=3,
)

In [51]:
fake_map_tensor = torch.zeros(1, 1, 1, Np_lane, 2)
fake_map_tensor[0, 0, 0, 3, :] = 3.14
print(fake_map_tensor.shape)
print(fake_map_tensor)

torch.Size([1, 1, 1, 100, 2])
tensor([[[[[0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [3.1400, 3.1400],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0

In [52]:
pooler = nn.MaxPool3d([1,10,1], [1,10,1], return_indices=True)

In [53]:
output, max_indices = pooler(fake_map_tensor)

In [54]:
print(output.shape)
print(max_indices.shape)

torch.Size([1, 1, 1, 10, 2])
torch.Size([1, 1, 1, 10, 2])


In [55]:
unravelled = encoder.unravel_mask_index(max_indices, [1, 1, 1, 100, 2])
print(unravelled)
print(unravelled.shape)

tensor([[[[[ 3,  3],
           [10, 10],
           [20, 20],
           [30, 30],
           [40, 40],
           [50, 50],
           [60, 60],
           [70, 70],
           [80, 80],
           [90, 90]]]]])
torch.Size([1, 1, 1, 10, 2])


In [56]:
print(torch.gather(
    fake_map_tensor,
    3,
    unravelled))

tensor([[[[[3.1400, 3.1400],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000],
           [0.0000, 0.0000]]]]])
