In [299]:
from scipy.spatial import distance
import pickle
import torch
from torch import linalg
import numpy as np
from scipy.linalg import cho_solve

In [300]:
def load_object(repo: str, file: str):
    # Path constant to save the object
    PATH = f'{repo}/{file}.pkl'

    with open(PATH, 'rb') as f:
        return pickle.load(f)

In [301]:
features = load_object('../../data/train', 'donut')
features = torch.from_numpy(features)

In [302]:
print(features)
print(features.shape)
reshaped_feats = features.t() #torch.reshape(features, (2, 1000))
print(reshaped_feats)
print(features.t())

tensor([[ 4.8302,  1.7729],
        [-3.0147,  3.6299],
        [-2.7704, -4.4909],
        ...,
        [-3.2588,  3.8029],
        [-4.8496,  2.3319],
        [ 3.7889,  3.3190]], dtype=torch.float64)
torch.Size([1000, 2])
tensor([[ 4.8302, -3.0147, -2.7704,  ..., -3.2588, -4.8496,  3.7889],
        [ 1.7729,  3.6299, -4.4909,  ...,  3.8029,  2.3319,  3.3190]],
       dtype=torch.float64)
tensor([[ 4.8302, -3.0147, -2.7704,  ..., -3.2588, -4.8496,  3.7889],
        [ 1.7729,  3.6299, -4.4909,  ...,  3.8029,  2.3319,  3.3190]],
       dtype=torch.float64)


In [318]:
means = torch.zeros(2, dtype=torch.float64).normal_()
sigma = torch.zeros(2, 2, dtype=torch.float64).normal_()
L = torch.linalg.cholesky(sigma @ sigma.t() + torch.eye(2))
print(L)

tensor([[ 1.2016,  0.0000],
        [-0.5309,  1.6547]], dtype=torch.float64)


In [319]:
mahalanobis = lambda x, mu, S_inv: (x - mu).t() @ S_inv @ (x - mu)

print(mahalanobis(features[1], means, torch.inverse(L)))

tensor(6.4935, dtype=torch.float64)


In [320]:
def mahalanobis(u, v, cov):
    delta = u - v
    print(delta.shape)
    m = torch.matmul(torch.matmul(delta.t(), torch.inverse(cov)), delta)
    return m

def _batch_mahalanobis(bL, bx):
    r"""
    Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
    for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`.

    Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch
    shape, but `bL` one should be able to broadcasted to `bx` one.
    """
    n = bx.size(-1)
    bx_batch_shape = bx.shape[:-1]

    # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
    # we are going to make bx have shape (..., 1, j,  i, 1, n) to apply batched tri.solve
    bx_batch_dims = len(bx_batch_shape)
    bL_batch_dims = bL.dim() - 2
    outer_batch_dims = bx_batch_dims - bL_batch_dims
    old_batch_dims = outer_batch_dims + bL_batch_dims
    new_batch_dims = outer_batch_dims + 2 * bL_batch_dims
    # Reshape bx with the shape (..., 1, i, j, 1, n)
    bx_new_shape = bx.shape[:outer_batch_dims]
    for (sL, sx) in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
        bx_new_shape += (sx // sL, sL)
    bx_new_shape += (n,)
    bx = bx.reshape(bx_new_shape)
    # Permute bx to make it have shape (..., 1, j, i, 1, n)
    permute_dims = (list(range(outer_batch_dims)) +
                    list(range(outer_batch_dims, new_batch_dims, 2)) +
                    list(range(outer_batch_dims + 1, new_batch_dims, 2)) +
                    [new_batch_dims])
    bx = bx.permute(permute_dims)

    flat_L = bL.reshape(-1, n, n)  # shape = b x n x n
    flat_x = bx.reshape(-1, flat_L.size(0), n)  # shape = c x b x n
    flat_x_swap = flat_x.permute(1, 2, 0)  # shape = b x n x c
    M_swap = torch.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2)  # shape = b x c
    M = M_swap.t()  # shape = c x b

    # Now we revert the above reshape and permute operators.
    permuted_M = M.reshape(bx.shape[:-1])  # shape = (..., 1, j, i, 1)
    permute_inv_dims = list(range(outer_batch_dims))
    for i in range(bL_batch_dims):
        permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]
    reshaped_M = permuted_M.permute(permute_inv_dims)  # shape = (..., 1, i, j, 1)
    return reshaped_M.reshape(bx_batch_shape)  

In [335]:
#mahalanobis(features.t(), means.unsqueeze(1), L).shape
print(means.shape)
_batch_mahalanobis(L, (features - means))

torch.Size([2])


tensor([20.1001,  5.0694, 24.7087, 17.6917,  6.5707, 17.0211, 16.1480, 18.0451,
        16.2420, 16.0514,  3.2556,  8.1285, 16.9439, 23.2921, 23.8508, 11.2674,
        24.9308,  9.8712, 18.9821, 17.2443, 16.3630, 19.7476, 21.6267, 17.6330,
        18.6070, 18.2136, 29.3943,  5.5698, 18.2921, 21.3648, 19.6743,  9.3026,
        21.4542, 22.8172, 15.0363, 15.8300,  9.7038, 14.7042, 19.4855, 20.2349,
        17.1715,  4.6161, 13.6980, 17.5796,  4.7823, 16.5145, 20.8763,  5.3385,
        14.8522,  3.6058, 16.8907, 20.2990, 19.0398,  8.8930, 13.2721, 15.0168,
        10.0832,  6.6415, 15.0075,  5.3959, 16.4204, 13.5941,  8.1815, 19.0371,
         6.6960, 26.9291, 25.6815, 18.7053, 20.1590, 14.8073, 18.3407, 23.8836,
        19.0003,  3.0680, 12.7399, 20.8446, 13.8286, 26.2801, 14.8152, 13.3350,
        22.5234, 11.3606,  5.0073, 23.4201, 22.1629, 23.1404, 18.0632, 18.2459,
        17.8623, 24.3920, 11.6619, 24.5748, 16.9772, 23.8919, 22.3738, 18.2644,
         4.3031, 16.5787, 18.0221,  6.89

In [287]:

x = reshaped_feats - means
y = cho_solve((L, True), x)

RuntimeError: The size of tensor a (1000) must match the size of tensor b (2) at non-singleton dimension 1

In [288]:
print(y.shape)

(2, 1000)


In [289]:
print((reshaped_feats - means).t().shape)

y = torch.linalg.solve_triangular(L, (reshaped_feats - means).t(), upper=False)
y = torch.linalg.solve_triangular(L.t(), y, upper=True)

RuntimeError: The size of tensor a (1000) must match the size of tensor b (2) at non-singleton dimension 1

In [290]:
log_determinant_part = -torch.sum(torch.log(torch.diag(L)))
quadratic_part = -0.5 * features.double().matmul(torch.Tensor(y).double())
const_part = -0.5 * len(L) * np.log(2 * np.pi)

In [291]:
logpdf = const_part + log_determinant_part + quadratic_part

In [292]:
torch.exp(logpdf)

tensor([[1.5043e-06, 3.9502e-01, 7.5263e+03,  ..., 5.0314e-01, 6.6063e+01,
         1.5123e-06],
        [1.2215e-01, 5.6694e-03, 5.6967e-02,  ..., 4.9925e-03, 4.4468e-03,
         5.6896e-02],
        [1.9087e+03, 4.6720e-02, 3.3142e-07,  ..., 3.9963e-02, 2.8434e-04,
         3.3246e+03],
        ...,
        [1.5050e-01, 4.8296e-03, 4.7138e-02,  ..., 4.2068e-03, 3.4446e-03,
         6.7190e-02],
        [1.4760e+01, 3.2130e-03, 2.5051e-04,  ..., 2.5728e-03, 2.4421e-04,
         7.5852e+00],
        [1.3440e-06, 1.6353e-01, 1.1651e+04,  ..., 1.9963e-01, 3.0173e+01,
         9.9014e-07]], dtype=torch.float64)

In [338]:
S = torch.zeros(2, 2, dtype=torch.float64).normal_()
L = linalg.cholesky(torch.tril(S) @ torch.tril(S).t() + torch.eye(2))

_batch_mahalanobis(L.unsqueeze(0), (features - means.unsqueeze(0)))/L.det() == _batch_mahalanobis(L, (features - means))/L.det()

tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, Tr