In [39]:
def drecexpbary(oracle, x0,
                nu = 2, sigma = 0.5, 
                zeta = 0, lambda_ = 1, 
                iterations = 1000):
    '''
     Recursive barycenter algorithm for direct optimization
     
     In:
       - oracle     [function]  : Oracle function e.g. lambda x: numpy.power(x, 2)
       - x0         [np.array]  : Initial query values
       - nu         [double]    : positive value (Caution on its value due overflow)
       - sigma      [double]    : Std deviation of normal distribution
       - zeta       [double]    : Proportional value for mean of normal distribution
       - lambda     [double]    : Forgetting factor between 0 and 1
       - iterations [int]       : Maximum number of iterations
     
     Out:
        - x [np.array]: Optimum position
    '''
    import numpy as np
    import scipy as scp
    
    # Initialization
    xhat_1 = x0
    m_1 = 0

    deltax_1 = np.zeros((len(x0), 1))
    solution_is_found = False
    normrnd = scp.random.normal
    
    # Optimization loop
    i = 1
    while(not solution_is_found):    
        z = normrnd(zeta*deltax_1, sigma).T
        
        x = xhat_1 + z
        e_i = np.exp(-nu*oracle(x))
        m = lambda_*m_1 + e_i
        xhat = (1/m)*(lambda_*m_1*xhat_1 + x*e_i)
        print(xhat)
        solution_is_found = i >= iterations
        
        # Update previous variables
        m_1=m
        xhat_1=xhat
        
        i = i + 1
    
    return xhat

In [43]:
import numpy as np

oracle = lambda x: np.power(x, 2)
x0 = np.array([10, 10])

x = drecexpbary(oracle, x0, 2, 0.5)


[[10.14602077 10.83967377]]
[[10.14602077 10.25301632]]
[[10.14659073 10.25301632]]
[[10.14659073  9.38677093]]
[[10.149525    9.05085545]]
[[10.149525    8.58875665]]
[[9.92567437 8.59653555]]
[[9.59590258 8.59902424]]
[[9.59640423 8.02683847]]
[[9.43572013 8.02120008]]
[[9.3003981  8.02503267]]
[[8.07575695 7.78408684]]
[[8.01333488 7.34989531]]
[[8.0133349  6.51222975]]
[[7.25255481 6.2133975 ]]
[[7.25255591 6.2133975 ]]
[[7.16554201 6.21339753]]
[[7.1652959  6.21342976]]
[[7.13161632 5.33262618]]
[[7.13161632 4.87184256]]
[[6.92558923 4.88360519]]
[[6.81267378 4.12802185]]
[[6.81274921 4.13261921]]
[[6.78371968 4.13611988]]
[[6.78555821 4.13722505]]
[[6.78556002 4.10897037]]
[[6.56041557 3.42093481]]
[[5.62492718 3.43529215]]
[[4.82521151 3.44154014]]
[[4.42044102 3.4344329 ]]
[[4.04507294 3.44199849]]
[[3.90673089 3.4420004 ]]
[[3.90673091 2.98037957]]
[[3.04441144 2.88466734]]
[[2.38176264 2.8849333 ]]
[[1.75249931 2.51699532]]
[[1.78922417 2.5172288 ]]
[[1.78924887 2.52555967]]


In [34]:
x

array([[7.36758868, 8.1837395 ]])