# Barycenter method - recursive and batch versions 

## Setup

In [9]:
# Add pybary to os path 
import os, sys

currentdir = os.getcwd()
parentdir = os.path.dirname(currentdir)

sys.path.append(parentdir+'/pybary')

In [10]:
# Main imports
from numpy import power, array, exp, zeros, append, arange
from numpy.random import normal
from numpy.linalg import norm
from functools import reduce
import matplotlib.pyplot as plt

from IPython.display import display
from ipywidgets import interactive, FloatSlider, IntSlider
from collections import namedtuple

from pybary import bary_batch, bary_recursive, bary_recur_formula


## Hyperparameters

In [8]:
# Oracle function
oracle = lambda x: norm(x)

# Hyperparameters
nu = 5
sigma = 0.5
zeta = 0
lambda_ = 1

## Batch version

In [4]:
# Batch setup

# Points for batch barycenter version
mu_x = 0
sigma_x = 1
size_batch = [100, 2]

xs = normal(mu_x, sigma_x, size_batch)

# Batch run
xhat_batch = bary_batch(oracle, xs, nu)

# Results
print(xhat_batch)

[[-0.09435003 -0.02260265]]


## Recursive version

In [5]:
# Recursive setup

# Initial point
x0 = array([1, 1])

# Iteration cardinality
iterations = 100

# Recursive run
xhat_recursive = bary_recursive(oracle, x0, nu, sigma, zeta, lambda_, iterations)

print(xhat_recursive)


[[-0.02900835  0.01636211]]


# Iterative recursive version

In [6]:
RecurResultsProps = ["steps", "ms", "xhats"]
RecurResults = namedtuple("RecurResults", RecurResultsProps)

instanceAverageResultsProps = ["steps", "ms", "xhats"]
instanceAverageResults = namedtuple(
    "instanceAverageResults", 
    instanceAverageResultsProps
)

hyperparametersProps = ["nu", "sigma", "zeta", "lambda_"]
hyperparameters = namedtuple("hyperparameters", hyperparametersProps)

def update_mean(curr_count, curr_mean, x):
    return curr_count*curr_mean/(curr_count+1) +  array(x)/(curr_count+1)

def bary_recur_(oracle_fun, x0, nu, sigma, zeta, lambda_, iterations):
    # Necessary 
    xhat_1 = x0
    m_1 = 0
    card_x = (len(x0), 1)

    deltax_1 = zeros(card_x)

    ms = []
    xhats = []
    solution_is_found = False
    
    # Recursive run
    i = 1
    while not solution_is_found:
        z = normal(zeta * deltax_1, sigma).T

        x = xhat_1 + z
        m, xhat = bary_recur_formula(m_1, xhat_1, x, oracle_fun, nu, lambda_)
        
        ms.append(m)
        xhats.append(list(xhat[0]))

        # Update previous variables
        m_1 = m
        xhat_1 = xhat
        deltax_1 = xhat - xhat_1

        solution_is_found = i >= iterations
        i = i + 1
    
    return ms, xhats    
    
    
def do_recur(nu, sigma, zeta, lambda_, iterations):
    # Initial point
    x0 = array([1, 1])
    
    # Center point
    center = array([0, 0])
    
    # Oracle function
    oracle_fun = lambda x: norm(x-center)
    
    ms, xhats = bary_recur_(oracle_fun, x0, nu, sigma, zeta, lambda_, iterations)
    
    RecurResults.steps = arange(start=1, stop=iterations+1)
    RecurResults.ms = array(ms)
    RecurResults.xhats = array(xhats)
    
    return RecurResults

def do_plot(plotResults, hyperparameters):
    """
    Plot multi-images
    """
    steps = plotResults.steps
    ms = plotResults.ms
    xhats = plotResults.xhats
    
    plt.stem(steps, ms)
    plt.show()

    fig, axs = plt.subplots(1, 2)
    
    title_regex = '$\\nu = {}$, $\\sigma = {}$, $\\zeta = {}$, $\\lambda = {}$'
    title_str = title_regex.format(
        hyperparameters.nu, 
        hyperparameters.sigma,
        hyperparameters.zeta,
        hyperparameters.lambda_
    )
    plt.suptitle(title_str)
    
    axs[0].stem(steps, xhats[:,[0]])
    axs[1].stem(steps, xhats[:,[1]])

    plt.show()

def handle_event(nu, sigma, zeta, lambda_, iterations, instances):
    """
    Process events from the ipywidgets.interactive handler.
       
    Argument names in the event handler must match the keys in the "interactive" call (below).       
    """
    
    hyperparameters.nu = nu
    hyperparameters.sigma = sigma
    hyperparameters.zeta = zeta
    hyperparameters.lambda_ = lambda_
    
    card_m = (1, iterations)
    card_xhat = (iterations, len(x0))
    
    instanceAverageResults.ms = zeros(card_m);
    instanceAverageResults.xhats = zeros(card_xhat);
    
    for instance_count in range(instances):
        results = do_recur(nu, sigma, zeta, lambda_, iterations)
        
        instanceAverageResults.ms = update_mean(0, instanceAverageResults.ms, results.ms.T)
        
        instanceAverageResults.xhats = update_mean(
            instance_count, instanceAverageResults.xhats, array(results.xhats)
        )
    
    instanceAverageResults.steps = arange(start=1, stop=iterations+1)
    instanceAverageResults.ms = instanceAverageResults.ms.T
    
    do_plot(instanceAverageResults, hyperparameters) 

def float_slider_config(min_value, max_value, step_value, slider_value):
    """
    Return an FloatSlider widget with the common configuration
    """
    return FloatSlider(
        min=min_value, 
        max=max_value, 
        step=step_value,
        value=slider_value
    )

def int_slider_config(min_value, max_value, step_value, slider_value):
    """
    Return an IntSlider widget with the common configuration
    """
    return IntSlider(
        min=min_value, 
        max=max_value, 
        step=step_value,
        value=slider_value
    )

In [7]:
# Source: https://codesolid.com/creating-a-python-interactive-plot/

# Make the slider controls interactive, and display them
slider_controls = interactive(
    handle_event, 
    nu         = float_slider_config(1.0,    5,   0.25,    4.0),
    sigma      = float_slider_config(0.1,    1,   0.10,    0.5),
    zeta       = float_slider_config(0.0,    1,   0.10,    0.0), 
    lambda_    = float_slider_config(0.9,    1,   0.01,    1.0),
    iterations = int_slider_config(100,   1000, 100.00, 1000.0),
    instances  = int_slider_config(  1,    100,   1.0,     1.0),
)

display(slider_controls)

interactive(children=(FloatSlider(value=4.0, description='nu', max=5.0, min=1.0, step=0.25), FloatSlider(value…