In [1]:
from warnings import warn
from numpy import mean, transpose, cov, cos, sin, shape, exp, newaxis, concatenate
from numpy.linalg import linalg, LinAlgError, solve
from scipy.stats import chi2


#testing purposes
from numpy.random import seed
import numpy 
import abc

## Mean Embedding Test

In [2]:
#Construct data arrays, p-value = 0.05
seed(120)
num_samples = 500
dimensions = 10
X = numpy.random.randn(num_samples, dimensions)
X[:, 1] *= 3
Y = numpy.random.randn(num_samples, dimensions)
X.shape, Y.shape

((500, 10), (500, 10))

In [3]:
#init
scale = 1 
data_x, data_y = X*scale, Y*scale
number_of_frequencies = 5 #what is this?
_, dimension = numpy.shape(data_x)

In [4]:
points = numpy.random.randn(number_of_frequencies, dimension) #test points

ind = 0
a = numpy.zeros([points.shape[0],data_x.shape[0]])
for point in points:
    zx = numpy.linalg.norm(data_x - scale * point, axis=1)**2
    zy = numpy.linalg.norm(data_y - scale * point, axis=1)**2
    zx_est, zy_est = numpy.exp(-zx/2.0), numpy.exp(-zy/2.0)
    
    diff = zx_est - zy_est #diff in mean embeddings
    
    a[ind] = diff
    ind +=1

obs = a.T

num_samples, _ = shape(obs)
sigma = cov(transpose(obs))
mu = mean(obs, 0)
stat = num_samples * mu.dot(solve(sigma, mu.T)) #compute test statistic
pval = chi2.sf(stat, number_of_frequencies) #convert to p-value

print(pval)

3.89449630549191e-05


In [12]:
assert(num_samples*mu.dot(linalg.inv(sigma)).dot(mu) == stat)
cov(transpose(obs)), transpose(obs).dot(obs)/500

(array([[ 4.51025742e-05, -6.00022456e-06,  7.70028118e-05,
         -1.99651306e-06, -2.40734207e-06],
        [-6.00022456e-06,  2.97395703e-04, -5.25071619e-06,
          1.56852963e-05,  9.96891039e-05],
        [ 7.70028118e-05, -5.25071619e-06,  9.08104891e-04,
         -8.91269634e-06,  1.35870725e-05],
        [-1.99651306e-06,  1.56852963e-05, -8.91269634e-06,
          2.25518413e-04,  3.77469003e-05],
        [-2.40734207e-06,  9.96891039e-05,  1.35870725e-05,
          3.77469003e-05,  3.41933374e-04]]),
 array([[ 4.50450914e-05, -5.48752394e-06,  7.76222198e-05,
         -1.72316062e-06, -2.11997885e-06],
        [-5.48752394e-06,  3.04462357e-04,  6.59414613e-06,
          1.97755189e-05,  1.03813132e-04],
        [ 7.76222198e-05,  6.59414613e-06,  9.24568796e-04,
         -2.52839306e-06,  2.02381095e-05],
        [-1.72316062e-06,  1.97755189e-05, -2.52839306e-06,
          2.27284651e-04,  3.99972499e-05],
        [-2.11997885e-06,  1.03813132e-04,  2.02381095e-05,
  