# Test Optimizer children as subclasses

In [2]:
from brian2 import *
from brian2tools import *

In [4]:
# create input and output
input_traces = zeros((10,5))*volt
for i in range(5):
    input_traces[5:,i]=i*10*mV

output_traces = 10*nS*input_traces

model = Equations('''
    I = g*(v-E) : amp
    g : siemens (constant)
    E : volt (constant)
    ''')

In [26]:
def test_optimizer(optimizer, n_samples=10, method='DE'):
    parameter_names = {'E', 'g'}
    bounds=[[-5, 5], [0, 10]]
    
    optim = optimizer(method=method, parameter_names=parameter_names, bounds=bounds)

    parameters = optim.ask(n_samples=n_samples)
    errors = np.random.rand(n_samples)
    
    print('shape parameters', shape(parameters))
    print('shape errors', shape(errors))

    print('parameters', parameters)
    print('errors', errors)
    
    optim.tell(parameters, errors)
    optim.recommend()
    
    return parameters, errors

In [28]:
test_optimizer(SkoptOptimizer, 10, 'gp')

shape parameters (10, 2)
shape errors (10,)
parameters [[2.0264499244435896, 0.6165173181422324], [-1.2310812382614325, 1.4564339539751738], [-0.06824526266854125, 8.889254049717007], [-4.415128140915489, 2.5569481782737853], [-0.21043422400669343, 6.760874977543353], [-1.843575730312545, 9.64957360874301], [-4.757008467187856, 9.369741245296256], [3.293873456124226, 1.3446141565794747], [1.5680902589998649, 6.191443526329299], [4.165048285451618, 5.005715949235508]]
errors [0.09706058 0.31756129 0.09542415 0.30379887 0.35098854 0.00992724
 0.45333657 0.96304378 0.88333876 0.31883338]


([[2.0264499244435896, 0.6165173181422324],
  [-1.2310812382614325, 1.4564339539751738],
  [-0.06824526266854125, 8.889254049717007],
  [-4.415128140915489, 2.5569481782737853],
  [-0.21043422400669343, 6.760874977543353],
  [-1.843575730312545, 9.64957360874301],
  [-4.757008467187856, 9.369741245296256],
  [3.293873456124226, 1.3446141565794747],
  [1.5680902589998649, 6.191443526329299],
  [4.165048285451618, 5.005715949235508]],
 array([0.09706058, 0.31756129, 0.09542415, 0.30379887, 0.35098854,
        0.00992724, 0.45333657, 0.96304378, 0.88333876, 0.31883338]))

In [29]:
test_optimizer(NevergradOptimizer, 10, 'DE')

shape parameters (10, 2)
shape errors (10,)
parameters [[-0.6351182304942128, 7.08509164468683], [-2.896700691329783, 5.577113549039541], [-1.8278066046565982, 7.48425605274756], [0.5547862267189654, 3.4314199267218117], [-1.3855594624949998, 1.8601283104041992], [1.1732382281420521, 4.180847167094241], [2.8135519725301084, 7.594221506835334], [-2.942151689947753, 6.061708588717931], [2.8945836796045374, 4.025866932842289], [1.7647534127793487, 4.516481628721874]]
errors [0.23384999 0.19836175 0.12089629 0.4472833  0.84181991 0.38544886
 0.58871133 0.01905118 0.21929367 0.51515305]


([[-0.6351182304942128, 7.08509164468683],
  [-2.896700691329783, 5.577113549039541],
  [-1.8278066046565982, 7.48425605274756],
  [0.5547862267189654, 3.4314199267218117],
  [-1.3855594624949998, 1.8601283104041992],
  [1.1732382281420521, 4.180847167094241],
  [2.8135519725301084, 7.594221506835334],
  [-2.942151689947753, 6.061708588717931],
  [2.8945836796045374, 4.025866932842289],
  [1.7647534127793487, 4.516481628721874]],
 array([0.23384999, 0.19836175, 0.12089629, 0.4472833 , 0.84181991,
        0.38544886, 0.58871133, 0.01905118, 0.21929367, 0.51515305]))

In [None]:
candidates, parameters = [], []


# setup the nevergrad optimizer
n_opt = NevergradOptimizer(method='DE', parameter_names={'E', 'g'},
                           bounds=[[-5, 5], [0, 10]])

parameters = n_opt.ask(10)

# pass parameters to the NeuronGroup
errors = fit_traces_ask_tell(model = model, input_var = 'v', output_var = 'I',\
                            input = input_traces, output = output_traces, dt = 0.1*ms,
                            g = [1*nS, 30*nS], E = [-20*mV,100*mV], update=parameters,
                            method=('linear'))



# give information to the optimizer
n_opt.tell(parameters, errors)

ans = n_opt.recommend()

# show answers
for n in zip(parameters, errors):
    print(n)

print(list(ans.args))