In [84]:
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

%matplotlib inline

In [85]:
def loss_false(code_batch, k=1):
    """
    An activity regularizer based on the False-Nearest-Neighbor
    Algorithm of Kennel, Brown, and Arbanel. Phys Rev A. 1992
    
    Parameters
    ----------
    code_batch: tensor
        (Batch size, Embedding Dimension) tensor of encoded inputs
    k: int 
        The number of nearest neighbors used to compute 
        neighborhoods.
    """
    batch_size, n_latent = code_batch.shape
    
    # Fixed parameters
    rtol = 20.0
    atol = 2.0

    # Distance matrix calculation
    tri_mask = torch.tril(torch.ones(n_latent, n_latent), diagonal=-1)
    batch_masked = tri_mask.unsqueeze(1) * code_batch.unsqueeze(0)
    X_sq = torch.sum(batch_masked ** 2, dim=2, keepdim=True)
    pdist_vector = X_sq + X_sq.transpose(1, 2) - 2 * torch.bmm(batch_masked, batch_masked.transpose(1, 2))
    all_dists = pdist_vector

    # Average distances calculation
    all_ra = torch.sqrt(
        (1 / torch.arange(1, n_latent + 1, dtype=torch.float32)) *
        torch.sum(torch.std(batch_masked, dim=0, keepdim=True) ** 2, dim=1).squeeze()
    )

    # Clip distances to avoid singularities
    all_dists = torch.clamp(all_dists, min=1e-14, max=torch.max(all_dists))

    # Find k nearest neighbors
    _, inds = torch.topk(-all_dists, k=k+1, dim=-1)

    # Gather neighbor distances
    neighbor_dists_d = torch.gather(all_dists, 2, inds)
    neighbor_new_dists = torch.gather(all_dists[1:], 2, inds[:-1])

    # Calculate scaled distances
    scaled_dist = torch.sqrt(
        (neighbor_new_dists - neighbor_dists_d[:-1]) / neighbor_dists_d[:-1]
    )

    # Apply FNN conditions
    is_false_change = scaled_dist > rtol
    is_large_jump = neighbor_new_dists > atol * all_ra[:-1, None, None]
    is_false_neighbor = torch.logical_or(is_false_change, is_large_jump)

    # Count false neighbors
    total_false_neighbors = is_false_neighbor[..., 1:(k+1)].int()

    # Calculate regularization weights
    reg_weights = 1 - torch.mean(total_false_neighbors.float(), dim=(1, 2))
    reg_weights = F.pad(reg_weights, (1, 0))

    # Calculate batch-averaged activations
    activations_batch_averaged = torch.sqrt(torch.mean(code_batch ** 2, dim=0))

    # Compute final loss
    loss = torch.sum(reg_weights * activations_batch_averaged)

    return loss.float()

In [86]:
code_batch = torch.cumsum(torch.ones((32, 6)), axis=1)

tensor(20.)