In [1]:
import pandas as pd
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import f1_score
from pandas.api.types import is_numeric_dtype

from genetic_decision_tree import GeneticDecisionTree

# List the OpenML datasets

In [2]:
real_files = [
    'soybean',
    'micro-mass',
    'mfeat-karhunen',
    'Amazon_employee_access',
    'abalone',
    'cnae-9',
    'semeion',
    'vehicle',
    'satimage',
    'analcatdata_authorship',
    'breast-w',
    'SpeedDating',
    'eucalyptus',
    'vowel',
    'wall-robot-navigation',
    'credit-approval',
    'artificial-characters',
    'splice',
    'har',
    'cmc',
    'segment',
    'JapaneseVowels',
    'jm1',
    'gas-drift',
    'mushroom',
    'irish',
    'profb',
    'adult',
    'higgs',
    'anneal',
    'credit-g',
    'blood-transfusion-service-center',
    'monks-problems-2',
    'tic-tac-toe',
    'qsar-biodeg',
    'wdbc',
    'phoneme',
    'diabetes',
    'ozone-level-8hr',
    'hill-valley',
    'kc2',
    'eeg-eye-state',
    'climate-model-simulation-crashes',
    'spambase',
    'ilpd',
    'one-hundred-plants-margin',
    'banknote-authentication',
    'mozilla4',
    'electricity',
    'madelon',
    'scene',
    'musk',
    'nomao',
    'bank-marketing',
    'MagicTelescope',
    'Click_prediction_small',
    'PhishingWebsites',
    'nursery',
    'page-blocks',
    'hypothyroid',
    'yeast',
    'kropt',
    'CreditCardSubset',
    'shuttle',
    'Satellite',
    'baseball',
    'mc1',
    'pc1',
    'cardiotocography',
    'kr-vs-k',
    'volcanoes-a1',
    'wine-quality-white',
    'car-evaluation',
    'solar-flare',
    'allbp',
    'allrep',
    'dis',
    'car',
    'steel-plates-fault'
    ]


In [3]:
# This defines a function to test a single file from OpenML. It creates
# a standard decision tree as well as four variations on the GeneticDecisionTree
# based on whether it peforms mutations and/or combinations. 

def test_dataset(dataset_name):
    def test_model(clf):
        clf.fit(X_train, y_train)
        y_pred = clf.predict(X_test)
        return f1_score(y_test, y_pred, average='macro')    
    
    np.random.seed(0)

    # Load the data
    data = fetch_openml(dataset_name, version=1, parser='auto') 
    df = pd.DataFrame(data['data'])
    y_true = data['target']
    display(df)
    
    # One-hot encode categorical columns unless there are too many unique values,
    # in which case, we drop the column.
    drop_cols = []
    for col_name in df.columns:
        if (not is_numeric_dtype(df[col_name])) and (df[col_name].nunique() > 10):
            drop_cols.append(col_name)
    df = df.drop(columns=drop_cols)
    if len(df.columns) == 0:
        print("All columns are categorical with many unique values")
        return None   
    df = pd.get_dummies(df)        

    # Divide the data into train and test
    X_train, X_test, y_train, y_test = train_test_split(df, y_true, test_size=0.3, random_state=42)

    # Fit and evaluate a standard decision tree
    clf = DecisionTreeClassifier(max_depth=4)
    dt_score = test_model(clf)
    print("DT:", dt_score)

    # Fit and evaluate a GeneticDecisionTree based only on random trees
    np.random.seed(0)
    max_iterations = 4
    gdt = GeneticDecisionTree(
        max_depth=4, max_iterations=max_iterations, allow_mutate=False, allow_combine=False, n_jobs=-1, verbose=True)
    score1 = test_model(gdt)
    print("Genetic DT:", score1)

    # Fit and evaluate a GeneticDecisionTree allowing mutations of strong trees
    gdt = GeneticDecisionTree(
        max_depth=4, max_iterations=max_iterations, allow_mutate=True, allow_combine=False, n_jobs=-1, verbose=True)
    score2 = test_model(gdt)
    print("Genetic DT:", score2)    

    # Fit and evaluate a GeneticDecisionTree allowing combinations of pairs of strong trees
    gdt = GeneticDecisionTree(
        max_depth=4, max_iterations=max_iterations, allow_mutate=False, allow_combine=True, n_jobs=-1, verbose=True)
    score3 = test_model(gdt)
    print("Genetic DT:", score3)

    # Fit and evaluate a GeneticDecisionTree allowing both mutations and combinations
    gdt = GeneticDecisionTree(
        max_depth=4, max_iterations=max_iterations, allow_mutate=True, allow_combine=True, n_jobs=-1, verbose=True)
    score4 = test_model(gdt)
    print("Genetic DT:", score4)

    return [dt_score, score1, score2, score3, score4]

In [None]:
# Loop through the real datasets on OpenML and display, for each,
# how standard and genetic decision trees compare. 

display_rows = []
display_dt = None

for file_name in real_files:
    print(".................................................................")
    print(file_name)
    results = test_dataset(file_name)
    if not results:
        continue
    display_rows.append([file_name] + results)
    display_dt = pd.DataFrame(display_rows, columns=[
        'File Name', 
        "DT", 
        "GDT (random only)", 
        "GDT (with mutations)", 
        "GDT (with combinations)", 
        "GDT (with both)"])
    display(display_dt)
    
print()
print("Final Results")
display(display_dt)

.................................................................
soybean


Unnamed: 0,date,plant-stand,precip,temp,hail,crop-hist,area-damaged,severity,seed-tmt,germination,...,int-discolor,sclerotia,fruit-pods,fruit-spots,seed,mold-growth,seed-discolor,seed-size,shriveling,roots
0,october,normal,gt-norm,norm,yes,same-lst-yr,low-areas,pot-severe,none,90-100,...,none,absent,norm,dna,norm,absent,absent,norm,absent,norm
1,august,normal,gt-norm,norm,yes,same-lst-two-yrs,scattered,severe,fungicide,80-89,...,none,absent,norm,dna,norm,absent,absent,norm,absent,norm
2,july,normal,gt-norm,norm,yes,same-lst-yr,scattered,severe,fungicide,lt-80,...,none,absent,norm,dna,norm,absent,absent,norm,absent,norm
3,july,normal,gt-norm,norm,yes,same-lst-yr,scattered,severe,none,80-89,...,none,absent,norm,dna,norm,absent,absent,norm,absent,norm
4,october,normal,gt-norm,norm,yes,same-lst-two-yrs,scattered,pot-severe,none,lt-80,...,none,absent,norm,dna,norm,absent,absent,norm,absent,norm
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
678,april,,,,,,upper-areas,,,,...,,,,,,,,,,
679,april,lt-normal,,lt-norm,,diff-lst-year,scattered,,,,...,,,dna,,,,,,,rotted
680,june,lt-normal,,lt-norm,,diff-lst-year,scattered,,,,...,,,dna,,,,,,,rotted
681,april,lt-normal,,lt-norm,,same-lst-yr,whole-field,,,,...,,,dna,,,,,,,rotted


DT: 0.36059592200066853

Iteration: 1
Top (training) scores so far: ['0.411', '0.402', '0.399', '0.397', '0.396', '0.396', '0.396', '0.381', '0.380', '0.379']

Iteration: 2
Top (training) scores so far: ['0.411', '0.402', '0.399', '0.397', '0.396', '0.396', '0.396', '0.381', '0.380', '0.379']

Iteration: 3
Top (training) scores so far: ['0.411', '0.407', '0.404', '0.402', '0.399', '0.397', '0.396', '0.396', '0.396', '0.392']

Iteration: 4
Top (training) scores so far: ['0.437', '0.411', '0.407', '0.404', '0.402', '0.399', '0.397', '0.397', '0.396', '0.396']
Genetic DT: 0.463165050741131

Iteration: 1
Top (training) scores so far: ['0.421', '0.419', '0.406', '0.391', '0.390', '0.383', '0.383', '0.378', '0.377', '0.375']
Number in top 20 based on mutation: 3

Iteration: 2
Top (training) scores so far: ['0.424', '0.424', '0.424', '0.424', '0.424', '0.424', '0.424', '0.424', '0.424', '0.424']
Number in top 20 based on mutation: 20

Iteration: 3
Top (training) scores so far: ['0.424', '0.42

Unnamed: 0,File Name,DT,GDT (random only),GDT (with mutations),GDT (with combinations),GDT (with both)
0,soybean,0.360596,0.463165,0.443073,0.444158,0.498024


.................................................................
micro-mass


Unnamed: 0,V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,...,V1291,V1292,V1293,V1294,V1295,V1296,V1297,V1298,V1299,V1300
0,0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.000000,0.000000,...,0.0,0,0,0.0,0.0,0.0,0.000000,0.0,0.0,0.000000
1,0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.000000,10162.613281,...,0.0,0,0,0.0,0.0,0.0,11427.124023,0.0,0.0,0.000000
2,0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.000000,0.000000,...,0.0,0,0,0.0,0.0,0.0,12094.415039,0.0,0.0,0.000000
3,0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.000000,0.000000,...,0.0,0,0,0.0,0.0,0.0,0.000000,0.0,0.0,0.000000
4,0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.000000,15409.125977,...,0.0,0,0,0.0,0.0,0.0,0.000000,0.0,0.0,35418.402344
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
355,0,0.0,0.0,0.0,0.0,0,0.0,0.0,31279.455078,0.000000,...,0.0,0,0,0.0,0.0,0.0,0.000000,0.0,0.0,0.000000
356,0,0.0,0.0,0.0,0.0,0,0.0,0.0,31064.574219,0.000000,...,0.0,0,0,0.0,0.0,0.0,0.000000,0.0,0.0,0.000000
357,0,0.0,0.0,0.0,0.0,0,0.0,0.0,17840.945313,0.000000,...,0.0,0,0,0.0,0.0,0.0,0.000000,0.0,0.0,0.000000
358,0,0.0,0.0,0.0,0.0,0,0.0,0.0,82506.257813,0.000000,...,0.0,0,0,0.0,0.0,0.0,0.000000,0.0,0.0,0.000000


DT: 0.4296456142843611

Iteration: 1
Top (training) scores so far: ['0.809', '0.807', '0.792', '0.783', '0.741', '0.731', '0.720', '0.709', '0.695', '0.685']

Iteration: 2
Top (training) scores so far: ['0.809', '0.807', '0.804', '0.792', '0.783', '0.741', '0.731', '0.720', '0.709', '0.705']

Iteration: 3
Top (training) scores so far: ['0.809', '0.807', '0.804', '0.792', '0.783', '0.741', '0.741', '0.731', '0.727', '0.720']

Iteration: 4
Top (training) scores so far: ['0.809', '0.807', '0.804', '0.792', '0.783', '0.741', '0.741', '0.731', '0.727', '0.720']
Genetic DT: 0.6748186375028481

Iteration: 1
Top (training) scores so far: ['0.813', '0.804', '0.798', '0.798', '0.748', '0.718', '0.716', '0.714', '0.698', '0.695']
Number in top 20 based on mutation: 0

Iteration: 2
Top (training) scores so far: ['0.825', '0.813', '0.806', '0.806', '0.806', '0.806', '0.806', '0.806', '0.806', '0.805']
Number in top 20 based on mutation: 13

Iteration: 3
Top (training) scores so far: ['0.825', '0.82

Unnamed: 0,File Name,DT,GDT (random only),GDT (with mutations),GDT (with combinations),GDT (with both)
0,soybean,0.360596,0.463165,0.443073,0.444158,0.498024
1,micro-mass,0.429646,0.674819,0.664863,0.705724,0.694583


.................................................................
mfeat-karhunen


Unnamed: 0,att1,att2,att3,att4,att5,att6,att7,att8,att9,att10,...,att55,att56,att57,att58,att59,att60,att61,att62,att63,att64
0,-10.297008,-11.666789,11.560669,-2.081316,4.044656,4.086815,-2.558072,-8.476935,2.138135,3.503082,...,1.078083,0.921927,0.496387,-0.643667,0.284104,0.286555,0.348625,1.814691,-1.351353,-0.473910
1,-5.036009,-12.885333,0.161155,0.592460,3.123534,4.220469,-6.411771,-6.335328,-0.244622,1.346073,...,0.942353,2.938791,1.429883,-2.336344,1.281628,-0.098321,0.582357,0.485792,0.642451,0.613107
2,-9.639157,-6.655898,0.388687,-1.717650,0.300346,3.400769,-7.240785,-1.659405,-0.874005,4.153403,...,-0.413174,-0.023028,-0.025265,1.259838,-0.441360,-0.960094,1.995843,1.097748,0.827182,-1.767840
3,-6.650375,-7.043851,4.104350,-2.342780,3.494658,3.924822,-9.874812,-6.556576,-1.364269,1.153308,...,-0.961236,-1.043815,-0.204508,-1.981150,0.982438,-0.144233,-1.449328,-0.913552,-0.771735,0.304992
4,-10.664524,-10.974133,0.194391,0.453885,2.193088,-3.304663,-8.376592,-4.241146,2.964818,-0.949622,...,0.152957,1.448160,-1.254907,-3.481295,-0.563889,1.529335,0.510399,0.298318,-0.943213,1.149847
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1995,-2.415248,-6.619806,5.053538,6.662300,12.136673,-1.447842,-2.321873,4.042169,-2.981806,-0.106785,...,-1.438355,-0.714285,0.017051,0.460572,-0.951763,0.241901,-0.399051,-0.304857,-0.068411,-1.049052
1996,5.892684,-8.185875,1.819305,6.871263,1.021332,-0.869375,-6.759738,-3.891993,-4.781352,3.355656,...,-0.672254,1.273016,0.227573,0.444086,1.439473,-0.405706,0.378187,-0.128056,0.925637,1.798053
1997,1.881613,-9.650881,0.317780,0.655888,7.882648,1.740497,0.026943,-4.412813,-3.403179,-0.614610,...,-0.121590,-1.622687,0.309964,0.473773,0.916683,0.971719,0.689472,-0.439637,0.287013,-0.420793
1998,-1.530886,-10.183775,-1.055864,4.956079,11.729954,1.480784,-2.806543,0.602515,-5.411981,-2.165543,...,-0.220936,-0.466334,0.128358,-0.888494,-0.014442,-0.780897,1.000286,1.405214,0.435514,-0.225426


DT: 0.5012456246144168

Iteration: 1
Top (training) scores so far: ['0.657', '0.652', '0.648', '0.646', '0.638', '0.637', '0.635', '0.622', '0.621', '0.616']

Iteration: 2
Top (training) scores so far: ['0.675', '0.657', '0.652', '0.649', '0.648', '0.646', '0.638', '0.637', '0.637', '0.635']

Iteration: 3
Top (training) scores so far: ['0.675', '0.657', '0.652', '0.649', '0.648', '0.646', '0.638', '0.637', '0.637', '0.635']

Iteration: 4
Top (training) scores so far: ['0.675', '0.657', '0.652', '0.649', '0.648', '0.646', '0.644', '0.643', '0.638', '0.637']
Genetic DT: 0.6389853119156605

Iteration: 1
Top (training) scores so far: ['0.648', '0.647', '0.647', '0.644', '0.643', '0.631', '0.630', '0.628', '0.626', '0.623']
Number in top 20 based on mutation: 0

Iteration: 2
Top (training) scores so far: ['0.656', '0.654', '0.653', '0.653', '0.652', '0.652', '0.652', '0.652', '0.652', '0.652']
Number in top 20 based on mutation: 20

Iteration: 3
Top (training) scores so far: ['0.690', '0.66