In [1]:
from utils.file_utils import get_config
from utils.dataset_utils import get_dataloader
from utils.training_utils import parse_data, get_scene_vehicle_nums
import torch
config = get_config('./config.yaml')

In [2]:

train_loader, val_loader, test_loader = get_dataloader(config)
gpu_id = torch.cuda.current_device()
for batch in train_loader:
    ## if no vehicles in the batch, skip it
    vehicle_nums_dic = get_scene_vehicle_nums(batch)
    if vehicle_nums_dic['cur'] == 0:
        print('No vehicles in the batch')
        continue
    input_dict, ground_truth_dict = parse_data(data=batch, gpu_id=gpu_id, config=config)
    break
    

In [3]:
input_dict.keys()

dict_keys(['prv/state/his/timestamp', 'prv/state/his/x_position', 'prv/state/his/y_position', 'prv/state/his/x_velocity', 'prv/state/his/y_velocity', 'prv/state/his/yaw_angle', 'prv/state/his/occluded_occupancy_map', 'prv/state/his/observed_occupancy_map', 'prv/state/his/flow_map', 'cur/state/his/timestamp', 'cur/state/his/x_position', 'cur/state/his/y_position', 'cur/state/his/x_velocity', 'cur/state/his/y_velocity', 'cur/state/his/yaw_angle', 'cur/state/his/occluded_occupancy_map', 'cur/state/his/observed_occupancy_map', 'cur/state/his/flow_map', 'nxt/state/his/timestamp', 'nxt/state/his/x_position', 'nxt/state/his/y_position', 'nxt/state/his/x_velocity', 'nxt/state/his/y_velocity', 'nxt/state/his/yaw_angle', 'nxt/state/his/occluded_occupancy_map', 'nxt/state/his/observed_occupancy_map', 'nxt/state/his/flow_map', 'prv/meta/length', 'prv/meta/width', 'prv/meta/class', 'prv/meta/direction', 'cur/meta/length', 'cur/meta/width', 'cur/meta/class', 'cur/meta/direction', 'nxt/meta/length', 

torch.Size([2, 2, 40])

In [14]:
vector_features_list = ['cur/meta/length', 'cur/meta/width', 'cur/meta/class', 'cur/meta/direction']
node_features_list = ['cur/state/his/timestamp', 'cur/state/his/x_position', 'cur/state/his/y_position', 'cur/state/his/x_velocity', 'cur/state/his/y_velocity', 'cur/state/his/yaw_angle',]
vector_feature = torch.cat([torch.unsqueeze(input_dict[feature], dim=1) for feature in vector_features_list], dim=1)
print(len(vector_features_list))
print(len(node_features_list))
node_feature = torch.cat([input_dict[feature].permute(0, 2, 1) for feature in node_features_list], dim=2)
print(vector_feature.shape)
print(node_feature.shape)

4
6
torch.Size([2, 4])
torch.Size([2, 40, 12])


In [5]:
from modules.SwinTransformerEncoder import SwinTransformerEncoder
from modules.FlowGuidedMultiHeadSelfAttention import FlowGuidedMultiHeadSelfAttention
occupancy_map = torch.randn((2,256,256,40))
flow_map = torch.randn((2,256,256,40, 2))

In [7]:

swin = SwinTransformerEncoder(config=config)
res = swin(occupancy_map=occupancy_map,flow_map=flow_map[:,:,:,-1,:], road_map=None)
q = res[-1]
fg_msa = FlowGuidedMultiHeadSelfAttention(config=config)
res,pos,ref = fg_msa(q)

q = res + q
B, H, W, D = q.size()
q = q.reshape(B, H*W, D)
T = 12
query = torch.repeat_interleave(torch.unsqueeze(q, dim=1),repeats=T,axis=1)
ref = ref.reshape(B, T, H*W, D)
query = ref + query
print(query.shape)



torch.Size([2, 12, 256, 384])
