In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import typing

import numpy as np
import seaborn as sns
import torch
from matplotlib import pyplot as plt

from mtt.data.sparse import SparseData, SparseDataset
from mtt.models.kernel import KNN, SparseBase
from mtt.models.transformer import SpatialTransformer
from mtt.models.utils import load_model
from mtt.utils import compute_ospa, compute_ospa_components

rng = np.random.default_rng(0)
sns.set_theme(
    context="paper",
    style="whitegrid",
    rc={
        "figure.figsize": (3.5, 3.5),
        "figure.dpi": 150,
        "savefig.dpi": 1000,
        "figure.constrained_layout.use": True,
        "pdf.fonttype": 42,
    },
)

In [None]:
model: SparseBase
model, name, params = load_model("wandb://damowerko/mtt/wfgumxwj")  # mse_sigma = 25
# model, name, params = load_model("wandb://damowerko/mtt/phvf95ld")  # mse_sigma = 50
model.eval()

In [None]:
dataset = SparseDataset(
    f"data/train",
    length=params["input_length"],
    slim=True,
)

In [None]:
mse_model = 0
mse_naive = 0
for i in range(100):
    data = dataset.get(i, 50)
    data = SparseData(*(x.cuda() for x in data))
    with torch.no_grad():
        input = model.to_stinput(data)
        label = model.to_stlabel(data)
        output = model.forward(*(x.cuda().contiguous() for x in input))
        mse_model += model.mse_loss(
            output.mu,
            output.logp,
            output.batch,
            label.y,
            label.y_batch,
            model.kernel_sigma,
        ).item()

        logp_naive = torch.full_like(output.logp, output.logp.exp().mean().log().item())
        mse_naive += model.mse_loss(
            output.mu,
            logp_naive,
            output.batch,
            label.y,
            label.y_batch,
            model.kernel_sigma,
        ).item()

print(f"MSE Model: {mse_model}, MSE Naive: {mse_naive}")

In [None]:
sim_idx = np.random.randint(len(dataset))
step_idx = 50
data = dataset.get(sim_idx, step_idx)
data = SparseData(*(x.cuda() for x in data))

with torch.no_grad():
    input = model.to_stinput(data)
    label = model.to_stlabel(data)
    output = model.forward(*(x.cuda() for x in input))

measurements = data.measurement_position.cpu().numpy()
positions = label.y.cpu().numpy()
estimates = output.mu.cpu().numpy()
probs = output.logp.exp().cpu().numpy()

with sns.axes_style("ticks"):
    fig, axs = plt.subplots(1, 2, figsize=(7.0, 3.5))
    # ax.plot(*data[step_idx].sensor_position.T, "bx", label="Sensor")

    axs[0].plot(*measurements.T, "g+", label="Measurement", alpha=1)
    axs[0].plot(*positions.T, "ro", label="Target")
    axs[0].plot(*estimates.T, "bx", label="Estimate")

    axs[1].plot(*positions.T, "ro", label="Target")

    # 128, 128 image
    XY = np.mgrid[-500:500:128j, -500:500:128j].reshape(2, -1).T
    d = np.linalg.norm(XY[:, None] - estimates[None, :], axis=-1)
    # gaussian kernel
    kernel = np.exp(-(d**2) / (2 * model.kernel_sigma**2))
    image = (kernel @ probs).reshape(128, 128).T
    axs[1].imshow(image, extent=(-500, 500, -500, 500), origin="lower", cmap="viridis")

    for ax in axs:
        ax.set_xlim(-500, 500)
        ax.set_ylim(-500, 500)
        ax.set_aspect("equal")
        ax.legend()

In [None]:
assert isinstance(model, KNN)

positions = model.model.blocks[0].conv.kernel_positions.detach().cpu().numpy()
plt.figure()
plt.scatter(*positions.T)
plt.show()