In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import math
import typing

import matplotlib.pyplot as plt
import numpy as np
import scipy
import torch
import torch.nn.functional as F
from mpl_toolkits.mplot3d import Axes3D
from torch.utils.data import DataLoader, TensorDataset
from tqdm.notebook import tqdm

from torchcps.kernel.nn import KNN, Mixture

In [None]:
# create a datset using objects in 3D space
n_simulations = 1000
n_steps = 10
n_sensors = 4
n_dimensions = 3


width = 1000  # width of the area where targets are placed
min_height = 1000  # min height at which the targets are placed
padding = 0.25
vel_initial = 5.0
vel_noise = 0.5

# sensor parameters
sensor_noise_range = 10
sensor_noise_bearing = 0.035
sensor_noise_inclination = 0.035

In [None]:
# nearly constant velocity model
F_matrix = torch.tensor(
    [
        [1, 0, 0, 1, 0, 0],
        [0, 1, 0, 0, 1, 0],
        [0, 0, 1, 0, 0, 1],
        [0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 1],
    ]
).float()
G_matrix = torch.tensor(
    [
        [0.5, 0, 0],
        [0, 0.5, 0],
        [0, 0, 0.5],
        [1, 0, 0],
        [0, 1, 0],
        [0, 0, 1],
    ]
).float()

In [None]:
with torch.no_grad():
    ## create target trajectories
    targets = torch.zeros(n_simulations, n_steps, 2 * n_dimensions)
    # targets are placed uniformly in the area
    targets[:, 0, 0].uniform_(-width / 2 + width * padding, width / 2 - width * padding)
    targets[:, 0, 1].uniform_(-width / 2 + width * padding, width / 2 - width * padding)
    targets[:, 0, 2].uniform_(
        min_height + width * padding, min_height + width - width * padding
    )
    # initial velocity is constant but with random direction
    targets[:, 0, 3:].normal_()
    targets[:, 0, 3:] *= vel_initial * torch.norm(
        targets[:, 0, 3:], dim=1, keepdim=True
    )

    for i in range(1, n_steps):
        targets[:, i, :, None] = F_matrix @ targets[
            :, i - 1, :, None
        ] + G_matrix @ torch.normal(0, vel_noise, (n_simulations, 3, 1))

In [None]:
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111, projection="3d")
ax = typing.cast(Axes3D, ax)

ax.set_xlim(-width / 2, width / 2)
ax.set_ylim(-width / 2, width / 2)
ax.set_zlim(min_height, min_height + width)
for i in range(10):
    ax.plot(*targets[i, :, :].T)

plt.show()

In [None]:
with torch.no_grad():
    # sensors are placed on the ground at 0 height (z = 0)
    sensors = torch.stack(
        [
            torch.empty(n_simulations, n_sensors).uniform_(-width / 2, width / 2),
            torch.empty(n_simulations, n_sensors).uniform_(-width / 2, width / 2),
            torch.zeros(n_simulations, n_sensors),
        ],
        dim=-1,
    )

    # relative position of targets to sensors
    # shape: (n_simulations, n_steps, n_sensors, n_dimensions)
    dx = targets[:, :, None, :3] - sensors[:, None, :, :]
    r = torch.norm(dx, dim=-1)

    # add noise in polar coordinates
    measurements_range = (
        r + torch.randn(n_simulations, n_steps, n_sensors) * sensor_noise_range
    )
    measurements_bearing = (
        # theta = atan2(y / x)
        torch.atan2(dx[..., 1], dx[..., 0])
        + torch.randn(n_simulations, n_steps, n_sensors) * sensor_noise_bearing
    )
    measurements_elevation = (
        # phi = acos(z / r)
        torch.acos(dx[..., 2] / r)
        + torch.randn(n_simulations, n_steps, n_sensors) * sensor_noise_inclination
    )

    # convert back to cartesian coordinates
    # x = r cos(theta) sin(phi)
    # y = r sin(theta) sin(phi)
    # z = r cos(phi)
    measurements = sensors[:, None, :, :] + torch.stack(
        [
            measurements_range
            * torch.cos(measurements_bearing)
            * torch.sin(measurements_elevation),
            measurements_range
            * torch.sin(measurements_bearing)
            * torch.sin(measurements_elevation),
            measurements_range * torch.cos(measurements_elevation),
        ],
        dim=-1,
    )

In [None]:
print(measurements.shape)

In [None]:
# make an x-y plot of the target path, sensor positions and measurements for a specific simulation
sim_idx = 0
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111, projection="3d")
ax = typing.cast(Axes3D, ax)
ax.plot(*targets[sim_idx, :, :3].T, label="target")
# ax.scatter(*sensors[sim_idx, :, :].T, label="sensor")
ax.scatter(*measurements[sim_idx].reshape(-1, 3).T, label="measurement", c="r")
plt.show()

In [None]:
model = KNN(
    n_dimensions=3,
    # the time index of the measurements and sensor positions
    in_channels=4,
    hidden_channels=128,
    # the position of the targets
    out_channels=4,
    n_layers=4,
    n_layers_mlp=2,
    hidden_channels_mlp=128,
    sigma=50,
    max_filter_kernels=256,
    update_positions=True,
    alpha=None,
).cuda()

In [None]:
def compute_logp(
    mu: torch.Tensor,
    sigmas: torch.Tensor,
    existance_logp: torch.Tensor,
    y: torch.Tensor,
) -> torch.Tensor:
    # \sum_i p_i * gaussian(x | mu_i, sigma_i)
    normal_logp = (
        torch.distributions.Normal(mu, sigmas).log_prob(y_positions[:, None])
        # product in linear space to get multivariate normal
        .sum(-1)
    )
    # product in linear space to scale by probabilities
    logp = normal_logp + existance_logp
    # sum over all components and average over the batch
    return logp.logsumexp(dim=(0, 1)) - math.log(y.shape[0])


dataset = TensorDataset(
    measurements.cuda(),
    sensors[:, None].expand(measurements.shape).cuda(),
    targets.cuda(),
)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True, drop_last=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.0)

pbar = tqdm(range(100))
log = []
log_ema = None
for epoch in pbar:
    for x, s, y in dataloader:
        B = x.shape[0]
        x_positions = x.reshape(B * n_steps * n_sensors, n_dimensions)
        # the weights encode the step index [0, n_steps - 1]
        x_weights = torch.concatenate(
            (
                s,
                torch.arange(n_steps, device=x.device)
                .float()[None, :, None, None]
                .expand(B, n_steps, n_sensors, 1),
            ),
            dim=-1,
        ).reshape(B * n_steps * n_sensors, 4)

        # the length of each batch is n_steps * n_sensors
        x_batch = torch.full((B,), n_steps * n_sensors, device=x.device)
        x_mixture = Mixture(x_positions, x_weights, x_batch)
        z_mixture = model.forward(x_mixture)

        # parse mixture
        mu = z_mixture.positions.reshape(B, n_steps * n_sensors, 3)
        sigmas = F.softplus(z_mixture.weights[:, :-1]).reshape(
            B, n_steps * n_sensors, 3
        )
        # interpret as probabilities for each component within a batch
        existance_logp = (
            z_mixture.weights[:, -1].reshape(B, n_steps * n_sensors).log_softmax(-1)
        )

        # we want to predict the final position of the target
        y_positions = y[:, -1, :3]

        rmse = (
            (y_positions - mu[:, existance_logp.argmax(-1)])
            .pow(2)
            .sum(-1)
            .sqrt()
            .mean()
        )

        # log-probability loss
        value = compute_logp(mu, sigmas, existance_logp, y_positions)
        loss = -value

        optimizer.zero_grad()
        loss.backward()
        # Check if any gradients are non-finite
        if any(
            torch.isinf(p.grad).any() or torch.isnan(p.grad).any()  # type: ignore
            for p in model.parameters()
            if p.grad is not None
        ):
            print("Warning: non-finite gradient detected. Skipping this iteration.")
            continue
        optimizer.step()

        log.append(
            {
                "logp": value.item(),
                "prob": value.exp().item(),
                "rmse": rmse.item(),
                "sigma": sigmas.mean().item(),
            }
        )
        if log_ema is None:
            log_ema = log[-1]
        else:
            log_ema = {k: 0.9 * v + 0.1 * log[-1][k] for k, v in log_ema.items()}

        pbar.set_postfix(log_ema)

In [None]:
from scipy.signal.windows import gaussian

metric = "logp"

# gaussian window smoothing
window_std = 100
window = gaussian(window_std * 6, std=window_std)
window /= window.sum()
value = np.array([l[metric] for l in log])
value_smooth = scipy.signal.convolve(value, window, mode="valid")

# plot metric over time
plt.figure()
plt.plot(value_smooth, label=metric)
plt.legend()
plt.xlabel("step")
plt.show()