In [27]:
# Add pybary to os path 
import os, sys

currentdir = os.getcwd()
parentdir = os.path.dirname(currentdir)

sys.path.append(parentdir+'/pybary')

In [17]:
# Main imports
from numpy import power, array, exp, zeros
from numpy.random import normal
from numpy.linalg import norm
from functools import reduce

from pybary import bary_batch, bary_recursive

In [11]:
# Oracle function
oracle = lambda x: norm(x)

# Hyperparameters
nu = 1
sigma = 0.1
zeta = 0
lambda_ = 1

In [30]:
xs_test = array(
    [
        [0, 0], [1, 0], [0, 1], [1, 1]
    ]
)

bexp_fun = lambda x: exp(-nu*oracle(x))

den = bexp_fun(xs_test[0]) + bexp_fun(xs_test[1])
num = xs_test[0]*bexp_fun(xs_test[0]) + xs_test[1]*bexp_fun(xs_test[1])

bexp_fun = lambda x: exp(-nu*oracle(x))
oracle_eval = list(map(bexp_fun, xs_test))

print('Evaluations : '+str(oracle_eval))
print('Numerator   : '+str(num))
print('Denominator : '+str(den))
print('Barycenter  : '+str(num/den))

Evaluations : [1.0, 0.3678794411714424, 0.3678794411714424, 0.2431167344342142]
Numerator   : [0.36787944 0.        ]
Denominator : 1.3678794411714423
Barycenter  : [0.26894142 0.        ]


In [32]:
xs = xs_test

n = len(xs[0])
size_x = (n, 1)

prod_func = lambda elems: elems[0]*elems[1]
sum_func = lambda acc, a: acc + a

num = reduce(
    sum_func, 
    map(prod_func, zip(map(bexp_fun, xs), xs)), 
    zeros(size_x).T
)

den = reduce(sum_func, map(bexp_fun, xs), 0)

print(num)
print(den)
print(num/den)

[[0.61099618 0.61099618]]
1.978875616777099
[[0.30875926 0.30875926]]


In [33]:
# Batch setup

# Points for batch barycenter version
mu_x = 0
sigma_x = 1
size_batch = [1000, 2]

xs = normal(mu_x, sigma_x, size_batch)

# Batch run
xhat_batch = bary_batch(oracle, xs, nu)

# Results
print(xhat_batch)

[[0.0135538  0.03120113]]


In [34]:
# Recursive setup

# Initial point
x0 = array([0, 0])

# Iteration cardinality
iterations = 1000

# Recursive run
xhat_recursive = bary_recursive(oracle, x0, nu, sigma, zeta, lambda_, iterations)

print(xhat_recursive)


[[-0.01122628  0.05922258]]
