In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import typing

import numpy as np
import pandas as pd
import seaborn as sns
import torch
from matplotlib import pyplot as plt

from torchcps.kernel.nn import sample_kernel, GaussianKernel, Mixture

from mtt.data.sparse import SparseData, SparseDataset
from mtt.models.sparse import SparseLabel
from mtt.models.kernel import KNN, SparseBase
from mtt.models.utils import load_model
from mtt.models.sparse import kernel_loss

from mtt.utils import compute_ospa, compute_ospa_components

rng = np.random.default_rng(0)
sns.set_theme(
    context="paper",
    style="whitegrid",
    rc={
        "figure.figsize": (3.5, 3.5),
        "figure.dpi": 150,
        "savefig.dpi": 1000,
        "figure.constrained_layout.use": True,
        "pdf.fonttype": 42,
    },
)

In [None]:
model: SparseBase
model, name, params = load_model("wandb://damowerko/mtt/v1au78ao")
model = typing.cast(KNN, model.eval())

In [None]:
dataset = SparseDataset(
    f"data/train",
    length=params["input_length"],
    slim=True,
)

In [None]:
rng = np.random.default_rng(0)

metrics = []
for _ in range(100):
    sim_idx = rng.integers(0, len(dataset))
    step_idx = rng.integers(0, 100)
    data = dataset.get(sim_idx, step_idx)

    data = SparseData(*(x.cuda() for x in data))
    with torch.no_grad():
        input = model.forward_input(data)
        label = SparseLabel.from_sparse_data(data, model.input_length)
        output = model.forward(input)
        mse_model = kernel_loss(
            output.mu,
            output.logp.exp(),
            label.y,
            model.kernel_sigma,
        ).item()
        logp_naive = torch.full_like(output.logp, output.logp.exp().mean().log().item())
        mse_naive = kernel_loss(
            output.mu,
            logp_naive.exp(),
            label.y,
            model.kernel_sigma,
        ).item()

        output_cardinality = output.logp.exp().sum(-1).item()
        target_cardinality = label.y.shape[0]

        metrics.append(
            dict(
                mse_model=mse_model,
                mse_naive=mse_naive,
                output_cardinality=output_cardinality,
                target_cardinality=target_cardinality,
            )
        )
metrics_df = pd.DataFrame(metrics)

In [None]:
print(
    f"MSE Model: {metrics_df.mse_model.mean()} , MSE Naive: {metrics_df.mse_naive.mean()}"
)

In [None]:
metrics_df["cardinality_error"] = (
    metrics_df["output_cardinality"] - metrics_df["target_cardinality"]
)
sns.displot(metrics_df, x="cardinality_error")
plt.show()

In [None]:
from math import ceil
from mtt.peaks import sample_rkhs, fit_gmm, fit_kmeans

rng = np.random.default_rng(0)
method = "gmm"
ospa = []
ospa_cardinality = []
ospa_distance = []
for _ in range(100):
    sim_idx = rng.integers(0, len(dataset))
    step_idx = rng.integers(0, 100)
    data = dataset.get(sim_idx, step_idx)

    data = SparseData(*(x.cuda() for x in data))
    with torch.no_grad():
        input = model.forward_input(data)
        label = SparseLabel.from_sparse_data(data, model.input_length)
        output = model.forward(input)

        # If the output only has positive values, its L1 norm is the total weights.
        n_components = ceil(output.logp.exp().sum(-1).item())
        # interpret RKHS as likelihood and sample from it
        samples = sample_rkhs(
            output.mu.cpu().numpy(),
            output.logp.exp().cpu().numpy(),
            model.kernel_sigma,
            1000,
        )
        # Fit gaussian mixture model to find peaks.
        if method == "gmm":
            peaks = fit_gmm(samples, n_components=n_components)
        elif method == "kmeans":
            peaks = fit_kmeans(
                samples,
                n_components=n_components,
                n_components_range=2,
            )
        else:
            raise ValueError(f"Unknown model: {method}")
        X = peaks.means

        # kmeans = KMeans(n_components)
        # cluster = torch.from_numpy(
        #     kmeans.fit_predict(
        #         output.mu.cpu().numpy(), sample_weight=output.logp.exp().cpu().numpy()
        #     )
        # ).long()
        # prob_sum = scatter_add(output.logp.exp().cpu(), cluster).numpy()

        ospa.append(
            compute_ospa(
                label.y.cpu().numpy(),
                X,
                500,
                2,
            )
        )
        ospa_components = compute_ospa_components(
            label.y.cpu().numpy(),
            X,
            500,
            1,
        )
        ospa_distance.append(ospa_components[0])
        ospa_cardinality.append(ospa_components[1])

In [None]:
print(f"OSPA: mean = {np.mean(ospa)}, std = {np.std(ospa)}")
print(
    f"OSPA Cardinality: mean = {np.mean(ospa_cardinality)}, std = {np.std(ospa_cardinality)}"
)
print(f"OSPA Distance: mean = {np.mean(ospa_distance)}, std = {np.std(ospa_distance)}")

In [None]:
with sns.axes_style("ticks"), torch.no_grad():
    fig, axs = plt.subplots(5, 3, figsize=(15, 25))
    for i in range(5):
        sim_idx = np.random.randint(len(dataset))
        step_idx = 50
        data = dataset.get(sim_idx, step_idx)
        data = SparseData(*(x.cuda() for x in data))

        input = model.forward_input(data)
        label = SparseLabel.from_sparse_data(data, model.input_length)
        output = model.forward(input)

        measurements = data.measurement_position.cpu().numpy()
        positions = label.y.cpu().numpy()
        estimates = output.mu.cpu().numpy()
        probs = output.logp.exp().cpu().numpy()

        axs[i, 0].plot(*estimates.T, "b.", label="Output Samples")
        axs[i, 0].plot(*measurements.T, "g+", label="Measurement", alpha=1)
        axs[i, 0].plot(*positions.T, "ro", label="Target")

        kernel = GaussianKernel(model.sigma)

        XY = torch.cartesian_prod(*[torch.linspace(-500, 500, 512)] * 2).cuda()
        input_image = (
            sample_kernel(
                kernel,
                Mixture(
                    input.positions.contiguous(), input.weights[..., -1].contiguous()
                ),
                XY,
            )
            .weights.reshape(512, 512)
            .cpu()
            .numpy()
        )
        output_image = (
            sample_kernel(
                kernel,
                Mixture(output.mu.contiguous(), output.logp.exp().contiguous()),
                XY,
            )
            .weights.reshape(512, 512)
            .cpu()
            .numpy()
        )

        axs[i, 1].plot(*positions.T, "ro", label="Target")
        axs[i, 1].imshow(
            input_image.T, extent=(-500, 500, -500, 500), origin="lower", cmap="viridis"
        )
        axs[i, 1].set_title(f"Input, Sim: {sim_idx}, Step: {step_idx}")

        axs[i, 2].plot(*positions.T, "ro", label="Target")
        axs[i, 2].imshow(
            output_image.T,
            extent=(-500, 500, -500, 500),
            origin="lower",
            cmap="viridis",
        )
        axs[i, 2].set_title(f"Output, Sim: {sim_idx}, Step: {step_idx}")

        for ax in axs[i]:
            ax.set_xlim(-500, 500)
            ax.set_ylim(-500, 500)
            ax.set_aspect("equal")
            ax.legend(loc="upper right")