In [1]:
from utils.bmds import BMDSTrainer
from utils.nn import create_mlp_layers
import torch
from sklearn.neighbors import BallTree
from utils.preprocessing import check_tensor
from utils.distributions import exponential_log_prob
from typing import Any, Optional, Callable
from torch import nn

In [None]:
n = 60000
k = 100
d_latent = 2
d = 10
noise_coef = 1e-3
batch_size = 128

In [12]:
class BMDS(nn.Module):
    default_create_layers_kwargs: dict[str, Any] = {
        'activation': 'PReLU',
        'use_batch_norm': False,
        'last_layer_activation':  True,
        'last_layer_batch_norm': True,
    }
    
    def __init__(
            self,
            dist,
            neighbors,
            batch_size,
            input_dim: int,
            n: int,
            *,
            n_layers: int = 2,
            hidden_dim: int = 1000,
            embedding_dim: int = 100,
            create_layers: Optional[Callable[..., list[nn.Module]]] = None,
            **kwargs: Any,
    ):
        super().__init__()
        
        self.dist_sqr = check_tensor(dist).pow(2)
        self.neighbors = check_tensor(neighbors)
        
        self.batch_size = batch_size
        self.obj_idx = torch.arange(batch_size).repeat(dist.shape[1], 1).T.reshape(-1)
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.n = n

        if create_layers is None:
            create_layers = create_mlp_layers
            kwargs = {**self.default_create_layers_kwargs, **kwargs}

        head_layers = create_layers(input_dim, [hidden_dim] * (n_layers - 1), hidden_dim, **kwargs)

        self.head = nn.Sequential(*head_layers)

        self.mu = nn.Parameter(torch.randn(hidden_dim, embedding_dim) / hidden_dim)
        self.sigma = nn.Parameter(torch.randn(hidden_dim, embedding_dim) / hidden_dim)
    
    def forward(self, inp):
        head = self.head(inp)
        return head @ self.mu, (head.pow(2) @ self.sigma.pow(2)).pow(0.5)
    
    def loss(self, batch):
        idx1, idx2, true_dist_sqr = batch.values()
        
        mu1, sigma1 = self(self.get_inp(idx1))
        mu2, sigma2 = self(self.get_inp(idx2))
        
        dist_sqr = (mu1 - mu2 + torch.randn_like(mu1) * (sigma1.pow(2) + sigma2.pow(2)).pow(0.5)).pow(2).mean(1)
        
        log_prob = exponential_log_prob(true_dist_sqr.pow(2), dist_sqr).mean()
        reg = (torch.log(self.mu.pow(2).mean(0) + self.sigma.pow(2).mean(0)).sum() * self.hidden_dim - torch.log(self.sigma.pow(2)).sum()) / self.n / 2
        
        return {'loss': -log_prob + reg, 'log_prob': log_prob, 'reg': reg}
        
    def get_inp(self, idx):
        neighbors_idx = self.neighbors[idx].reshape(-1)
        neighbors_dist_sqr = self.dist_sqr[idx].reshape(-1)
        
        return torch.sparse_coo_tensor(
            indices=torch.stack((self.obj_idx, neighbors_idx)),
            values=neighbors_dist_sqr,
            size=(self.batch_size, self.input_dim),
        )
    
    
class NeighborsDataset(torch.utils.data.Dataset):
    def __init__(self, dist, neighbors):
        super().__init__()
        
        self.object_idx = torch.arange(dist.shape[0]).repeat(dist.shape[1], 1).T.reshape(-1)
        self.neighbors_idx = check_tensor(neighbors, dtype=torch.int32).reshape(-1)
        self.dist_sqr = (check_tensor(dist).reshape(-1) / dist.max()).pow(2)
        
    def __len__(self):
        return self.dist_sqr.shape[0]
    
    def __getitem__(self, idx):
        return {'idx1': self.object_idx[idx], 'idx2': self.neighbors_idx[idx], 'dist_sqr': self.dist_sqr[idx]}
    
    
def create_generator(dataset: torch.utils.data.Dataset, batch_size: int = 128, shuffle: bool = True, drop_last: bool = True, **kwargs):
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, **kwargs)
    while True:
        yield from loader

In [None]:
latent_data = torch.randn(n, d_latent)
data = latent_data @ torch.randn(d_latent, d) + torch.randn(n, d) * noise_coef

def query(data_point, *, ball_tree, n_neighbors):
    return ball_tree.query(data_point.reshape(1, -1), k=n_neighbors + 1)


def find_neighbors(data, n_neighbors: int, dist_fn=None):
    if dist_fn is None:
        ball_tree = BallTree(data, leaf_size=1, metric='euclidean')
        distances, neighbors = ball_tree.query(data, k=n_neighbors + 1)
        return distances[:, 1:], neighbors[:, 1:]
    
distances, neighbors = find_neighbors(data, k)

In [None]:
dataset = NeighborsDataset(distances, neighbors)
bmds = BMDS(distances, neighbors, batch_size, n, len(dataset))
bmds_trainer = BMDSTrainer(bmds)

bmds_trainer.train(create_generator(dataset), project_name='bmds mnist', experiment_name='first trial', total_iters=100000)