In [4]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

import numpy as onp
import jax.numpy as np
from jax.scipy.special import digamma, gammaln

def nearest_neighbors_distance(X, k):
    """
    Compute the distance to the kth nearest neighbor for each point in X by
    exhaustively searching all points in X.
    
    X : ndarray, shape (n_samples, W, H) or (n_samples, num_features)
    k : int
    """
    X = X.reshape(X.shape[0], -1)
    distance_matrix = np.sum((X[:, None, :] - X[None, :, :]) ** 2, axis=-1)
    kth_nn_index = np.argsort(distance_matrix, axis=-1)[:, k]
    kth_nn = X[kth_nn_index, :]
    kth_nn_dist = np.sqrt(np.sum((X - kth_nn)**2, axis=-1))
    return kth_nn_dist


def nearest_neighbors_entropy_estimate(X, k):
    nn = nearest_neighbors_distance(X, k)
    N, d = X.shape

    # compute the log volume of the d-dimensional ball with raidus of the nearest neighbor distance
    log_vd = d * np.log(nn) + d/2 * np.log(np.pi) - gammaln(d/2 + 1)
    h = np.log(k) - digamma(k) + np.mean(log_vd + np.log(N) - np.log(k))
    return h 


def multivariate_gaussian_entropy(cov_matrix):
    """
    Numerically stable computation of the analytics entropy of a multivariate gaussian
    """
    d = cov_matrix.shape[0]
    det_cov_matrix = np.linalg.det(cov_matrix)
    if np.isinf(det_cov_matrix):
        raise ValueError('Determinant of covariance matrix is infinite')
    entropy = 0.5 * d * np.log(2 * np.pi * np.e) + 0.5 * np.sum(np.log(np.linalg.eigvalsh(cov_matrix)))
    return entropy


num_samples = 10000
num_dimensions = 20

# make a random covariance matrix
cov_mat = onp.random.randn(num_dimensions, num_dimensions)
cov_mat = cov_mat.T @ cov_mat
cov_mat_random = cov_mat / np.max(cov_mat) # for numerical stability

# identity covariance matrix
cov_mat_identity = np.eye(num_dimensions)


for cov_mat, name in zip([cov_mat_identity, cov_mat_random], ['Identity', 'Random']):
    print(name)
    X = onp.random.multivariate_normal(np.zeros(cov_mat.shape[0]), cov_mat, size=num_samples)    
    entropy_true = multivariate_gaussian_entropy(cov_mat)
    entropy_nn = nearest_neighbors_entropy_estimate(X, 3)
    print('True entropy: ', entropy_true)
    print('KL NN entropy estimate: ', entropy_nn)
    print('')


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Identity
True entropy:  28.37877
KL NN entropy estimate:  29.093527

Random
True entropy:  10.929836
KL NN entropy estimate:  18.432587

