In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import typing

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch_geometric.nn as gnn
from tqdm.notebook import tqdm

from mtt.data.sparse import SparseData, SparseDataset
from mtt.models.egnn import EGNNConv
from mtt.models.sparse import SparseInput, SparseLabel, SparseOutput

In [None]:
input_length = 20
dataset = SparseDataset(length=input_length, slim=True)
train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [0.9, 0.1], generator=torch.Generator().manual_seed(42)
)

In [None]:
from mtt.models.kernel import KNN
from torch.utils.data import DataLoader
from torchcps.kernel.nn import sample_kernel
from torchcps.kernel.rkhs import GaussianKernel, Mixture

batch_size = 4
data_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=SparseDataset.collate_fn,
)
data = next(iter(data_loader))

model = KNN(input_length=input_length, n_samples=10000)
input = model.forward_input(data)
# reshape input to (batch, n_samples, input_length)
input = Mixture(
    input.positions.reshape(batch_size, -1, 2),
    input.weights.reshape(batch_size, -1, model.input_length),
)
XY = torch.cartesian_prod(*[torch.linspace(-500, 500, 128)] * 2)[None, ...].expand(
    4, -1, -1
)
kernel = GaussianKernel(model.sigma)
x = sample_kernel(kernel, input, XY)

In [None]:
from itertools import product

data = typing.cast(SparseData, data)
batch_idx = torch.repeat_interleave(
    torch.arange(
        data.target_batch_sizes.shape[0], device=data.target_batch_sizes.device
    ),
    data.target_batch_sizes,
)

steps = list(range(19, 0, -5))
fig, axs = plt.subplots(
    len(steps), batch_size, figsize=(5 * batch_size, 5 * len(steps))
)
for (it, t), b in product(enumerate(steps), range(batch_size)):
    ax = axs[it, b]
    ax.imshow(
        x.weights[b, :, t].reshape(128, 128).detach().numpy().T,
        cmap="viridis",
        origin="lower",
        extent=[-500, 500, -500, 500],
    )
    ax.set_title(f"Step {t}, Batch {b}")
    ax.scatter(
        *data.target_position[(batch_idx == b) & (data.target_time == t)].T, color="red"
    )

In [None]:
sigmas = [10, 15, 20, 25, 30]
fig, axs = plt.subplots(
    len(sigmas), batch_size, figsize=(5 * batch_size, 5 * len(sigmas))
)
for i in range(len(sigmas)):
    model = KNN(input_length=1, sigma=sigmas[i], n_samples=1000)
    input = model.forward_input(data)
    input = Mixture(
        input.positions.reshape(batch_size, -1, 2).contiguous(),
        input.weights.reshape(batch_size, -1, model.input_length).contiguous(),
    )
    XY = (
        torch.cartesian_prod(*[torch.linspace(-500, 500, 128)] * 2)[None, ...]
        .expand(4, -1, -1)
        .contiguous()
    )
    kernel = GaussianKernel(model.sigma)
    x = sample_kernel(kernel, input, XY)

    for b in range(batch_size):
        ax = axs[i, b]
        ax.imshow(
            x.weights[b, :, -1].reshape(128, 128).detach().numpy().T,
            cmap="viridis",
            origin="lower",
            extent=[-500, 500, -500, 500],
        )
        ax.title.set_text(f"Sigma {sigmas[i]}, Batch {b}")