In [12]:
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

import tengp

# load data
X, y = load_iris(return_X_y=True)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# define function set
def pdivide(a, b):
    return np.divide(a, b, out=np.zeros_like(a), where=b!=0)

def plog(a):
    return np.log(a, out=np.zeros_like(a), where=a>0)

funset = tengp.FunctionSet()
funset.add(np.add, 2)
funset.add(np.multiply, 2)
funset.add(pdivide, 2)
funset.add(plog, 1)

# define cost function
def cost_function(y, y_pred):
    labels = np.array(y_pred).argmax(axis=0)
    return -accuracy_score(y, labels)

# tie everything together
params = tengp.Parameters(4, 3, n_columns=25, n_rows=1, function_set=funset, use_tensorflow=False)

res = tengp.simple_es(X_train, y_train, cost_function, params, target_fitness=-1, random_state=42)
# or one can use CMA-ES:
# hof, res = tengp.cma_es(...)

# evaluate the best individual
y_pred = res[0].transform(X_test)
labels = np.array(y_pred).argmax(axis=1)
print('Accuracy on test: {:.2}'.format(accuracy_score(y_test, labels)))

Accuracy on test: 0.37


In [13]:
y_pred

array([[ 1.18476194, -0.52360532,  3.26990828],
       [ 1.10830599, -0.59031447,  3.02922251],
       [ 1.42709566, -0.33750581,  4.16658043],
       [ 1.16639616, -0.53922839,  3.210402  ],
       [ 1.30144219, -0.42967416,  3.6745923 ],
       [ 1.04519092, -0.64894762,  2.84394143],
       [ 1.08786297, -0.60893199,  2.96792475],
       [ 1.31661598, -0.41808239,  3.73077496],
       [ 1.20266599, -0.50860643,  3.32898012],
       [ 1.12819195, -0.57253087,  3.09006446],
       [ 1.2538034 , -0.46696553,  3.50364339],
       [ 0.9003873 , -0.79807746,  2.4605559 ],
       [ 1.06683471, -0.62845113,  2.90616606],
       [ 0.92650579, -0.76948216,  2.52566854],
       [ 0.97622796, -0.71720634,  2.65442474],
       [ 1.22012848, -0.49419101,  3.38762295],
       [ 1.2538034 , -0.46696553,  3.50364339],
       [ 1.08786297, -0.60893199,  2.96792475],
       [ 1.10830599, -0.59031447,  3.02922251],
       [ 1.23716835, -0.480322  ,  3.44584223],
       [ 0.87336224, -0.82855206,  2.394