In [1]:
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch.optim import *
from functorch import vmap
from IPython.display import display, clear_output
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn import Module
from torch import  Tensor
from typing import Callable

%matplotlib inline

# The Model

In [2]:
class SoftNearestNeighbors(Module):
    def __init__(self,
                 encoder: Module,
                 similarity: Callable[[Tensor, Tensor], Tensor],
                 prototypes: tuple[Tensor, Tensor] | tuple[None, None] = (None, None)):
        super().__init__()
        self.sfn = similarity
        self.enc = encoder
        self.X, self.Y = prototypes

    @property
    def prototypes(self):
        return self.X, self.Y

    @prototypes.setter
    def prototypes(self, db):
        self.X, self.Y = db

    def forward(self, X):
        # Compute latent representation z of x.
        Z = self.enc(X)
        # Compute latent prototypes.
        # Note: db stands for database.
        Z_db = self.enc(self.X)
        # Compute the similarities between Z and Z_db.
        S = vmap(lambda z: vmap(lambda z_db: self.sfn(z, z_db))(Z_db))(Z)
        # Compute weights.
        W = F.softmax(S, dim=1)
        # Compute predictions.
        Y = vmap(lambda w: (w.view(-1, 1) * self.Y).sum(dim=0))(W)
        return Y

In [3]:
class Dense(nn.Module):
    def __init__(self, in_size, hidden_size):
        super().__init__()
        self.linear = nn.Linear(in_size, hidden_size)

    def forward(self, x):
        h = self.linear(x)
        return th.cat([x, h], dim=1)

In [4]:
encoder = nn.Sequential(nn.BatchNorm1d(1), Dense(1, 10), nn.Mish(),
                        nn.BatchNorm1d(11), Dense(11, 11), nn.Mish(),
                        nn.BatchNorm1d(22), Dense(22, 22), nn.Mish(),
                        nn.BatchNorm1d(44), Dense(44, 44), nn.Mish(),
                        nn.BatchNorm1d(88), Dense(88, 88), nn.Mish(),
                        nn.BatchNorm1d(176), nn.Linear(176, 1))

In [5]:
def similarity(a, b):
    return th.dot(a, b)

In [6]:
soft_nn = SoftNearestNeighbors(encoder, similarity)