In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import typing

import torch
import torch.nn as nn
import torch_geometric.nn as gnn
from tqdm.notebook import tqdm

from mtt.data.sparse import SparseData, SparseDataset
from mtt.models.kernel import KernelEncoderLayer, KernelDecoderLayer, Mixture
from mtt.models.sparse import SparseInput, SparseLabel, SparseOutput

In [None]:
input_length = 4
dataset = SparseDataset(length=input_length, slim=True)
train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [0.9, 0.1], generator=torch.Generator().manual_seed(42)
)

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.readin = gnn.MLP([2 * input_length, 256, 32], act=nn.LeakyReLU())
        self.encoder_layers = typing.cast(
            list[KernelEncoderLayer],
            nn.ModuleList(
                [
                    KernelEncoderLayer(2, 25, 32, 256, 10.0, False),
                    KernelEncoderLayer(2, 25, 32, 256, 50.0, False),
                    KernelEncoderLayer(2, 25, 32, 256, 25.0, False),
                    KernelEncoderLayer(2, 25, 32, 256, 10.0, False),
                ]
            ),
        )
        self.clutter_readout = gnn.MLP(
            [32, 256, 1], plain_last=True, act=nn.LeakyReLU()
        )
        self.decoder_layers = typing.cast(
            list[KernelDecoderLayer],
            nn.ModuleList(
                [
                    KernelDecoderLayer(2, 25, 32, 256, 25.0, 100.0, False),
                    KernelDecoderLayer(2, 25, 32, 256, 25.0, 50.0, False),
                    KernelDecoderLayer(2, 25, 32, 256, 25.0, 25.0, False),
                    KernelDecoderLayer(2, 25, 32, 256, 25.0, 10.0, False),
                ]
            ),
        )
        self.readout = gnn.MLP([32, 256, 4], plain_last=True, act=nn.LeakyReLU())

    def forward(self, x: Mixture) -> tuple[SparseOutput, torch.Tensor]:
        # encoder
        e = x.map_weights(self.readin.forward)
        for layer in self.encoder_layers:
            e = layer.forward(e)

        clutter_prob = self.clutter_readout(e.weights).softmax(dim=-1)

        # decoder
        batch_size = 1 if x.batch is None else x.batch.shape[0]
        with torch.no_grad(), torch.device(x.weights.device):
            d_pos = (
                torch.cartesian_prod(*[torch.linspace(-500, 500, 5)] * 2)
                .reshape(-1, 2)
                .repeat_interleave(batch_size, dim=0)
            )

            d_weights = torch.zeros(
                d_pos.shape[0],
                e.weights.shape[1],
            )
            d_batch = torch.full((batch_size,), d_pos.shape[0] // batch_size)
            z = Mixture(d_pos, d_weights, d_batch)
        for layer in self.decoder_layers:
            z = layer.forward(z, e)

        # readout
        z = z.map_weights(self.readout.forward)
        delta_mu, sigma, logits = torch.split(z.weights, [2, 1, 1], dim=-1)
        mu = z.positions + delta_mu
        sigma = nn.functional.softplus(sigma).expand(-1, 2) + 1e-8
        logits = nn.functional.logsigmoid(logits)[..., 0]
        assert z.batch is not None
        estimates = SparseOutput(mu, sigma, logits, z.batch)
        return estimates, clutter_prob


model = Model().cuda()

In [None]:
from mtt.models.sparse import logp_loss, parallel_assignment

dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=32, shuffle=True, collate_fn=SparseDataset.collate_fn
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-8)

log = []
for i in range(10):
    pbar = tqdm(dataloader)
    for data in pbar:
        data = SparseData(*(x.cuda() for x in data))
        input = SparseInput.from_sparse_data(data, input_length)
        label = SparseLabel.from_sparse_data(data, input_length)

        output, clutter_prob = model.forward(
            Mixture(input.x_pos, input.x, input.x_batch)
        )

        clutter_loss = nn.functional.binary_cross_entropy(
            clutter_prob.squeeze(), data.is_clutter.float()
        )

        x_split_idx = output.batch.cumsum(0)[:-1].cpu()
        y_split_idx = label.batch.cumsum(0)[:-1].cpu()
        mu_split = output.mu.tensor_split(x_split_idx)
        sigma_split = output.sigma.tensor_split(x_split_idx)

        logp_split = output.logp.tensor_split(x_split_idx)
        y_split = label.y.tensor_split(y_split_idx)

        batch_size = output.batch.shape[0]
        logp = torch.zeros((batch_size,), device=output.mu.device)
        for batch_idx, i, j in parallel_assignment(mu_split, y_split, None):
            logp[batch_idx] = logp_loss(
                mu_split[batch_idx],
                sigma_split[batch_idx],
                logp_split[batch_idx],
                y_split[batch_idx],
                (i, j),
            )
        logp = logp.mean()
        loss = logp + clutter_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        log.append(
            {
                "logp": logp.item(),
                "clutter_loss": clutter_loss.item(),
                "loss": loss.item(),
                "sigma": output.sigma.mean().item(),
            }
        )
        pbar.set_postfix(log[-1])

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].plot([x["logp"] for x in log])
axs[1].plot([x["clutter_loss"] for x in log])
axs[2].plot([x["sigma"] for x in log])
plt.show()

In [None]:
# plot the mu positions for a sample in the dataset
sample = dataset.get(100, 50)
with torch.no_grad():
    input = SparseInput.from_sparse_data(sample, input_length)
    label = SparseLabel.from_sparse_data(sample, input_length)
    output, _ = model.forward(Mixture(input.x_pos, input.x, input.x_batch))

    fig, axs = plt.subplots(1, 2, figsize=(20, 10))
    axs[0].scatter(*output.mu.cpu().T, c="black")
    axs[0].scatter(*label.y.cpu().T, c="red")

    # make an image using the output.logp.exp() as the intensity of a gaussian kernel with sigma = 10.0
    XY = torch.cartesian_prod(*[torch.linspace(-500, 500, 128)] * 2)
    dist = (XY[:, None, :] - output.mu[None, ...]).norm(dim=-1)
    K = torch.exp(-(dist**2) / (2 * output.sigma.mean(-1)[None, :] ** 2))
    Z = (K @ output.logp.exp().squeeze()).reshape(128, 128).cpu().numpy()

axs[1].imshow(Z.T, extent=(-500, 500, -500, 500), origin="lower", cmap="viridis")
axs[1].scatter(*label.y.cpu().T, c="red")

for ax in axs:
    ax.set_xlim(-500, 500)
    ax.set_ylim(-500, 500)
    ax.set_aspect("equal")

plt.show()

In [None]:
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=32, collate_fn=SparseDataset.collate_fn
)

model.eval()
tp, tn, fp, fn = 0, 0, 0, 0
for data in val_dataloader:
    data = typing.cast(SparseData, data)
    input = SparseInput.from_sparse_data(data, input_length)
    label = SparseLabel.from_sparse_data(data, input_length)

    estimates, clutter_prob = model.forward(
        Mixture(input.x_pos, input.x, input.x_batch)
    )
    is_clutter = data.is_clutter.float()

    tp += ((clutter_prob > 0.5) & (is_clutter == 1)).sum()
    tn += ((clutter_prob <= 0.5) & (is_clutter == 0)).sum()
    fp += ((clutter_prob > 0.5) & (is_clutter == 0)).sum()
    fn += ((clutter_prob <= 0.5) & (is_clutter == 1)).sum()

accuracy = (tp + tn) / (tp + tn + fp + fn)
f1 = 2 * tp / (2 * tp + fp + fn)
precision = tp / (tp + fp)
recall = tp / (tp + fn)
print(
    f"Accuracy: {accuracy:.2f}, F1: {f1:.2f}, Precision: {precision:.2f}, Recall: {recall:.2f}"
)