In [6]:
import random

import numpy as np
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
import seaborn as sns

import tengp
from gpbenchmarks import get_data


def pdivide(x, y):
    return np.divide(x, y, out=np.copy(x), where=x!=0)

def plog(x, y):
    return np.log(x, out=np.copy(x), where=x>0)

def psin(x, y):
    return np.sin(x)

def pcos(x, y):
    return np.cos(x)

def pow2(x, y):
    return x**2

def pow3(x, y):
    return x**3

def ptan(x, y):
    return np.tan(x)

def ptanh(x, y):
    return np.tanh(x)

def psqrt(x, y):
    return  np.sqrt(x)

def pexp(x, y):
    return np.exp(x)


funset = tengp.FunctionSet()
funset.add(np.add, 2)
funset.add(np.subtract, 2)
funset.add(np.multiply, 2)
funset.add(pdivide, 2)
funset.add(plog, 2)
funset.add(psin, 2)
funset.add(pcos, 2)

#     function set: +, -, *, /, sin, cos, tan, tanh, sqrt, exp, log, **2, **3
ext_funset = tengp.FunctionSet()
ext_funset.add(np.add, 2)
ext_funset.add(np.subtract, 2)
ext_funset.add(np.multiply, 2)
ext_funset.add(pdivide, 2)
ext_funset.add(psin, 2)
ext_funset.add(pcos, 2)
ext_funset.add(ptan, 2)
ext_funset.add(ptanh, 2)
ext_funset.add(psqrt, 2)
ext_funset.add(pexp, 2)
ext_funset.add(plog, 2)
ext_funset.add(pow2, 2)
ext_funset.add(pow3, 2)

params1d = tengp.Parameters(2, 1, 1, 50, funset, real_valued=True)
params2d = tengp.Parameters(3, 1, 1, 50, funset, real_valued=True)
params5d = tengp.Parameters(6, 1, 1, 50, ext_funset, real_valued=True)

functions = [('nguyenf4', params1d, [20, -1, 1]), ('nguyenf7', params1d, [20, 0, 2]), ('nguyenf10', params2d, [100, -1, 1]), ('korns12', params5d, [10000, -50, 50])]

In [7]:
import pygmo as pg

In [20]:
vectors = []
n_actives = []

class cost_function:
    def __init__(self, X, Y, params, bounds):
        self.params = params
        self.bounds = bounds
        self.X = X
        self.Y = Y
    
    def fitness(self, x):
        #vectors.append(x)
        #x[-1] = round(x[-1])
        print(list(x))
        %time individual = tengp.individual.NPIndividual(list(x), self.bounds, self.params)
        individual = tengp.individual.NPIndividual(list(x), self.bounds, self.params)
        pred = individual.transform(self.X)
        
        #n_actives.append(len(individual.active_nodes))
        
        try:
            return [mean_squared_error(pred, self.Y)]
        except ValueError:
            return [10000000000]
        
    def get_bounds(self):
        max_back = 10
        lower = [0]*len(self.bounds)
        for i in range(len(lower)):
            if i % 3 == 0:
                #print(lower[i],'\t', self.bounds[i])
                continue
            
            lower[i] = self.bounds[i] - max_back
            
            if lower[i] < 0:
                lower[i] = 0
            #print(lower[i], '\t', self.bounds[i])
            
                
        lower[-1] = self.bounds[-1] - max_back
        #print(lower[-1],'\t', self.bounds[-1])
        
        
        return (lower, [b for b in self.bounds])

In [None]:
%%time
results = []
champions = []

for f in functions:
    print(f[0])
    x, y = get_data(f[0], *f[2])
    x = np.c_[np.ones(len(x)), x]
    bounds = tengp.individual.IndividualBuilder(f[1]).create().bounds[:]

    prob = pg.problem(cost_function(x, y, f[1], bounds))
    
    problem_results = []
    pg.set_global_rng_seed(42)
    for i in range(100):
        print(i, end=', ')
        algo = pg.algorithm(pg.pso(
            gen=200,
            eta1=0.1,
            eta2=4
#             omega=0.4,
#             max_vel=1
#             variant=1,
#             neighb_type=4
        ))
        algo.set_verbosity(1)
        pop = pg.population(prob, 50)
        pop = algo.evolve(pop)
        problem_results.append(pop.champion_f[0])
        champions.append(pop.champion_x)
    results.append(problem_results)
    

nguyenf4
0, [0.8371112701438281, 0.8881315441512125, 0.35330381307400344, 3.755573732485472, 1.099589801406659, 0.40589823885656795, 3.883751813765638, 0.7072493374099462, 2.502935656073749, 1.9463948210155202, 3.368774818424023, 3.8723499281355367, 3.8885844565852827, 3.8015065030375284, 1.521408642870985, 0.0731335403801519, 2.6353887094041584, 0.8499840573665479, 3.120448956407131, 0.9923033287476009, 1.677184296446976, 4.141214773178328, 7.564316292676427, 7.037548718770104, 4.798969028517865, 1.8974524926569116, 8.671875897361678, 2.2341531846843514, 6.125118038027933, 2.477402713078631, 1.2762205644301596, 10.588619186134101, 3.005951443364516, 2.3080038994714114, 6.730546619089408, 6.908217986540656, 3.241522619349035, 8.880761707672905, 3.4699203334042243, 0.8628656035892028, 7.653484926871224, 6.205913126587644, 1.239871886285203, 9.88114961158678, 14.454549777037593, 5.401774062696669, 10.777614299855527, 12.93231293647634, 2.6155133456606627, 10.884100533585869, 13.029494511

CPU times: user 3.75 s, sys: 109 ms, total: 3.86 s
Wall time: 3.91 s
[2.0537836457778393, 0.6103808649713008, 0.08215773609480577, 1.263267694531217, 0.3853749399334133, 1.086882568026335, 0.669451715702353, 0.49387849721505583, 2.565223174457069, 2.2310240167209305, 1.7289050703944935, 3.67518281517916, 4.266836263029995, 4.917657812761674, 0.19650200503074583, 2.049594456081808, 4.75187890572766, 4.541348146134237, 2.9274529926877992, 3.7073245718722747, 2.662587907914424, 1.2862113211054331, 2.538395284045494, 3.591827895192711, 0.8171048913418931, 0.02928171802494999, 1.9422777125319899, 5.866823945719974, 2.0192607657587818, 6.115498363812, 0.8733592835643164, 6.913118011074666, 2.1019185770183304, 2.3974546515933692, 2.912289923267376, 9.397376496986254, 5.711796884500933, 3.130287890146376, 4.103358703639452, 1.0910538159889644, 6.249896922495074, 11.999462100118649, 2.783177051351708, 6.067337875423881, 10.659583927149363, 5.0198065717023805, 8.71288150516309, 7.829987485397607

[1.2629202298015736, 0.1408691462459404, 0.21906278448983907, 2.2099483733617493, 0.8229096118585483, 1.1350877278887355, 1.9002037490143835, 1.3991041919788185, 2.1127537483083736, 2.499076158949197, 2.525848607453017, 0.6087583849595889, 1.8873554894707025, 3.638875927670343, 0.6137818689329289, 4.361009818498381, 3.738902219882758, 1.0307912118637181, 3.9246749053896925, 3.4913395374800147, 0.2907233964102577, 1.1558692030029158, 4.127433547111008, 3.41938209935872, 3.316398744283092, 0.4802309605048373, 3.2404625049196465, 4.614555552331012, 5.399159827153811, 5.555946593535216, 4.261095920381184, 5.044993336839137, 7.557974795921131, 1.939625968414094, 2.8822438461086115, 6.486290056567719, 2.6688889216795912, 12.814498981670221, 5.59267261342387, 4.852197383363837, 5.979861254332709, 12.045146700915431, 0.312660039643423, 10.590156467571887, 14.456313500009447, 2.619963932597001, 6.31172842129858, 15.139868056593754, 5.10534770058383, 9.86379362225695, 10.951118161020034, 3.33564

KeyboardInterrupt: 

[4.457087302607002, 0.0348193712614364, 0.09800999814169319, 5.940887005701388, 0.7424557344635673, 1.4811583049856034, 5.102438419472245, 1.4422064150888216, 2.256905974130185, 4.797627496657916, 3.4687464266273205, 1.595156378117904, 2.485710507685858, 1.2066284730894152, 0.06346392138620355, 1.3612147144067706, 1.0935852753003528, 2.591219957776082, 5.914579894636809, 4.773112127182426, 3.8409482201450182, 5.751688514474154, 2.977994733615125, 2.475773281706258, 4.755757313163485, 5.324811547001275, 3.313476407851201, 0.3057225808073328, 6.438011065318706, 3.98404983777613, 5.4824042629697605, 10.406569619471009, 1.709746400913867, 2.5579507289917496, 6.68132604942035, 8.825028517571827, 1.8072528727868606, 8.313466546230682, 7.727482453674108, 5.876678941941085, 10.686545757267192, 9.220188250088055, 1.4461708083622784, 12.84599036438464, 11.591522826766923, 1.6240188468250492, 6.7124426824607575, 15.559997753031894, 0.48205257717534133, 14.83707108506145, 9.916369187668778, 0.8680

CPU times: user 4.98 s, sys: 172 ms, total: 5.16 s
Wall time: 5.21 s
[2.5711037786863082, 0.6613032054929118, 0.26777564917734276, 5.4842863941165, 1.2871401073786517, 0.5374919361745343, 5.224665654918488, 2.4158238821460305, 0.8553804111440974, 3.2511119934976827, 0.40843526750071585, 3.4073426890221885, 3.313396494945308, 0.7352297560491351, 3.7263637765625917, 5.038929051157169, 2.0393417173360677, 3.842017490037484, 1.4451280620495286, 0.785648377922978, 0.49219491518280206, 0.38695328750729796, 4.979809591689162, 2.256782207661777, 2.9390637928952414, 8.467758260487786, 1.5147920612654775, 3.5881874783001146, 0.8392667802243426, 0.9786339734929956, 4.848495183074136, 9.994483704519912, 7.2527174903662335, 5.537811435749467, 5.6417637760451536, 10.896354392484954, 0.4055531751961664, 9.482073253020516, 11.769751758991951, 1.7522751171488427, 13.278971521754599, 5.064286721120427, 2.9604749325347983, 14.012384228798137, 8.537463662298592, 0.944206216020808, 10.513682457501497, 6.87

CPU times: user 1.17 s, sys: 31.2 ms, total: 1.2 s
Wall time: 1.21 s
[3.2173422322663567, 0.353799199397129, 0.09082245508509662, 4.573934209630002, 0.7045245474322309, 0.755984243243315, 5.120974168251575, 0.390413239819579, 2.7663152166017104, 4.084090003832282, 0.3126731361167761, 2.5572790099068783, 1.2039368070597058, 1.9639002139928237, 3.6544841221680646, 1.4383077981949515, 2.822066295210338, 5.715846746396154, 5.741306984652143, 6.575834055216866, 3.147949222547734, 1.0923559568957195, 2.7936905774831082, 1.1952104033790691, 2.188102471025347, 5.082289315923043, 0.8973360588439525, 4.008825264231371, 8.239872649867978, 8.988437883332882, 4.9957265907577515, 4.628825153338031, 5.912754885471504, 2.636417531701617, 2.5338579569469797, 9.625344625408998, 4.385649630428049, 12.332504751200638, 5.72309656667581, 0.6215326443916629, 9.080546792514708, 10.68871841766462, 5.283007632231129, 12.000597684608284, 14.277856527557239, 1.143635897331977, 12.660179153636744, 7.99734169703363

CPU times: user 1.75 s, sys: 15.6 ms, total: 1.77 s
Wall time: 1.83 s
[5.280561569385186, 0.9014434679213448, 0.7583893983612929, 3.160174223941783, 0.7536435543435416, 0.24856662511303593, 2.7292029275603404, 0.6329537808759865, 2.0116188507983845, 1.1155185425197958, 1.1106749463240706, 3.645987252335318, 5.798808850757247, 0.09029887287810467, 4.244519161960109, 0.9949500476942861, 3.064103655386985, 1.9879616149963866, 4.816932390164121, 4.831440022189175, 5.289243615739371, 2.631243951578844, 3.7720510066974455, 5.860800199070469, 1.8951928541102805, 4.462884367743525, 5.656111702491241, 3.44973317321282, 4.86894335572972, 0.22816335280303343, 3.9268279261580084, 5.790487562519997, 8.107777232295938, 3.9505406373603553, 11.847081266270653, 10.747381593798258, 4.765863670378079, 7.343552235265011, 5.565214798791277, 3.1776955143734495, 7.8431138663850195, 10.846784272436182, 4.019983808408949, 13.592694025257561, 10.983392311005444, 3.5915801966053422, 6.641392294325913, 11.4721942

CPU times: user 5.86 s, sys: 93.8 ms, total: 5.95 s
Wall time: 6.05 s
[3.903830787217805, 0.4512367765267806, 0.016009798903499805, 1.6537441818076855, 1.0905402312800145, 0.76273396946068, 2.8143731678454893, 1.2837909708126467, 2.209635799378307, 2.900856085453678, 2.2279257863694806, 2.280493506100762, 0.46958859703830336, 4.5850487147096475, 1.934538139165406, 4.29560250177814, 1.93588311201195, 0.9363815759252119, 3.1828734185709533, 5.103361513822417, 1.2584528547525824, 0.9067261671579272, 4.041083028064883, 6.422776263435786, 2.6031386770923737, 3.880590807691768, 1.599081381503773, 0.09364202869807645, 8.972433609284694, 8.753839189945559, 1.2502406383330409, 6.696700433840412, 7.515408020959027, 4.278921224893738, 2.9875428943051947, 6.639562058675256, 5.8127426038366234, 10.380470784451623, 7.810391721747051, 4.286672967670718, 6.978481841895154, 12.8242867314241, 5.950439805494963, 5.338721110177639, 7.398554785220435, 5.160854767914753, 15.027244901194925, 9.46322983448210

CPU times: user 6.61 s, sys: 188 ms, total: 6.8 s
Wall time: 7.12 s
[1.7482901661008396, 0.5791504420746278, 0.611581007446688, 1.3676378557184017, 0.5199313932950166, 0.09195603685298809, 3.4999986895669997, 2.9489090876525133, 2.388773152358191, 5.437150185398541, 2.9305156246082436, 0.3878286565530535, 4.225156738146916, 0.5415159026812177, 3.6739830867159777, 2.2955875110508757, 5.556149120341697, 0.19812263119157764, 3.958995232787643, 6.00828195708168, 4.819115098205207, 4.0078920270926535, 7.8794965574107065, 0.3091640170256078, 0.62959336185124, 7.957619580160522, 5.770174133929076, 1.3323621179019436, 9.243650983073264, 1.680541484283431, 0.996198295423578, 10.469604357608878, 8.816640896709387, 1.6461119849213925, 8.174020043094192, 10.116453518892879, 0.4776517637636818, 4.088456842425861, 7.333314279807041, 0.6403665324942897, 4.282807384524576, 4.8899599144115635, 3.8632524537780832, 14.699702070256249, 10.78725567475972, 3.9869803009271676, 8.444087612923132, 6.8976457321

CPU times: user 8.03 s, sys: 125 ms, total: 8.16 s
Wall time: 9.16 s
[0.624663709364595, 0.2004447642598938, 0.5467842709677133, 0.024395536593131324, 1.3279021026517568, 0.6645156176682901, 4.513030915470276, 1.4266903505352477, 2.425682763101525, 3.654691875526911, 0.6268778950065789, 0.4424339342361583, 1.3911203573965312, 2.7887691592249997, 0.6935141514777152, 2.804814903958982, 4.0370782315012415, 1.506075743586195, 0.6193932686251827, 5.514556991114839, 2.253575354094274, 2.3584591374255273, 2.839389869290426, 6.069889490901871, 2.7854699004011914, 1.1007932930657789, 0.48228943190811013, 1.4015445044847312, 1.3190604082242186, 6.95782995796264, 1.1500804413288748, 1.565564311259641, 6.497663467281447, 1.2954183573983558, 9.386973815594867, 10.194687611862053, 5.408915862848023, 3.4456790551132666, 11.158295072927995, 0.04612728975172098, 5.968364047619389, 10.832408508552163, 0.6362841769091978, 10.336771254815769, 9.1843157862539, 0.8816826588643503, 15.668326682205143, 12.361

CPU times: user 9.12 s, sys: 172 ms, total: 9.3 s
Wall time: 9.56 s
[1.5592948189050353, 0.06920867714829831, 0.0870375494218145, 3.2131599102617128, 0.21841671187269232, 1.9848285697684989, 5.429162000821303, 1.9836289868081902, 2.619052818889956, 5.10435980334306, 1.6090310189796018, 1.889347465126742, 0.4181330565519257, 0.6866323223745414, 0.6824787764912411, 2.3633800816097303, 5.430079665982033, 2.7729737259034883, 3.1756041667285206, 1.5312398201877104, 5.959828685400184, 4.659386553173474, 0.9361741844544035, 1.4926117052494758, 4.467613322177748, 0.5882899651697631, 1.6975064357051863, 4.1818811347355025, 5.325755182040467, 1.084241272116203, 1.1655405089357476, 7.743632131086712, 6.145166008733037, 4.116904966925575, 11.57062583931034, 6.476379113503662, 3.3789607232854615, 11.558916243064367, 7.405824143445648, 4.142656519173097, 8.57013312279447, 8.96238682882679, 3.851536334100632, 8.360081706067405, 8.688511039720868, 1.1760731499478383, 6.634834460325768, 8.7389098028149

CPU times: user 5.77 s, sys: 125 ms, total: 5.89 s
Wall time: 6.33 s
[1.4213720087078718, 0.3102876256070034, 0.4128917631710534, 5.619696660997455, 1.464967625108165, 1.7067062982454955, 3.4072712696273166, 0.029221048515562214, 0.9207544234138347, 2.048597554905968, 3.7092444452055418, 3.8124080013768187, 0.8266250782214001, 3.6835710403422794, 4.369324438374849, 1.836365444923906, 3.5071041671548326, 5.658529168568901, 0.9933122871231683, 5.150941174601009, 5.3437950038129385, 5.169734074292398, 5.144829104367927, 6.760741630370926, 4.522594639159932, 6.907952216984462, 4.687920329990723, 3.4624030420443974, 8.539252579426712, 0.027952951191014107, 3.2721628593128895, 2.468627062154038, 4.286257125254808, 2.2319261557392105, 3.7506229770254222, 9.356985033713862, 3.645025984725338, 6.050043537082301, 7.450025721075731, 3.232994985465272, 5.357276151125856, 6.2215018759872125, 5.5193181992146245, 8.281288128331123, 14.20414036675284, 5.347963062204276, 12.974336837430101, 11.98962583

KeyboardInterrupt: 

[0.5755400701037081, 0.7521436214596872, 0.2599349229763338, 1.6286105704171536, 1.7918296726833363, 1.0946244511510657, 1.8998182561184576, 0.3749384957809152, 0.40088079289811335, 5.888146873891344, 1.6411151618156865, 0.5803112503859583, 4.266332639496849, 2.5247866116445152, 0.1934917937960302, 1.7030416015687537, 1.2113439298842683, 1.1404850190505653, 1.605973652282841, 6.177652265067087, 0.27315298726283843, 5.499567136329991, 6.438054929356213, 1.8166254644331075, 5.9930170559760665, 0.45890466743636676, 0.21456847614860686, 4.959651801343075, 2.9412317417621487, 1.4469305018348169, 4.19554730489042, 8.933540733668792, 1.0042962356216407, 2.256776372395918, 4.260298130612998, 8.020353290238049, 0.598300891097429, 3.939171294158485, 12.091592656364092, 1.1860875476209631, 12.409190765331973, 6.098582293177952, 5.421656117192249, 12.847397986244413, 5.11769630373775, 1.7253920055818477, 9.538533379225683, 13.801145957190002, 4.7765765405901615, 9.220063172378143, 13.7319075852162

In [None]:
def plot_results(outs):
    costs = np.array(outs)
    print('min:', np.min(costs))
    print('mean:', np.mean(costs))
    print('median:', np.median(costs))
    print('variance:',np.var(costs))
    sns.distplot(costs, kde=False);

In [None]:
print('Nguyen4')
plot_results(results[0])

In [None]:
print('Nguyen7')
plot_results(results[1])

In [None]:
print('Nguyen10')
plot_results(results[2])

In [None]:
print('Korns12')
plot_results(results[3])