In [None]:
%load_ext autoreload
%autoreload 2
from data.utils import *
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import random

from forecasting.lstm import LSTMModel, MotionPredictionDataset, k_best_loss

In [None]:
MAX_LEN = 41
VEHICLE_CLASSES = ("car", "truck", "bus", "construction_vehicle", "emergency_vehicle")
FEATURE_SCALING = np.array([40, 40, 5, 5, 1])
CENTER_FEATURE_SCALING = 4

def delta_to_heading(delta):
    return np.arctan2(delta[0], delta[1])

def wrap_radians(radians):
    """
    Wraps radians to the -pi, pi range.

    Parameters:
    radians (float): The radians to be wrapped.

    Returns:
    float: The wrapped radians.
    """
    return (radians + np.pi) % (2 * np.pi) - np.pi

def preprocess_data(labels, center_features=None, add_turn_label=False, max_len=20, prediction_len=6):
    # group tracks
    vehicle_labels = filter_by_class_names(labels, VEHICLE_CLASSES)
    grouped_tracks = defaultdict(dict)
    for seq_id, frames in vehicle_labels.items():
        for timestep, frame in enumerate(frames):
            for instance in array_dict_iterator(frame, len(frame["translation"])):
                grouped_tracks[f"{seq_id}:{instance['track_id']}"][timestep] = instance

    data = {}
    for id, instance_by_ts in progressbar(
        grouped_tracks.items(), desc="preprocessing track data"
    ):
        seq_id, instance_id = id.split(":")
        class_label = next(iter(instance_by_ts.values()))["label"]
        features, targets, masks = [], [], []
        # check ego distance
        if np.any([
            np.linalg.norm(instance["translation"][:2] - np.array(instance["ego_translation"])[:2]) >= 40
            for instance in instance_by_ts.values()
        ]):
            continue

        # change all sequences to start at timestep=0
        min_ts = min(instance_by_ts.keys())
        instance_by_ts = {
            ts - min_ts: instance for ts, instance in instance_by_ts.items()
        }
        sorted_ts = list(sorted(instance_by_ts.keys()))
        last_timestep = sorted_ts[-1]

        # linearly interpolate missing translations
        translations = {}
        for timestep in range(last_timestep+1):
            if timestep in instance_by_ts:
                translations[timestep] = instance_by_ts[timestep]["translation"][:2]
            else:
                prev_timestep = next(ts for ts in reversed(sorted_ts) if ts <= timestep)
                next_timestep = next(ts for ts in sorted_ts if ts >= timestep)
                translation = (
                    instance_by_ts[next_timestep]["translation"][:2] * (timestep - prev_timestep)
                    + instance_by_ts[prev_timestep]["translation"][:2] * (next_timestep - timestep)
                ) / (next_timestep - prev_timestep)
                translations[timestep] = translation

        for timestep in range(max_len):
            # if features are missing, default to zeros
            feature_size = 5
            if add_turn_label:
                feature_size += 1
            if center_features is not None:
                feature_size += len(next(iter(center_features.values())))
            feature = np.zeros(feature_size)
            target = np.zeros((prediction_len, 2))
            mask = np.zeros(prediction_len)
            if timestep in instance_by_ts:
                instance = instance_by_ts[timestep]
                # format features (translation, velocity)
                translation_delta = (
                    instance["translation"] - instance_by_ts[0]["translation"]
                )
                # normalize inputs
                feature = np.nan_to_num(
                    np.concatenate(
                        [
                            translation_delta[:2],
                            instance["velocity"][:2],
                            np.array([instance["yaw"]]),
                        ]
                    )
                    / FEATURE_SCALING
                )
                # transform target to translations
                future_timesteps = list(range(
                    timestep + 1, min(timestep + prediction_len + 1, last_timestep + 1)
                ))
                for i, future_ts in enumerate(future_timesteps):
                    # target[i] = translations[future_ts] - instance["translation"][:2]
                    target[i] = translations[future_ts] - translations[future_ts - 1]
                    mask[i] = 1

                if center_features is not None:
                    center_feature = center_features[f"{seq_id}:{instance['timestamp_ns']}:{instance_id}"] / CENTER_FEATURE_SCALING
                    feature = np.concatenate([feature, center_feature])

                if add_turn_label:
                    future_timesteps = [ts for ts in future_timesteps if ts in instance_by_ts]
                    if len(future_timesteps) == 0:
                        turn_radian = 0
                    else:
                        final_delta = instance_by_ts[future_timesteps[-1]]["translation"] - instance["translation"]
                        turn_radian = np.nan_to_num(
                            wrap_radians(delta_to_heading(final_delta) - instance["yaw"])
                        ) if np.linalg.norm(final_delta) > 2 else 0
                    turn_indicator = 0
                    if 0.1 * np.pi <= turn_radian:
                        turn_indicator = 1
                    if turn_radian <= -0.1 * np.pi:
                        turn_indicator = -1
                    feature = np.concatenate([feature, np.array([turn_indicator])])

            features.append(feature)
            targets.append(target)
            masks.append(mask)

        data[id] = {
            "feature": np.stack(features, axis=0),
            "target": np.stack(targets, axis=0),
            "target_mask": np.stack(masks, axis=0),
            "class_label": class_label,
            "min_ts": min_ts,
        }

    return data

class CenterFeatureModel(LSTMModel):
    def __init__(self, input_dim, prediction_len, k, **kwargs):
        super().__init__(input_dim - 128, prediction_len, k, **kwargs)
        self.center_feature_proj = nn.Linear(128, self.embedding_dim)
        nn.init.kaiming_uniform_(self.center_feature_proj.weight)

    def forward(self, input):
        center_features = input[:, :, -128:]
        input = input[:, :, :-128]
        B, L, D = input.shape
        embedding = (
            self.input_proj(input.reshape(B * L, -1)).reshape(B, L, -1) +
            self.center_feature_proj(center_features.reshape(B * L, -1)).reshape(B, L, -1)
        )
        x = F.relu(embedding)
        for lstm_layer in self.lstm_layers:
            x_out, state = lstm_layer(x)
            x = x + self.dropout(x_out)
        output = self.output_proj(x.reshape(B * L, -1)).reshape(
            B, L, self.k, self.prediction_len, -1
        )
        return output

shared_values = lambda d1, d2: ((d1[key], d2[key]) for key in set(d1.keys()).intersection(d2.keys()))

def run_and_evaluate_inference(model, dataset):
    dataloader = DataLoader(
        dataset,
        batch_size=64,
        shuffle=False,
    )
    model = model.to(device).eval()
    forecasts = defaultdict(lambda: defaultdict(dict))
    start_idx = 0
    for input, _, target, mask in dataloader:
        input, target = input.to(device), target.to(device)
        with torch.no_grad():
            prediction = model(input)
        prediction = prediction.cpu().numpy()
        ids = dataset.keys[start_idx:start_idx + len(input)]
        start_idx += len(input)
        for id, prediction_ot in zip(ids, prediction):
            start_ts = dataset.data[id]["min_ts"]
            seq_id, track_id = id.split(":")
            for i, delta_at_t in enumerate(prediction_ot):
                timestep = start_ts + i
                if timestep >= len(labels[seq_id]):
                    continue
                detection_frame = labels[seq_id][timestep]
                if track_id not in detection_frame["track_id"]:
                    continue
                timestamp = detection_frame["timestamp_ns"]
                detection = index_array_values(detection_frame, list(detection_frame["track_id"]).index(track_id))
                current_translation = detection["translation"][:2]
                prediction_at_t = current_translation + np.cumsum(delta_at_t, axis=1)
                forecasts[seq_id][timestamp][track_id] = prediction_at_t
    forecasts = {seq_id: dict(preds_by_ts) for seq_id, preds_by_ts in forecasts.items()}

    errors_by_profile = {p: {"ade": [], "fde": []} for p in velocity_profile}
    for forecast_frames, forecast_label_frames in shared_values(forecasts, forecast_label):
        for prediction_by_instance_id, label_list in shared_values(forecast_frames, forecast_label_frames):
            label_by_instance_id = {label_agent["instance_id"]: label_agent for label_agent in label_list}
            for prediction, label_agent in shared_values(prediction_by_instance_id, label_by_instance_id):
                profile = label_agent["trajectory_type"]
                label_length = label_agent["future_translation"].shape[0]
                errors = np.linalg.norm(prediction[:, :label_length] - label_agent["future_translation"][np.newaxis], axis=-1)
                ade = errors.mean(axis=1).min(axis=0)
                errors_by_profile[profile]["ade"].append(ade)
                if label_length == PREDICTION_LENGTH:
                    fde = errors[:, -1].min(axis=0)
                    errors_by_profile[profile]["fde"].append(fde)

    metric_by_profile = {
        p: {
                "ade": np.mean(errors_by_profile[p]["ade"]),
                "fde": np.mean(errors_by_profile[p]["fde"]),
                "count": len(errors_by_profile[p]["ade"]),
            }
        for p in velocity_profile
    }
    return forecasts, metric_by_profile


In [None]:
PREDICTION_LENGTH = 6
K = 5

train_labels = load(f"dataset/nuscenes-train/labels.pkl")
labels = load(f"dataset/nuscenes-val/labels.pkl")
center_features = load("dataset/center_features.pkl")

# evaluation
from forecasting.evaluate import convert_forecast_labels, trajectory_type, av2_velocity, velocity_profile
forecast_label = convert_forecast_labels(labels, PREDICTION_LENGTH, np.inf)
for frames in forecast_label.values():
    for frame in frames.values():
        for agent in frame:
            agent["trajectory_type"] = trajectory_type(agent, av2_velocity, PREDICTION_LENGTH)


In [None]:
# RUN oracle turn signals experiment
seed = 0
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
learning_rate = 1e-3
device = "cuda"
sweep_results = {}

use_center_features, add_turn_label = False, True
K = 5
train_data = preprocess_data(
    train_labels, max_len=MAX_LEN, prediction_len=PREDICTION_LENGTH,
    add_turn_label=add_turn_label, center_features=(center_features if use_center_features else None),
    )
train_dataset = MotionPredictionDataset(train_data, PREDICTION_LENGTH)
data = preprocess_data(
    labels, max_len=MAX_LEN, prediction_len=PREDICTION_LENGTH,
    add_turn_label=add_turn_label, center_features=(center_features if use_center_features else None),
)
dataset = MotionPredictionDataset(data, PREDICTION_LENGTH)
dataloader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
)
if use_center_features:
    model = CenterFeatureModel(train_dataset.input_dim, train_dataset.prediction_len, k=K)
else:
    model = LSTMModel(train_dataset.input_dim, train_dataset.prediction_len, k=K)

# run
model = model.to(device).train()
optim = torch.optim.Adam(model.parameters(), learning_rate)
loss_traj = []

for epoch in range(50):
    epoch_loss = []
    for input, _, target, mask in dataloader:
        input, target, mask = input.to(device), target.to(device), mask.to(device)
        prediction = model(input)

        optim.zero_grad()
        loss = k_best_loss(prediction, target, mask)
        loss.backward()
        optim.step()

        epoch_loss.append(loss.detach().cpu().item())
    loss_traj.append(np.mean(epoch_loss))

# torch.save(model, f"models/model_dt_{K}{'_cf' if use_center_features else ''}{'_tl' if add_turn_label else ''}.pt")
# val
dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=False,
)
model = model.to(device).eval()

val_loss = []
for input, _, target, mask in dataloader:
    input, target, mask = input.to(device), target.to(device), mask.to(device)
    with torch.no_grad():
        prediction = model(input)
        val_loss.append(k_best_loss(prediction, target, mask).cpu().item())

_, metric_by_profile = run_and_evaluate_inference(model, dataset)
print(metric_by_profile)
experiment_results = {"metrics": metric_by_profile, "validation_loss": np.mean(val_loss), "train_loss": loss_traj}
sweep_results[(K, use_center_features, add_turn_label)] = experiment_results
# save(sweep_results, "sweep_results_dt.pkl")

In [None]:
# RUN K experiment
seed = 0
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
learning_rate = 1e-3
device = "cuda"
sweep_results = {}

for use_center_features, add_turn_label in [(False, False), (True, False)]:
    train_data = preprocess_data(
        train_labels, max_len=MAX_LEN, prediction_len=PREDICTION_LENGTH,
        add_turn_label=add_turn_label, center_features=(center_features if use_center_features else None),
        )
    train_dataset = MotionPredictionDataset(train_data, PREDICTION_LENGTH)
    data = preprocess_data(
        labels, max_len=MAX_LEN, prediction_len=PREDICTION_LENGTH,
        add_turn_label=add_turn_label, center_features=(center_features if use_center_features else None),
    )
    dataset = MotionPredictionDataset(data, PREDICTION_LENGTH)
    dataloader = DataLoader(
        train_dataset,
        batch_size=64,
        shuffle=True,
    )
    for K in tqdm([1, 3, 5], desc="sweeping K"):
        if use_center_features:
            model = CenterFeatureModel(train_dataset.input_dim, train_dataset.prediction_len, k=K)
        else:
            model = LSTMModel(train_dataset.input_dim, train_dataset.prediction_len, k=K)

        # run
        model = model.to(device).train()
        optim = torch.optim.Adam(model.parameters(), learning_rate)
        loss_traj = []

        for epoch in range(50):
            epoch_loss = []
            for input, _, target, mask in dataloader:
                input, target, mask = input.to(device), target.to(device), mask.to(device)
                prediction = model(input)

                optim.zero_grad()
                loss = k_best_loss(prediction, target, mask)
                loss.backward()
                optim.step()

                epoch_loss.append(loss.detach().cpu().item())
            loss_traj.append(np.mean(epoch_loss))

        # torch.save(model, f"models/model_dt_{K}{'_cf' if use_center_features else ''}{'_tl' if add_turn_label else ''}.pt")
        # val
        dataloader = DataLoader(
            dataset,
            batch_size=64,
            shuffle=False,
        )
        model = model.to(device).eval()

        val_loss = []
        for input, _, target, mask in dataloader:
            input, target, mask = input.to(device), target.to(device), mask.to(device)
            with torch.no_grad():
                prediction = model(input)
                val_loss.append(k_best_loss(prediction, target, mask).cpu().item())

        _, metric_by_profile = run_and_evaluate_inference(model, dataset)
        print(metric_by_profile)
        experiment_results = {"metrics": metric_by_profile, "validation_loss": np.mean(val_loss), "train_loss": loss_traj}
        sweep_results[(K, use_center_features, add_turn_label)] = experiment_results
# save(sweep_results, "sweep_results_dt.pkl")

In [None]:
print(sweep_results[(K, False, False)]["metrics"])
print(sweep_results[(K, True, False)]["metrics"])
print(sweep_results[(K, False, True)]["metrics"])


In [None]:
seed = 0
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)


use_center_features = False
add_turn_label = False
train_data = preprocess_data(
    train_labels, max_len=MAX_LEN, prediction_len=PREDICTION_LENGTH,
    add_turn_label=add_turn_label, center_features=(center_features if use_center_features else None),
    )
train_dataset = MotionPredictionDataset(train_data, PREDICTION_LENGTH)
data = preprocess_data(
    labels, max_len=MAX_LEN, prediction_len=PREDICTION_LENGTH,
    add_turn_label=add_turn_label, center_features=(center_features if use_center_features else None),
)
dataset = MotionPredictionDataset(data, PREDICTION_LENGTH)

# model = torch.load(f"models/model_dt_5{'_cf' if use_center_features else ''}{'_tl' if add_turn_label else ''}.pt")


In [None]:
dataloader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
)
if use_center_features:
    model = CenterFeatureModel(train_dataset.input_dim, train_dataset.prediction_len, k=K)
else:
    model = LSTMModel(train_dataset.input_dim, train_dataset.prediction_len, k=K)

# run
learning_rate = 1e-3
device = "cuda"
model = model.to(device).train()
optim = torch.optim.Adam(model.parameters(), learning_rate)
loss_traj = []

for epoch in range(50):
    epoch_loss = []
    for input, _, target, mask in dataloader:
        input, target, mask = input.to(device), target.to(device), mask.to(device)
        prediction = model(input)

        optim.zero_grad()
        loss = k_best_loss(prediction, target, mask)
        loss.backward()
        optim.step()

        epoch_loss.append(loss.detach().cpu().item())

    # print(epoch, np.mean(epoch_loss))
    loss_traj.append(np.mean(epoch_loss))

# val
dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=False,
)
model = model.to(device).eval()

val_loss = []
for input, _, target, mask in dataloader:
    input, target, mask = input.to(device), target.to(device), mask.to(device)
    with torch.no_grad():
        prediction = model(input)
        val_loss.append(k_best_loss(prediction, target, mask).cpu().item())
print("validation loss:", np.mean(val_loss))


In [None]:
# torch.save(model, f"models/model{'_cf' if use_center_features else ''}{'_tl' if add_turn_label else ''}.pt")


In [None]:
# inference

import json
_, metric_by_profile = run_and_evaluate_inference(model ,dataset)
print(json.dumps(metric_by_profile, indent=4))


In [None]:
from nuscenes import NuScenes

data_root = '/data/ashen3/datasets/nuScenes'
nusc = NuScenes("v1.0-trainval", data_root)

In [None]:
add_turn_label = False
use_center_features = False
data = preprocess_data(
    labels, max_len=MAX_LEN, prediction_len=PREDICTION_LENGTH,
    add_turn_label=add_turn_label, center_features=(center_features if use_center_features else None),
)
dataset = MotionPredictionDataset(data, PREDICTION_LENGTH)
model = torch.load(f"models/model_dt_5{'_cf' if use_center_features else ''}{'_tl' if add_turn_label else ''}.pt")
baseline_forecasts, _ = run_and_evaluate_inference(model, dataset)

use_center_features = True
data = preprocess_data(
    labels, max_len=MAX_LEN, prediction_len=PREDICTION_LENGTH,
    add_turn_label=add_turn_label, center_features=(center_features if use_center_features else None),
)
dataset = MotionPredictionDataset(data, PREDICTION_LENGTH)
model = torch.load(f"models/model_dt_5{'_cf' if use_center_features else ''}{'_tl' if add_turn_label else ''}.pt")
cf_forecasts, _ = run_and_evaluate_inference(model, dataset)

In [None]:
%matplotlib inline
from nuscenes.map_expansion.map_api import NuScenesMap

def get_nusc_map(nusc, scene_token):
    scene_rec = nusc.get('scene', scene_token)
    log_record = nusc.get('log', scene_rec['log_token'])
    map_name = log_record['location']
    nusc_map = NuScenesMap(dataroot=data_root, map_name=map_name)
    return nusc_map

def plot_forecasts(forecasts, seq_id, timestamp, margin=10):
    prediction_by_instance_id = forecasts[seq_id][timestamp]
    label_by_instance_id = {agent["instance_id"]: agent for agent in forecast_label[seq_id][timestamp]}
    locations = []
    for prediction, label_agent in shared_values(prediction_by_instance_id, label_by_instance_id):
        if label_agent["trajectory_type"] != "static":
            locations.append(prediction)
    locations = np.stack(locations, axis=0).reshape(-1, 2)
    m = margin
    box_coords = locations[:, 0].min() - m, locations[:, 1].min() - m, locations[:, 0].max() + m, locations[:, 1].max() + m
    layer_names = [
    'drivable_area',
    'road_segment',
    'road_block',
    'lane',
    'stop_line',
    'road_divider',
    'lane_divider',
    ]
    nusc_map.render_map_patch(box_coords, layer_names=layer_names, figsize=(6, 6))

    for prediction, label_agent in shared_values(prediction_by_instance_id, label_by_instance_id):
        loc = label_agent["current_translation"]
        plt.scatter(*loc, c="black")
        # skip static agents
        if label_agent["trajectory_type"] == "static":
            continue
        gt_future = label_agent["future_translation"]
        # only plot best prediction
        best_pred_i = np.argmin(np.linalg.norm(prediction[:, :len(gt_future)] - gt_future[np.newaxis], axis=-1).mean(1))
        prd_trajectory = np.concatenate([loc[np.newaxis], prediction[best_pred_i]])
        plt.plot(prd_trajectory[:, 0], prd_trajectory[:, 1], c="blue")
        # plot all predictions
        for predicted_mode in prediction:
            prd_trajectory = np.concatenate([loc[np.newaxis], predicted_mode])
            plt.plot(prd_trajectory[:, 0], prd_trajectory[:, 1], c="blue")
        # plot ground truth
        gt_trajectory = np.concatenate([loc[np.newaxis], gt_future])
        plt.plot(gt_trajectory[:, 0], gt_trajectory[:, 1], c="black")

    plt.xlim(box_coords[0], box_coords[2])
    plt.ylim(box_coords[1], box_coords[3])
    plt.show()

# visualize predictions
seq_id = 'ed242d80ccb34b139aaf9ab89859332e'
timestamp = 1535730476396726000
nusc_map = get_nusc_map(nusc, seq_id)

plot_forecasts(cf_forecasts, seq_id, timestamp)
plt.figure()
plot_forecasts(baseline_forecasts, seq_id, timestamp)


In [None]:
%matplotlib widget
plot_forecasts(cf_forecasts, seq_id, timestamp, margin=20)
# plt.xlim(1415, 1440)
# plt.ylim(1220, 1255)
plt.xlim(1430, 1465)
plt.ylim(1225, 1265)
plt.savefig(f"plots/visual_features_neg_{seq_id}.png")


In [None]:
%matplotlib widget
plot_forecasts(baseline_forecasts, seq_id, timestamp, margin=20)
# plt.xlim(1415, 1440)
# plt.ylim(1220, 1255)
plt.xlim(1430, 1465)
plt.ylim(1225, 1265)
plt.savefig(f"plots/baseline_neg_{seq_id}.png")


In [None]:
frames_and_count = []
shared_items = lambda d1, d2: ((key, (d1[key], d2[key])) for key in set(d1.keys()).intersection(d2.keys()))
for seq_id, (forecast_frames, forecast_label_frames) in shared_items(forecasts, forecast_label):
    for timestamp, (prediction_by_instance_id, label_list) in shared_items(forecast_frames, forecast_label_frames):
        label_by_instance_id = {label_agent["instance_id"]: label_agent for label_agent in label_list}
        num_non_static = 0
        for prediction, label_agent in shared_values(prediction_by_instance_id, label_by_instance_id):
            profile = label_agent["trajectory_type"]
            num_non_static += profile != "static"
        frames_and_count.append((num_non_static, seq_id, timestamp))

sorted(frames_and_count, reverse=True)[:50]