## Implementation of quantizer/tokenizer for trajectories from MotionLM paper
## https://arxiv.org/abs/2309.16534

In [1]:

import torch
import numpy as np
from itertools import product
import matplotlib.pyplot as plt

class TrajectoryQuantizer:
    def __init__(self, num_bins=128, delta_min=-18, delta_max=18, verlet_bins=128):
        self.num_bins = num_bins
        self.delta_min = delta_min
        self.delta_max = delta_max
        self.verlet_bins = verlet_bins # using 128 uses the default bin assumptions of num_bins
        self.bin_width = (delta_max - delta_min) / num_bins
        self.cartesian_mapping = self.create_cartesian_product_mapping()

    def normalize_trajectory(self, trajectory, initial_position, initial_heading):
        # Translate the trajectory so that the initial position is at the origin
        normalized_trajectory = trajectory - initial_position

        # Create a rotation matrix to align the initial heading with the x-axis
        cos_theta, sin_theta = torch.cos(initial_heading), torch.sin(initial_heading)
        rotation_matrix = torch.tensor([[cos_theta, -sin_theta], [sin_theta, cos_theta]])

        # Apply the rotation to all points in the trajectory
        normalized_trajectory = torch.matmul(normalized_trajectory, rotation_matrix.T)

        return normalized_trajectory

    def compute_delta_actions(self, normalized_traj):
        # Compute deltas follow LaneGCN type of displacements
        deltas = normalized_traj[1:] - normalized_traj[:-1]
        # print(f"Normalized traj: {normalized_traj}")
        # print(f"deltas: {deltas}")

        # Quantize deltas
        delta_x_bins = torch.floor((deltas[:, 0] - self.delta_min) / self.bin_width).long()
        delta_y_bins = torch.floor((deltas[:, 1] - self.delta_min) / self.bin_width).long()
        
        # Clip values to ensure they fall within the valid bin range
        delta_x_bins = torch.clamp(delta_x_bins, 0, self.num_bins - 1)
        delta_y_bins = torch.clamp(delta_y_bins, 0, self.num_bins - 1)

        return delta_x_bins, delta_y_bins

    def apply_verlet_wrapper(self, delta_bins):
        # TODO: add the verlet integration to reduce the bin size from 128 to 13(as mentioned in the paper)
        return delta_bins

    def create_cartesian_product_mapping(self):
        # Generate all possible combinations of bin indices
        cartesian_product = list(product(range(self.verlet_bins), repeat=2))
        # Create a mapping from tuple (x_bin, y_bin) to a unique index
        return {bin_indices: i for i, bin_indices in enumerate(cartesian_product)}

    def map_to_single_index(self, delta_x_bins, delta_y_bins):
        indices = [self.cartesian_mapping[(x.item(), y.item())] for x, y in zip(delta_x_bins, delta_y_bins)]
        return torch.tensor(indices)

    def quantize_trajectory(self, trajectory, initial_position, initial_heading):
        normalized_traj = self.normalize_trajectory(trajectory, initial_position, initial_heading)
        delta_x_bins, delta_y_bins = self.compute_delta_actions(normalized_traj)
        verlet_x_bins = self.apply_verlet_wrapper(delta_x_bins)
        verlet_y_bins = self.apply_verlet_wrapper(delta_y_bins)
        single_indices = self.map_to_single_index(verlet_x_bins, verlet_y_bins)
        # print(f"Normalized trajectories: {normalized_traj}")
        plt.plot(normalized_traj) # to see the normalized trajectory
    
        return single_indices

## Uncomment the following if you want to use a toy example where the sampling is done from cos and sine curves.
# if __name__ == "__main__":
#     # Create a sample trajectory
#     trajectory_length = 8*2 # 2Hz at 8s 
#     sample_pt = np.linspace(-1, np.pi, trajectory_length).astype(np.float32)
#     delta_x = torch.tensor(np.sin(sample_pt))
#     delta_y = torch.tensor(np.cos(sample_pt))
    
#     initial_position = torch.tensor([delta_x[0], delta_y[0]])
#     initial_heading = torch.tensor(0.0)  # Assuming initial heading is along x-axis
#     quantizer = TrajectoryQuantizer()
#     quantized_trajectory = quantizer.quantize_trajectory(torch.stack([delta_x, delta_y], dim=1), initial_position, initial_heading)
#     plt.plot(delta_x, delta_y, 'o' )
    
#     print(f"Quantized trajectory indices: \n {quantized_trajectory}")
#     print(f"Vocabulary size: {len(quantizer.cartesian_mapping)}")

## Dataloader for ArgoVerse2

In [None]:
import os, random
import matplotlib.pyplot as plt
from av2.datasets.motion_forecasting.data_schema import ObjectType
from torch.utils.data import Dataset
from pathlib import Path
import numpy as np
# relative imports
from av2.map.map_api import ArgoverseStaticMap
from av2.datasets.motion_forecasting import scenario_serialization
from dataclasses import dataclass
from torch.utils.data import Dataset, DataLoader

from av2.datasets.motion_forecasting.data_schema import (
    ArgoverseScenario,
    ObjectType,
    TrackCategory,
    ObjectState
)

torch.manual_seed(42)
np.random.seed(42)

def get_common_ax_plots():
    fig = plt.figure(1, figsize=(10, 10), dpi=300)
    ax = fig.add_subplot(111)
    return fig, ax

@dataclass
class Av2Configs:
    data_root_dir: str = "data/2_av2/"
    deterministic_loading: bool = True

def get_ft_obs(scenario: ArgoverseScenario):
    """
    Returns the past and fut obs for a focal track given a scenario
    :param scenario:
    :return: track, obs_past, obs_fut, heading_angle
    """
    for track in scenario.tracks:
        if track.category == TrackCategory.FOCAL_TRACK:  # and track.object_type == ObjectType.VEHICLE:
            # get the last observed position of the focal track
            observed_states_past = [obj_states.position for obj_states in track.object_states if
                                    obj_states.observed]
            # print(f"Last observed position of the Focal track: {observed_states[-1]}")
            observed_states_future = [obj_states.position for obj_states in track.object_states if
                                      not obj_states.observed]
            
            heading_angle = [obj_states.heading for obj_states in track.object_states]
            

            return track, observed_states_past, observed_states_future, heading_angle

def custom_collate(batch):
    # Assuming all dictionaries in the batch have the same keys
    return {key: [d[key] for d in batch] for key in batch[0].keys()}

class DataProcessorAV2(Dataset):
    def __init__(self, config: Av2Configs, args, split):
        self.data_root_dir = Path(config.data_root_dir) / split  # for the dataset
        self.config = config
        self.args = args
        self.file_names = os.listdir(self.data_root_dir)
        self.total_trajectories = []
        if not config.deterministic_loading:
            random.shuffle(self.file_names)

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        log_map_path = Path(self.data_root_dir) / file_name
        scenario_path = Path(self.data_root_dir) / file_name / f"scenario_{file_name}.parquet"

        if not log_map_path.exists():
            raise FileNotFoundError(f"The map directory {log_map_path} does not exist.")

        scenario = scenario_serialization.load_argoverse_scenario_parquet(scenario_path)
        avm = ArgoverseStaticMap.from_map_dir(log_map_path, build_raster=False)
        # fig, ax = get_common_ax_plots()
        track, ft_obs_past, ft_obs_future, ft_heading_angle = get_ft_obs(scenario)
        required_obj = [ObjectType.VEHICLE]
        
        data = {
        "track": None,
        "ft_obs_past": None,
        "ft_obs_future": None,
        "ft_heading_angle": None,
        "is_valid": False
        }

        for obj in required_obj:
            if track.object_type == obj:
                data["track"] = track
                data["ft_obs_past"] = ft_obs_past
                data["ft_obs_future"] = ft_obs_future
                data["ft_heading_angle"] = ft_heading_angle
                data["is_valid"] = True
        
        return data


av_processor = DataProcessorAV2(Av2Configs(), args=None, split="train")
loader = DataLoader(av_processor, batch_size=1, shuffle=False, collate_fn=custom_collate)

import sys
for i, data in enumerate(loader):
    if True in data["is_valid"]:
        # Build the full trajectory = ft_obs_past + ft_obs_future
        full_trajectory = np.concatenate(data["ft_obs_past"] + data["ft_obs_future"])
        # init the pos, heading_angle whereby we are normalizing the rest of the waypoints in the
        # trajectories to the first (xy) coordinate. 
        initial_position = full_trajectory[0]
        heading_angles = np.concatenate(data["ft_heading_angle"])
        initial_heading = heading_angles[0]
        # print(f"Full trajectory: {full_trajectory.dtype} \n Initial Position: {initial_position.dtype} Initial Heading: {initial_heading.dtype}")
        quantizer = TrajectoryQuantizer()
        quantized_trajectory = quantizer.quantize_trajectory(torch.tensor(full_trajectory.astype(np.float32)), 
                                                             torch.tensor(initial_position.astype(np.float32)),
                                                             torch.tensor(initial_heading.astype(np.float32))
                                                            )
        
        # print(f"Full trajectory")
        print(f"Quantized trajectory indices: \n {quantized_trajectory}")
        print(f"Vocabulary size: {len(quantizer.cartesian_mapping)}")
        
    if i>=10:
        sys.exit()
