In [1]:
from scipy.stats import uniform
from mango.tuner import Tuner
import math
import numpy as np

In [2]:
# Modified Branin function to include a categorical variable as well
# ref: https://pdfs.semanticscholar.org/5284/110bd42233ef08334ca3567c54ff9e01cc3f.pdf

In [4]:
space = {
    'x1': uniform(0, 1),
    'x2': uniform(0, 1),
    'x3': ['a', 'b', 'c']
}

def branin_cat(x1, x2, x3):
    a = 1.
    b = 5.1 / (4.*np.pi**2)
    c = 5. / np.pi
    r = 6.
    s = 10.
    t = 1. / (8.*np.pi)
    x11 = x1*15. - 5
    fb = a*(15.0*x2-b*x11**2+c*x11-r)**2+s*(1-t)*np.cos(x11)+s
    fb_mod = (fb - 54.8104) / 51.9496
    if x3 == 'a':
        ret = fb_mod + 0.2
    elif x3 == 'b':
        ret = fb_mod * 0.5
    elif x3 == 'c':
        ret = 1.03 + x1**2 - 2*x2**2 - math.log(abs(fb_mod)**0.5)
    else:
        raise ValueError("x3:%s not recognized" % x3)

    return ret

In [5]:
def objfunc(args_list):
    results = []
    for hyper_par in args_list:
        result = - branin_cat(**hyper_par)
        results.append(result)
    return results

batch_size = 1
config = {
    'num_iteration': 100,
    'batch_size': batch_size,
    'domain_size': 100,
}
n_trials = 10

res = []
for t in range(n_trials):
    tuner = Tuner(space, objfunc, conf_dict=config)
    results = tuner.maximize()
    print(len(results['objective_values']))
    print(results['best_params'])
    y = [0]*10
    vals = [t for t in results['objective_values']]
    for i in range(10):
        y[i] = max(vals[:(i+1)*batch_size*10])

    print(y)
    res.append(y)

avg = ([sum(i)*1.0/n_trials for i in zip(*res)]) 
print(avg)

101
{'x1': 0.652254639453049, 'x2': 0.9938149819759036, 'x3': 'c'}
[0.6451072510428864, 0.6701210738659363, 0.683313412979756, 0.683313412979756, 1.0151251006070252, 1.0151251006070252, 1.0151251006070252, 1.0151251006070252, 1.0151251006070252, 1.0151251006070252]
101
{'x1': 0.7265931187330508, 'x2': 0.9925866750672363, 'x3': 'c'}
[0.6619583496912418, 0.9380828794960155, 0.9380828794960155, 0.9380828794960155, 0.9559676303446515, 0.9559676303446515, 0.9559676303446515, 0.9559676303446515, 0.9559676303446515, 0.9559676303446515]
101
{'x1': 0.6086149761787906, 'x2': 0.9829049733342899, 'x3': 'c'}
[0.740394254791958, 0.7705889389468477, 0.7705889389468477, 0.7705889389468477, 0.9580383218373278, 0.9580383218373278, 0.9580383218373278, 0.9580383218373278, 0.9580383218373278, 0.9580383218373278]
101
{'x1': 0.5071650026728405, 'x2': 0.9864548700339378, 'x3': 'c'}
[0.5531440711240905, 0.7176976932818764, 0.7533603183236601, 0.930421339105069, 0.930421339105069, 0.930421339105069, 0.930421339