# Data

Initialization of dataset

In [1]:
from unitraj.datasets.base_dataset import BaseDataset

class MoEDataset(BaseDataset):
        def __init__(self, config=None, is_validation=False):
            super().__init__(config, is_validation)

Check config and data loading status

In [2]:
from datasets import build_dataset
from omegaconf import OmegaConf
from torch.utils.data import DataLoader

# Create config dictionary
cfg = OmegaConf.create({
    "load_num_workers": 0,  # number of workers for loading data
    "train_data_path": ["data_samples/nuscenes"],  # list of paths to the training data
    "val_data_path": ["data_samples/nuscenes"],  # list of paths to the validation data
    "cache_path": "./cache",
    "max_data_num": [None],  # maximum number of data for each training dataset, None means all data
    "starting_frame": [0],  # history trajectory starts at this frame for each training dataset
    "past_len": 21,  # history trajectory length, 2.1s
    "future_len": 60,  # future trajectory length, 6s
    "object_type": ["VEHICLE"],  # object types included in the training set
    "line_type": ["lane", "stop_sign", "road_edge", "road_line", "crosswalk", "speed_bump"],  # line type to be considered in the input
    "masked_attributes": ["z_axis", "size"],  # attributes to be masked in the input
    "trajectory_sample_interval": 1,  # sample interval for the trajectory
    "only_train_on_ego": False,  # only train on AV
    "center_offset_of_map": [30.0, 0.0],  # center offset of the map
    "use_cache": False,  # use cache for data loading
    "overwrite_cache": False,  # overwrite existing cache
    "store_data_in_memory": False,  # store data in memory
    "method": {"model_name": "autobot"}
})

dataset = build_dataset(cfg)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0, collate_fn=dataset.collate_fn)
# Print dataset stats
print(f"Total samples: {len(dataset)}")
print(f"Batches per epoch: {len(dataloader)}\n")

# Inspect batches
for batch in dataloader:
    inputs = batch["input_dict"]
    break

Loading training data...
Loaded 61 samples from data_samples/nuscenes
Data loaded
Total samples: 61
Batches per epoch: 61



Check the content of input

In [4]:
import  torch
model_input = {}
agents_in, agents_mask, roads = inputs['obj_trajs'], inputs['obj_trajs_mask'], inputs['map_polylines']
ego_in = torch.gather(agents_in, 1, inputs['track_index_to_predict'].view(-1, 1, 1, 1).repeat(1, 1, *agents_in.shape[-2:])).squeeze(1)
ego_mask = torch.gather(agents_mask, 1, inputs['track_index_to_predict'].view(-1, 1, 1).repeat(1, 1, agents_mask.shape[-1])).squeeze(1)
agents_in = torch.cat([agents_in, agents_mask.unsqueeze(-1)], dim=-1)
agents_in = agents_in.transpose(1, 2)
ego_in = torch.cat([ego_in, ego_mask.unsqueeze(-1)], dim=-1)
roads = torch.cat([inputs['map_polylines'], inputs['map_polylines_mask'].unsqueeze(-1)], dim=-1)
model_input['ego_in'] = ego_in
model_input['agents_in'] = agents_in
model_input['roads'] = roads

model_input.keys()

dict_keys(['ego_in', 'agents_in', 'roads'])

# Model

In [3]:
import torch
from torch import nn
from unitraj.models.base_model.base_model import BaseModel

class MoE(BaseModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_experts = config.get('num_experts', 4)
        self.expert_hidden_size = config.get('expert_hidden_size', 256)
        self.gate_hidden_size = config.get('gate_hidden_size', 128)
        self.input_size = config.get('input_size', 128)
        self.output_size = config.get('output_size', 128)
        
        # Define experts
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.input_size, self.expert_hidden_size),
                nn.ReLU(),
                nn.Linear(self.expert_hidden_size, self.output_size)
            ) for _ in range(self.num_experts)
        ])
        
        # Define gate network
        self.gate = nn.Sequential(
            nn.Linear(self.input_size, self.gate_hidden_size),
            nn.ReLU(),
            nn.Linear(self.gate_hidden_size, self.num_experts),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, x):
        # Get expert weights from gate
        gate_outputs = self.gate(x)
        
        # Get outputs from each expert
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
        
        # Combine expert outputs weighted by gate outputs
        final_output = torch.sum(gate_outputs.unsqueeze(-1) * expert_outputs, dim=1)
        
        return final_output

In [1]:
# Create sample input and instantiate model
input_size = 10
output_size = 5
batch_size = 8

# Create config dictionary
config = {
    'input_size': input_size,
    'output_size': output_size,
    'num_experts': 4,
    'expert_hidden_size': 256,
    'gate_hidden_size': 128
}

# Create model and sample input
model = MoE(config)
x = torch.randn(batch_size, input_size)

# Forward pass with shape checking
print(f"Input shape: {x.shape}")

# Check gate output shape
gate_outputs = model.gate(x)
print(f"Gate outputs shape: {gate_outputs.shape}")

# Check expert outputs shape
expert_outputs = torch.stack([expert(x) for expert in model.experts], dim=1)
print(f"Expert outputs shape: {expert_outputs.shape}")

# Get final output and check shape
final_output = model(x)
print(f"Final output shape: {final_output.shape}")

NameError: name 'MoE' is not defined