In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import typing

import torch
import torch.nn as nn
import torch_geometric.nn as gnn
from tqdm.notebook import tqdm

from mtt.data.sparse import SparseData, SparseDataset
from mtt.models.egnn import EGNNConv
from mtt.models.sparse import SparseInput, SparseLabel, SparseOutput

In [None]:
input_length = 4
dataset = SparseDataset(length=input_length, slim=True)
train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [0.9, 0.1], generator=torch.Generator().manual_seed(42)
)

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        n_f_in = 1  # time
        n_f_out = 2  # sigma, logp
        self.n_x_in = 2  # measurement_position, sensor_position

        f_hidden = 32
        self.x_hidden = 8
        n_hidden = 128
        ratio = 0.75

        self.radius = 100.0

        self.f_readin = nn.Linear(n_f_in, f_hidden)
        self.egnn = nn.ModuleList(
            [
                EGNNConv(f_hidden, f_hidden, self.x_hidden, self.x_hidden, n_hidden),
                EGNNConv(f_hidden, f_hidden, self.x_hidden, self.x_hidden, n_hidden),
                EGNNConv(f_hidden, f_hidden, self.x_hidden, self.x_hidden, n_hidden),
                EGNNConv(f_hidden, f_hidden, self.x_hidden, self.x_hidden, n_hidden),
            ]
        )
        self.select = typing.cast(
            list[gnn.pool.select.SelectTopK],
            nn.ModuleList(
                [
                    gnn.pool.select.SelectTopK(f_hidden, ratio=ratio),
                    gnn.pool.select.SelectTopK(f_hidden, ratio=ratio),
                    gnn.pool.select.SelectTopK(f_hidden, ratio=ratio),
                    gnn.pool.select.SelectTopK(f_hidden, ratio=ratio),
                ]
            ),
        )
        self.f_readout = gnn.MLP(
            [f_hidden, n_hidden, n_f_out], act=nn.LeakyReLU(), plain_last=True
        )

    def forward(self, data: SparseData):
        positions = data.measurement_position
        # h_in.shape = (N, 1)
        f_in = data.measurement_time.float()[:, None]
        f = self.f_readin.forward(f_in).relu()
        # x_in.shape = (N, 2, 2) the last two dimensions are the x,y coordinates
        x = torch.stack(
            [data.measurement_position, data.sensor_position]
            + [torch.zeros_like(data.measurement_position)]
            * (self.x_hidden - self.n_x_in),
            dim=1,
        )

        batch_idx = torch.repeat_interleave(
            torch.arange(
                len(data.measurement_batch_sizes),
                device=data.measurement_batch_sizes.device,
            ),
            data.measurement_batch_sizes,
        )
        for i in range(len(self.egnn)):
            # compute graph based on the positions
            edge_index = gnn.pool.knn_graph(positions, k=16, batch=batch_idx)
            df, dx = self.egnn[i].forward(f, x, edge_index)

            # limit the dx update to the radius
            dx = self.radius * dx / (torch.norm(dx, dim=-1, keepdim=True) + 1e-8)

            f = f + df
            x = x + dx
            positions = x[:, 0, :]

            # graph pooling
            selection = self.select[i].forward(f, batch_idx)
            x = x[selection.node_index]
            f = f[selection.node_index]
            positions = positions[selection.node_index]
            batch_idx = batch_idx[selection.node_index]

        mu = positions
        sigma = nn.functional.softplus(f[:, 0, None]).expand(-1, 2) + 1e-8
        logp = nn.functional.logsigmoid(f[:, 1])
        return SparseOutput(
            mu=mu,
            sigma=sigma,
            logp=logp,
            batch=torch.bincount(batch_idx),
        )


model = Model().cuda()

In [None]:
from mtt.models.sparse import (
    logp_loss,
    parallel_assignment,
    mse_loss,
    log_kernel_loss,
    kernel_loss,
)

dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=32, shuffle=True, collate_fn=SparseDataset.collate_fn
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-12)
model.cuda()

log = []
for i in range(10):
    pbar = tqdm(dataloader)
    for data in pbar:
        data = SparseData(*(x.cuda() for x in data))
        label = SparseLabel.from_sparse_data(data, input_length)

        output = model.forward(data)

        x_split_idx = output.batch.cumsum(0)[:-1].cpu()
        y_split_idx = label.batch.cumsum(0)[:-1].cpu()
        mu_split = output.mu.tensor_split(x_split_idx)
        sigma_split = output.sigma.tensor_split(x_split_idx)

        logp_split = output.logp.tensor_split(x_split_idx)
        y_split = label.y.tensor_split(y_split_idx)

        batch_size = output.batch.shape[0]
        loss = torch.zeros((batch_size,), device=output.mu.device)
        for batch_idx in range(batch_size):
            if mu_split[batch_idx].shape[0] == 0:
                continue

            loss[batch_idx] += kernel_loss(
                mu_split[batch_idx],
                logp_split[batch_idx].exp(),
                y_split[batch_idx],
                50.0,
            )

        # for batch_idx, i, j in parallel_assignment(mu_split, y_split, logp_split):
        #     loss[batch_idx] += logp_loss(
        #         mu_split[batch_idx],
        #         sigma_split[batch_idx],
        #         logp_split[batch_idx],
        #         y_split[batch_idx],
        #         (i, j),
        #     )
        # loss[batch_idx] += mse_loss(
        #     mu_split[batch_idx],
        #     y_split[batch_idx],
        #     (i, j),
        # )
        loss = loss.mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        log.append(
            {
                "loss": loss.item(),
                "sigma": output.sigma.mean().item(),
            }
        )
        pbar.set_postfix(log[-1])

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].plot([x["loss"] for x in log])
axs[1].plot([x["sigma"] for x in log])
plt.show()

In [None]:
import numpy as np

# plot the mu positions for a sample in the dataset
sample = dataset.get(np.random.randint(10000), 50)
with torch.no_grad():
    label = SparseLabel.from_sparse_data(sample, input_length)
    output = model.cpu().forward(sample)

    fig, axs = plt.subplots(1, 2, figsize=(20, 10))
    axs[0].scatter(*sample.measurement_position.T, c="blue")
    axs[0].scatter(*output.mu.cpu().T, c="black")
    axs[0].scatter(*label.y.cpu().T, c="red")

    # make an image using the output.logp.exp() as the intensity of a gaussian kernel with sigma = 10.0
    XY = torch.cartesian_prod(*[torch.linspace(-500, 500, 128)] * 2)
    dist = (XY[:, None, :] - output.mu[None, ...]).norm(dim=-1)
    # K = torch.exp(-(dist**2) / (2 * output.sigma.mean(-1)[None, :] ** 2))
    K = torch.exp(-(dist**2) / (2 * 20**2))
    Z = (K @ output.logp.exp().squeeze()).reshape(128, 128).cpu().numpy()

axs[1].imshow(Z.T, extent=(-500, 500, -500, 500), origin="lower", cmap="viridis")
axs[1].scatter(*label.y.cpu().T, c="red")

for ax in axs:
    ax.set_xlim(-500, 500)
    ax.set_ylim(-500, 500)
    ax.set_aspect("equal")

plt.show()

In [None]:
sample.measurement_position.shape

In [None]:
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=32, collate_fn=SparseDataset.collate_fn
)

model.eval()
tp, tn, fp, fn = 0, 0, 0, 0
for data in val_dataloader:
    data = typing.cast(SparseData, data)
    input = SparseInput.from_sparse_data(data, input_length)
    label = SparseLabel.from_sparse_data(data, input_length)

    estimates, clutter_prob = model.forward(
        Mixture(input.x_pos, input.x, input.x_batch)
    )
    is_clutter = data.is_clutter.float()

    tp += ((clutter_prob > 0.5) & (is_clutter == 1)).sum()
    tn += ((clutter_prob <= 0.5) & (is_clutter == 0)).sum()
    fp += ((clutter_prob > 0.5) & (is_clutter == 0)).sum()
    fn += ((clutter_prob <= 0.5) & (is_clutter == 1)).sum()

accuracy = (tp + tn) / (tp + tn + fp + fn)
f1 = 2 * tp / (2 * tp + fp + fn)
precision = tp / (tp + fp)
recall = tp / (tp + fn)
print(
    f"Accuracy: {accuracy:.2f}, F1: {f1:.2f}, Precision: {precision:.2f}, Recall: {recall:.2f}"
)