In [None]:
import os
import warnings

# Set the environment variable to suppress the tensorflow warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

# Suppress the warnings
warnings.filterwarnings("ignore")

In [None]:
import math
import os.path as osp
from typing import Any, Literal, Optional

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import torch

from planner.data import AV2DataModule
from planner.data.dataclass import Scenario

In [None]:
# Constants
CHECKPOINT_DIR = osp.abspath("../logs/train_ddp/runs/senevam_av2_radius_50")
DATA_DIR = osp.abspath("../data")

In [None]:
dm = AV2DataModule(
    batch_size=4,
    root=osp.join(DATA_DIR, "av2"),
    num_workers=2,
    pin_memory=True,
    radius=None,
)

In [None]:
device = (
    torch.device("cuda:0")
    if torch.cuda.is_available()
    else torch.device("cpu")
    # torch.device("cpu")
)
dl = iter(dm.val_dataloader())
ds = dm.val_dataset

In [None]:
from planner.model import SeNeVAMLightningModule

ckpt_dir = osp.join(CHECKPOINT_DIR, "2025-04-05_01-57-40/checkpoints")
model = SeNeVAMLightningModule.load_from_checkpoint(
    os.path.join(ckpt_dir, "epoch_032_b2ed65b.ckpt"),
    map_location=device,
    strict=True,
)

In [None]:
data = next(dl)
assert isinstance(data, Scenario)
data = data.to(device=device)
print(data)

In [None]:
with torch.no_grad():
    current, curr_valid = model.get_current(scenario=data, include_sdc=False)
    target, tar_valid = model.get_target(scenario=data, include_sdc=False)
    output = model.forward(scenario=data, horizon=60, include_sdc=False)
print(current.shape, curr_valid.shape)
print(target.shape, tar_valid.shape)
print(output.y_means.shape, output.y_covars.shape)

In [None]:
# Helper function: Plot the heatmap of uncertainty
# NOTE: this function is for single-agent heatmap plotting
# TODO: consider about adding log-prob up or other ways to aggregate
# multi-agent probabilities
def plot_full_uncertainty(
    means: npt.NDArray,
    covars: npt.NDArray,
    mixtures: npt.NDArray,
    ax: Optional[plt.Axes] = None,
    n_std: int = 1,
    heatmap_type: Literal["prob", "log_prob", "logit"] = "log_prob",
    colorbar: bool = True,
    *args: Any,
    **kwargs: Any,
) -> plt.Axes:
    if ax is None:
        _, ax = plt.subplots(1, 1)

    # step 1: obtain the rectangle region covering at least 68%
    # of the probability density function
    minx = np.min(means[:, :, 0] - n_std * np.sqrt(covars[..., 0, 0]))
    maxx = np.max(means[:, :, 0] + n_std * np.sqrt(covars[..., 0, 0]))
    miny = np.min(means[:, :, 1] - n_std * np.sqrt(covars[..., 1, 1]))
    maxy = np.max(means[:, :, 1] + n_std * np.sqrt(covars[..., 1, 1]))

    probs, mask = [], []

    # NOTE: efficient way to evaluate prob
    x, y = np.linspace(minx, maxx, 100), np.linspace(miny, maxy, 100)
    x, y = np.meshgrid(x, y)
    points = np.vstack((x.ravel(), y.ravel())).T  # shape (10000, 2)

    for mean, covar, pi in zip(means, covars, mixtures):
        # calculate the Gaussian probability density function
        inv_covar = np.linalg.inv(covar)  # shape (T, 2, 2)
        det_covar = np.linalg.det(covar)  # shape (T,)
        diff = points[:, None, ...] - mean[None, ...]  # shape (10000, T, 2)
        mahalanobis = np.einsum("btj, tij, bti -> bt", diff, inv_covar, diff)
        log_prob = -0.5 * (
            mahalanobis + np.log(det_covar + 1e-10) + 2 * np.log(2 * np.pi)
        )
        log_prob += np.log(pi + 1e-10)
        probs.append(np.exp(log_prob))

        # TODO: aggregate the probabilities from multiple agents?

        # NOTE: filter out outliers
        mask.append(
            np.reshape(
                np.all(mahalanobis > n_std * math.sqrt(2), -1),
                (100, 100),
            ).astype(bool)
        )
    mask = np.all(mask, axis=0)

    probs = np.stack(probs, axis=0)
    probs = np.mean(np.sum(probs, axis=0), axis=-1)
    probs = np.reshape(probs, shape=(100, 100))
    probs = np.ma.masked_where(mask, probs)
    if heatmap_type == "log_prob":
        probs = np.log(probs + 1e-10)
        heatmap_name = "Log-Probability"
    elif heatmap_type == "logit":
        probs = np.log((probs + 1e-10) / (1 - probs - 1e-10))
        heatmap_name = "Logit"
    else:
        heatmap_name = "Probability"
    cbar = ax.contourf(x, y, probs, *args, **kwargs)

    if colorbar:
        if len(ax.get_figure().axes) < 2:
            cbar = ax.get_figure().colorbar(cbar, ax=ax)
            cbar.ax.set_ylabel(
                heatmap_name + " to be visited",
                rotation=90,
                labelpad=5,
                fontdict={"color": "white"},
            )
            cbar.ax.set_yticklabels(
                cbar.ax.get_yticklabels(), fontdict={"color": "white"}
            )

    return ax

In [None]:
%matplotlib widget
import matplotlib.pyplot as plt

from planner.data.viz import plot_scenario

BATCH_INDEX = 1

fig, ax = plt.subplots(1, 1)

# plot the scenario
ax = plot_scenario(
    scenario=data[BATCH_INDEX].to("cpu"), ax=ax, crop_to_bounds=False
)
scenario_id = data[BATCH_INDEX].scenario_id.decode("utf-8")
map_api = ds.get_map_api(scenario_id=scenario_id)
for _, area in map_api.vector_drivable_areas.items():
    ax.fill(area.xyz[..., 0], area.xyz[..., 1], color="#5E5E5D", zorder=0)

# plot the ground-truth observations
cum_target = target.cumsum(dim=-2) + current
for tar, val in zip(cum_target[BATCH_INDEX], tar_valid[BATCH_INDEX]):
    tar = tar[val]
    ax.plot(
        tar[..., 0].cpu().numpy(),
        tar[..., 1].cpu().numpy(),
        "g-",
        lw=2,
        alpha=0.75,
        zorder=20,
    )

# plot the predictions
cum_y_means = output.y_means.cumsum(dim=-2) + current.unsqueeze(-3)
cum_y_covars = output.y_covars.cumsum(dim=-3)
for y_means, val in zip(cum_y_means[BATCH_INDEX], tar_valid[BATCH_INDEX]):
    for y_mean in y_means:
        xy = y_mean[val][..., 0:2].cpu().numpy()
        ax.plot(
            xy[..., 0],
            xy[..., 1],
            "r--",
            lw=2,
            alpha=0.5,
            zorder=15,
        )

ax = plot_full_uncertainty(
    means=cum_y_means[BATCH_INDEX, 0, ..., 0:2].cpu().numpy(),
    covars=cum_y_covars[BATCH_INDEX, 0, ..., 0:2, 0:2].cpu().numpy(),
    mixtures=np.ones(6) / 6,
    ax=ax,
    n_std=1,
    heatmap_type="log_prob",
    alpha=0.75,
    zorder=15,
)
fig.set_facecolor("#000000")

In [None]:
from planner.model.function.eval import MinADE

min_ade = MinADE().to(device=device)
min_ade(
    input_xy=cum_y_means[:, ..., 0:2],
    target_xy=cum_target.unsqueeze(-3)[:, ..., 0:2],
    valid=tar_valid.unsqueeze(-2)[:],
)
min_ade.compute()