# Using transformers for motion prediction
I thought it would be fun, and illustrative, to try using transformers for a bit more than just text. In this notebook we'll use transformers for two things:
1) to predict the motion of a small dynamical system (a pendulum) and 
2) to do a motion prediction task for vehicles, pedestrians, and cyclists.

## Setting up the dataset for motion prediction
We're using the argoverse dataset so there's a few things you'll need. 
1) You'll need to pull the argoverse code repo in. We've set it up as a submodule that you can get via
```
git submodule update --init --recursive
```
2) You'll need to install the argoverse api. You can do this by navigating into ```submodules/av2-api/conda``` and running
```
chmod +x install.sh # make the install script executable
./install.sh # run the install script
```
It does take a while! After that, you just need to actually activate the corresponding av2 conda environment.

## Actually getting the data
Go to this link [https://www.argoverse.org/av2.html#download-link and download the data by following the instructions. It's a lot of data so to pick a smaller dataset, just download the validation dataset for the motion forecasting task.](https://argoverse.github.io/user-guide/getting_started.html#downloading-the-data). **Note: you should only download the motion forecasting dataset and in particular, to keep things small, I would just keep the validation dataset around.**

Now if you have the data, lets visualize a scene.


In [2]:
from enum import Enum, unique
from pathlib import Path
from random import choices
from typing import Final

import click
from joblib import Parallel, delayed
from rich.progress import track

from av2.datasets.motion_forecasting import scenario_serialization
from av2.datasets.motion_forecasting.viz.scenario_visualization import (
    visualize_scenario,
)
from av2.map.map_api import ArgoverseStaticMap

_DEFAULT_N_JOBS: Final[int] = -2  # Use all but one CPUs


@unique
class SelectionCriteria(str, Enum):
    """Valid criteria used to select Argoverse scenarios for visualization."""

    FIRST: str = "first"
    RANDOM: str = "random"


def generate_scenario_visualizations(
    argoverse_scenario_dir: Path,
    viz_output_dir: Path,
    num_scenarios: int,
    selection_criteria: SelectionCriteria,
    *,
    debug: bool = False,
) -> None:
    """Generate and save dynamic visualizations for selected scenarios within `argoverse_scenario_dir`.

    Args:
        argoverse_scenario_dir: Path to local directory where Argoverse scenarios are stored.
        viz_output_dir: Path to local directory where generated visualizations should be saved.
        num_scenarios: Maximum number of scenarios for which to generate visualizations.
        selection_criteria: Controls how scenarios are selected for visualization.
        debug: Runs preprocessing in single-threaded mode when enabled.
    """
    Path(viz_output_dir).mkdir(parents=True, exist_ok=True)
    all_scenario_files = sorted(argoverse_scenario_dir.rglob("*.parquet"))
    scenario_file_list = (
        all_scenario_files[:num_scenarios]
        if selection_criteria == SelectionCriteria.FIRST
        else choices(all_scenario_files, k=num_scenarios)
    )  # Ignoring type here because type of "choice" is partially unknown.

    # Build inner function to generate visualization for a single scenario.
    def generate_scenario_visualization(scenario_path: Path) -> None:
        """Generate and save dynamic visualization for a single Argoverse scenario.

        NOTE: This function assumes that the static map is stored in the same directory as the scenario file.

        Args:
            scenario_path: Path to the parquet file corresponding to the Argoverse scenario to visualize.
        """
        scenario_id = scenario_path.stem.split("_")[-1]
        static_map_path = (
            scenario_path.parents[0] / f"log_map_archive_{scenario_id}.json"
        )
        viz_save_path = viz_output_dir / f"{scenario_id}.mp4"

        scenario = scenario_serialization.load_argoverse_scenario_parquet(scenario_path)
        static_map = ArgoverseStaticMap.from_json(static_map_path)
        visualize_scenario(scenario, static_map, viz_save_path)

    # Generate visualization for each selected scenario in parallel (except if running in debug mode)
    if debug:
        for scenario_path in track(scenario_file_list):
            generate_scenario_visualization(scenario_path)
    else:
        Parallel(n_jobs=_DEFAULT_N_JOBS)(
            delayed(generate_scenario_visualization)(scenario_path)
            for scenario_path in track(scenario_file_list)
        )

def run_generate_scenario_visualizations(
    argoverse_scenario_dir: str,
    viz_output_dir: str,
    num_scenarios: int,
    selection_criteria: str,
    debug: bool,
) -> None:
    """Click entry point for generation of Argoverse scenario visualizations."""
    generate_scenario_visualizations(
        Path(argoverse_scenario_dir),
        Path(viz_output_dir),
        num_scenarios,
        SelectionCriteria(selection_criteria.lower()),
        debug=debug,
    )

run_generate_scenario_visualizations("/home/eugene/code/tr-gy-8013-Fa-24/Lectures/Data/Lecture_7/av_data/val", "./", 1, "first", False)

## Okay, now lets construct data from this
We need to figure out a way to construct features from this data. To keep things easy here, I'm just going to make the features be the x, y, yaw, and velocity of the agents in the scene. We'll also need to construct the target, which will be the future x, y, yaw, and velocity of the ego agent (the agent we're trying to predict the future of). 

Of course, this is a very incomplete feature representation. What else could we add?

In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset
class ArgoverseScenarioDataset(Dataset):
    """PyTorch Dataset class for loading Argoverse scenarios from disk."""

    def __init__(self, argoverse_scenario_dir: str, num_scenarios: int, selection_criteria: str, max_num_agents: int = 50, num_points_to_predict: int = 30):
        """Initialize the ArgoverseScenarioDataset.

        Args:
            argoverse_scenario_dir: Path to local directory where Argoverse scenarios are stored.
            num_scenarios: Maximum number of scenarios to load.
            selection_criteria: Controls how scenarios are selected for loading.
        """
        self.scenario_files = sorted(Path(argoverse_scenario_dir).rglob("*.parquet"))
        self.scenario_files = (
            self.scenario_files[:num_scenarios]
            if selection_criteria == SelectionCriteria.FIRST
            else choices(self.scenario_files, k=num_scenarios)
        )
        self.max_num_agents = max_num_agents  # Maximum number of agents to consider in a scenario. We need this to create padding so that we can return tensors of fixed size shape
        self.num_points_to_predict = num_points_to_predict  # Number of points to predict in the future
        self.traj_len = 110  # Number of points in the trajectory

    def __len__(self) -> int:
        """Return the number of scenarios in the dataset."""
        return len(self.scenario_files)

    def __getitem__(self, idx: int) -> dict:
        """Load and return the Argoverse scenario at the specified index.
        
        Returns:
            dict containing:
                x: tensor of shape (max_num_agents, traj_len - num_points_to_predict, 5)
                   containing [x, y, heading, v_x, v_y] for each agent
                y: tensor of shape (num_points_to_predict, 2) containing future [x, y]
                   coordinates for the focal agent
                mask: tensor of shape (max_num_agents,) indicating which agents are present
        """
        scenario_object = scenario_serialization.load_argoverse_scenario_parquet(
            self.scenario_files[idx]
        )
        
        # Initialize tensors
        hist_len = self.traj_len - self.num_points_to_predict
        x = torch.zeros(self.max_num_agents, hist_len, 5)
        y = torch.zeros(self.num_points_to_predict, 2)
        mask = torch.zeros(self.max_num_agents)
        
        focal_track_id = scenario_object.focal_track_id
        
        for index, track in enumerate(scenario_object.tracks):
            if index >= self.max_num_agents:
                break
                
            # Extract trajectory data
            states = track.object_states
            timestamps = np.array([state.timestep for state in states])
            positions = np.array([[state.position[0], state.position[1]] for state in states])
            headings = np.array([state.heading for state in states])
            velocities = np.array([[state.velocity[0], state.velocity[1]] for state in states])
            
            # Normalize timestamps to get indices
            norm_timestamps = (timestamps - timestamps.min()).astype(int)
            
            if track.track_id == focal_track_id:
                # Fill future trajectory (labels)
                future_mask = norm_timestamps >= hist_len
                if np.any(future_mask):
                    future_positions = positions[future_mask][:self.num_points_to_predict]
                    y[:len(future_positions)] = torch.tensor(future_positions)
                    
                # Fill history for focal agent
                hist_mask = norm_timestamps < hist_len
                if np.any(hist_mask):
                    x[index, norm_timestamps[hist_mask]] = torch.tensor(np.column_stack([
                        positions[hist_mask],
                        headings[hist_mask, None],
                        velocities[hist_mask]
                    ]))
            else:
                # Fill history for other agents
                hist_mask = norm_timestamps < hist_len
                if np.any(hist_mask):
                    x[index, norm_timestamps[hist_mask]] = torch.tensor(np.column_stack([
                        positions[hist_mask],
                        headings[hist_mask, None],
                        velocities[hist_mask]
                    ]))
            
            mask[index] = 1  # Mark this agent as present
            
        return {
            'x': x,  # Input features for all agents
            'y': y,  # Future trajectory labels for focal agent
            'mask': mask,  # Mask indicating which agents are present
            'scenario_id': self.scenario_files[idx].stem  # Scenario identifier
        }

    def collate_fn(self, batch):
        """Custom collate function for batching scenarios.
        
        Args:
            batch: List of dictionaries from __getitem__
            
        Returns:
            dict containing batched tensors
        """
        return {
            'x': torch.stack([item['x'] for item in batch]),
            'y': torch.stack([item['y'] for item in batch]),
            'mask': torch.stack([item['mask'] for item in batch]),
            'scenario_ids': [item['scenario_id'] for item in batch]
        }
            
                

In [3]:
scenario_files = sorted(Path("/home/eugene/code/tr-gy-8013-Fa-24/Lectures/Data/Lecture_7/av_data/val").rglob("*.parquet"))
scenario_object = scenario_serialization.load_argoverse_scenario_parquet(scenario_files[0])

In [7]:
scenario_object.tracks[0].object_states

[ObjectState(observed=True, timestep=0, position=(-11.823673649007935, -567.4023974854001), heading=2.8502121782024448, velocity=(-10.135779773162483, 3.0359271296884756)),
 ObjectState(observed=True, timestep=1, position=(-12.39320578754739, -567.2253156611156), heading=2.8500757486227837, velocity=(-10.164790448705373, 3.0329899998597862)),
 ObjectState(observed=True, timestep=2, position=(-13.080372362384209, -567.0096938093519), heading=2.849729679583497, velocity=(-10.185671800671363, 3.031218812995912)),
 ObjectState(observed=True, timestep=3, position=(-13.880497176060121, -566.7544103059896), heading=2.8490452955231054, velocity=(-10.144088795057396, 3.0140308580389377)),
 ObjectState(observed=True, timestep=4, position=(-14.773626809536655, -566.4653864371573), heading=2.8479616883793915, velocity=(-10.107290461858751, 2.997264820097428)),
 ObjectState(observed=True, timestep=5, position=(-15.74815042844231, -566.1455193078166), heading=2.846542156750936, velocity=(-10.0823344