In [2]:
import numpy as np
import math
import os
import pickle
import torch
import torch.nn as nn
import scipy.special as ss

In [3]:
class empty_class:
    pass

In [20]:
args = empty_class()
args.d = 20
args.k = 6
args.n_samples = 10000000
args.n_samples_d_f2 = 10000
args.seed = 42
args.alpha = 1
args.gamma = 0.5

In [5]:
def get_nu_samples(args):
    fname = os.path.join('nu_samples', 
                         f'nu_samples_{args.d}_{args.k}_{args.n_samples}_{args.seed}.pkl')
    if os.path.exists(fname):
        X = pickle.load(open(fname, 'rb'))
        return X
    else:
        q_k_d = ss.jacobi(args.k, (args.d-3)/2.0, (args.d-3)/2.0)
        legendre_k_d = q_k_d/q_k_d(1)
        X0 = torch.randn(args.n_samples,args.d)
        X0 = torch.nn.functional.normalize(X0, p=2, dim=1)
        #acceptance_prob = torch.from_numpy(0.49 + 0.49*legendre_k_d(X0[:,args.d-1]))
        acceptance_prob = torch.nn.functional.relu(torch.from_numpy(0.99*legendre_k_d(X0[:,args.d-1])))
        acceptance_vector = torch.bernoulli(acceptance_prob)
        print('Acc. prob. sum:', torch.norm(acceptance_prob, p=1), 'Acc. vec. sum:', torch.norm(acceptance_vector, p=1))
        accepted_rows = []
        for i in range(args.n_samples):
            if acceptance_vector[i] == 1:
                accepted_rows.append(i)
        accepted_rows_tensor = torch.tensor(accepted_rows).unsqueeze(1).expand([len(accepted_rows),args.d])
        X = torch.gather(X0, 0, accepted_rows_tensor)
        if not os.path.exists('nu_samples'):
            os.makedirs('nu_samples')
        pickle.dump(X, open(fname, 'wb'))
        return X

In [6]:
def get_mu_samples(args):
    fname = os.path.join('mu_samples', 
                         f'mu_samples_{args.d}_{args.n_samples}_{args.seed}.pkl')
    if os.path.exists(fname):
        X = pickle.load(open(fname, 'rb'))
        return X
    else:
        #X = torch.randn(args.n_samples,args.d)
        #X = torch.nn.functional.normalize(X, p=2, dim=1)
        q_k_d = ss.jacobi(args.k, (args.d-3)/2.0, (args.d-3)/2.0)
        legendre_k_d = q_k_d/q_k_d(1)
        X0 = torch.randn(args.n_samples,args.d)
        X0 = torch.nn.functional.normalize(X0, p=2, dim=1)
        acceptance_prob = torch.nn.functional.relu(torch.from_numpy(-0.99*legendre_k_d(X0[:,args.d-1])))
        acceptance_vector = torch.bernoulli(acceptance_prob)
        print('Acc. prob. sum:', torch.norm(acceptance_prob, p=1), 'Acc. vec. sum:', torch.norm(acceptance_vector, p=1))
        accepted_rows = []
        for i in range(args.n_samples):
            if acceptance_vector[i] == 1:
                accepted_rows.append(i)
        accepted_rows_tensor = torch.tensor(accepted_rows).unsqueeze(1).expand([len(accepted_rows),args.d])
        X = torch.gather(X0, 0, accepted_rows_tensor)
        if not os.path.exists('mu_samples'):
            os.makedirs('mu_samples')
        pickle.dump(X, open(fname, 'wb'))
        return X

In [10]:
q_k_d = ss.jacobi(args.k, (args.d-3)/2.0, (args.d-3)/2.0)
legendre_k_d = q_k_d/q_k_d(1)

In [7]:
X_nu = get_nu_samples(args)

In [8]:
X_mu = get_mu_samples(args)

In [17]:
print(X_nu.shape, X_mu.shape)
min_num = np.min([X_nu.shape[0],X_mu.shape[0]])
X_nu = X_nu[:(min_num),:]
X_mu = X_mu[:(min_num),:]
print(X_nu.shape, X_mu.shape)

torch.Size([7220, 20]) torch.Size([7220, 20])
torch.Size([7220, 20]) torch.Size([7220, 20])


In [33]:
def d_f1_estimate(X_nu, X_mu, a, b):
    gen_moment_nu_positive = a*torch.mean(torch.nn.functional.relu(X_nu[:,args.d-1])) + b*torch.mean(torch.nn.functional.relu(-X_nu[:,args.d-1]))
    gen_moment_nu_negative = a*torch.mean(torch.nn.functional.relu(-X_nu[:,args.d-1])) + b*torch.mean(torch.nn.functional.relu(X_nu[:,args.d-1]))
    gen_moment_mu_positive = a*torch.mean(torch.nn.functional.relu(X_mu[:,args.d-1])) + b*torch.mean(torch.nn.functional.relu(-X_mu[:,args.d-1]))
    gen_moment_mu_negative = a*torch.mean(torch.nn.functional.relu(-X_mu[:,args.d-1])) + b*torch.mean(torch.nn.functional.relu(X_mu[:,args.d-1]))
    return torch.max(torch.abs(gen_moment_nu_positive - gen_moment_mu_positive),torch.abs(gen_moment_nu_negative - gen_moment_mu_negative))

In [57]:
print(d_f1_estimate(X_nu, X_mu, 1, 0))

tensor(0.0312)


In [97]:
N_kd = (2*args.k + args.d - 2) * math.factorial(args.k + args.d - 3) / (math.factorial(args.k) * math.factorial(args.d -2))

In [98]:
print(N_kd)

35700.0


In [99]:
lambda_kd = (args.d - 2)*math.factorial(args.alpha)*math.gamma((args.d-1)/2)*math.gamma(args.k-args.alpha)/ \
((2**args.k)*math.gamma((args.k-args.alpha+1)/2)*math.gamma((args.k+args.d+args.alpha)/2))

In [100]:
print(1/lambda_kd)

2067.6923076923076


In [28]:
def f2_kernel_evaluation(X0, X1, fill_diag = True):
    #X = torch.cat((X_mu, -X_nu), 0)
    if fill_diag:
        inner_prod = torch.matmul(X0,X1.t()).fill_diagonal_(fill_value = 1)
    else:
        inner_prod = torch.matmul(X0,X1.t())
    values = ((np.pi-torch.acos(inner_prod))*inner_prod \
            + torch.sqrt(1-inner_prod*inner_prod))/(2*np.pi*(args.d+1))
    return values 
    #2*np.sqrt(torch.mean(values))

In [23]:
def d_f2_estimate_exact_kernel(X_nu, X_mu, a, b):
    kernel_eval_X_mu_X_mu = f2_kernel_evaluation(X_mu, X_mu)
    kernel_eval_X_nu_X_nu = f2_kernel_evaluation(X_nu, X_nu)
    kernel_eval_X_mu_X_nu = f2_kernel_evaluation(X_mu, X_nu, fill_diag = False)
    return np.sqrt(torch.mean(kernel_eval_X_mu_X_mu) + torch.mean(kernel_eval_X_nu_X_nu) - 2*torch.mean(kernel_eval_X_mu_X_nu))

In [21]:
def d_f2_estimate(X_nu, X_mu, a, b):
    Y0 = torch.randn(args.d,args.n_samples_d_f2)
    Y0 = torch.nn.functional.normalize(Y0, p=2, dim=0)
    gen_moment_nu_positive = a*torch.mean(torch.nn.functional.relu(torch.matmul(X_nu,Y0)), dim=0) + b*torch.mean(torch.nn.functional.relu(-torch.matmul(X_nu,Y0)), dim=0)
    gen_moment_nu_negative = a*torch.mean(torch.nn.functional.relu(-torch.matmul(X_nu,Y0)), dim=0) + b*torch.mean(torch.nn.functional.relu(torch.matmul(X_nu,Y0)), dim=0)
    gen_moment_mu_positive = a*torch.mean(torch.nn.functional.relu(torch.matmul(X_mu,Y0)), dim=0) + b*torch.mean(torch.nn.functional.relu(-torch.matmul(X_mu,Y0)), dim=0)
    gen_moment_mu_negative = a*torch.mean(torch.nn.functional.relu(-torch.matmul(X_mu,Y0)), dim=0) + b*torch.mean(torch.nn.functional.relu(torch.matmul(X_mu,Y0)), dim=0)
    d_f2_sq = torch.mean(0.5*(gen_moment_nu_positive-gen_moment_mu_positive)**2 + 0.5*(gen_moment_nu_negative-gen_moment_mu_negative)**2)
    return torch.sqrt(d_f2_sq)
    #return torch.max(torch.abs(gen_moment_nu_positive - gen_moment_mu_positive),torch.abs(gen_moment_nu_negative - gen_moment_mu_negative))

In [22]:
print(d_f2_estimate(X_nu, X_mu, 1, 0))

tensor(0.0023)


In [29]:
print(d_f2_estimate_exact_kernel(X_mu, X_nu, 1, 0))

tensor(0.0022)
