In [108]:
# Import libraries
import math
import numpy as np
from typing import Callable, List
from sklearn.neighbors import NearestNeighbors
from scipy.stats import bernoulli

In [160]:
# Total distance to K-NN
def k_distances(
        queries: List[float], 
        samples: List[float], 
        k: int, 
        exclude: bool = False
    ) -> List[float]:

    queries = np.array(queries)
    samples = np.array(samples)

    if exclude: 
        nbrs = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(samples)
        distances, _ = nbrs.kneighbors(queries)
    else:
        nbrs = NearestNeighbors(n_neighbors=k, algorithm='ball_tree').fit(samples)
        distances, _ = nbrs.kneighbors(queries)
    
    return np.sum(distances,1)


# K-NN estimate of KL-divergence
def d_approx(forecasts,targets,d,k) -> float:

    n = len(forecasts)
    m = len(targets)

    p = k_distances(forecasts,forecasts,k,exclude = True)
    v = k_distances(forecasts,targets,k)

    d_sum = 0

    for i in range(n):
        d_sum += math.log(v[i]/p[i])


    return (d/n)*d_sum + math.log(m/(n-1))

In [161]:
# Parameters
k = 1 # Number of neighbours used in estimation
n = 100 # Number of samples from forecast distribution
m = 1 # Number of samples from true distribution
d = 5 # Dimension of forecasts

queries = np.random.normal(0, 1, n*d).reshape(n,d)
samples = np.random.normal(0, 1, m*d).reshape(m,d)
samples2 = bernoulli.rvs(0.5, size=m*d).reshape(m,d) * 2 - 1

print(d_approx(queries,samples,d,k))
print(d_approx(queries,samples2,d,k))

-0.4480305597216594
0.4429035650004032


In [162]:
a = []
b = []
c = []
e = []
f = []


for _ in range(1009):
    queries = np.random.normal(0, 1, n*d).reshape(n,d)
    queries1 = np.concatenate((np.random.normal(-1, 1, int(n*d/2)),np.random.normal(1, 1, int(n*d/2)))).reshape(n,d)
    queries2 = np.random.normal(0, 0.5, n*d).reshape(n,d)
    queries3 = np.random.normal(0, 100, n*d).reshape(n,d)
    queries4 = np.random.normal(1, 1, n*d).reshape(n,d)

    samples = np.random.normal(0, 1, m*d).reshape(m,d) 

    a.append(d_approx(queries,samples,d,k))
    b.append(d_approx(queries1,samples,d,k))
    c.append(d_approx(queries2,samples,d,k))
    e.append(d_approx(queries3,samples,d,k))
    f.append(d_approx(queries4,samples,d,k))


print(np.mean(a),np.std(a))
print(np.mean(b),np.std(b))
print(np.mean(c),np.std(c))
print(np.mean(e),np.std(e))
print(np.mean(f),np.std(f))

    

0.45073517664880125 0.8628203702680186
0.8431528077922128 0.5993548226654711
2.749170545597159 1.3976916511323265
-1.3014409868140362 0.17004369477269204
1.510249729926307 1.0156140591526026


In [158]:
def d_approx(
    queries: List[float], 
    sample: float,
    d,
    k
    ):

    return np.linalg.norm(min(queries - sample, key=lambda x: np.linalg.norm(x)))

In [122]:
np.array([[1,2,3],[1,1,1],[0,0,0]]) - np.array([1,2,3])

array([[ 0,  0,  0],
       [ 0, -1, -2],
       [-1, -2, -3]])

In [185]:

queries = np.random.normal(0, 1, 1000)
queries1 = np.concatenate((np.random.normal(-1, 1, 500),np.random.normal(1, 1, 500)))
samples = np.random.normal(0, 1, 100) 
print(np.mean([min(abs(queries-sample))**2 for sample in samples]))
print(np.mean([min(abs(queries1-sample))**2 for sample in samples]))

9.323695080646305e-06
1.831153791648671e-05
