Skip to content

Commit

Permalink
Fixed #8
Browse files Browse the repository at this point in the history
  • Loading branch information
rhiever committed Nov 12, 2015
1 parent 4fb50e8 commit e21d255
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions tpot/tpot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
import random
import hashlib
from itertools import combinations
from collections import Counter

import numpy as np
import pandas as pd

from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.cross_validation import StratifiedShuffleSplit
from sklearn.cross_validation import StratifiedKFold

from deap import algorithms
from deap import base
Expand All @@ -43,7 +43,7 @@ class TPOT:
optimized_pipeline = None
best_features_cache = {}

def __init__(self, population_size=100, generations=1000,
def __init__(self, population_size=100, generations=100,
mutation_rate=0.9, crossover_rate=0.05):
'''
Sets up the genetic programming algorithm for pipeline optimization.
Expand Down Expand Up @@ -95,7 +95,6 @@ def optimize(self, features, classes, feature_names=None):

training_testing_data = pd.DataFrame(data=features, columns=feature_names)
training_testing_data['class'] = classes
training_testing_data['guess'] = 0

for column in training_testing_data.columns.values:
if type(column) != str:
Expand All @@ -113,6 +112,10 @@ def optimize(self, features, classes, feature_names=None):

training_testing_data.loc[training_indeces, 'group'] = 'training'
training_testing_data.loc[testing_indeces, 'group'] = 'testing'

# Default the basic guess to the most frequent class
most_frequent_class = Counter(training_testing_data.loc[training_indeces, 'class'].values).most_common(1)[0][0]
training_testing_data['guess'] = most_frequent_class

self.toolbox.register('evaluate', self.evaluate_individual, training_testing_data=training_testing_data)

Expand Down Expand Up @@ -155,7 +158,8 @@ def score(self, training_features, training_classes, testing_features, testing_c
testing_data['group'] = 'testing'

training_testing_data = pd.concat([training_data, testing_data])
training_testing_data['guess'] = 0
most_frequent_class = Counter(training_classes).most_common(1)[0][0]
training_testing_data['guess'] = most_frequent_class

for column in training_testing_data.columns.values:
if type(column) != str:
Expand All @@ -178,7 +182,6 @@ def decision_tree(input_df, max_features, max_depth):
input_df = input_df.copy()

if len(input_df.columns) == 3:
input_df['guess'] = 0
return input_df

training_features = input_df.loc[input_df['group'] == 'training'].drop(['class', 'group', 'guess'], axis=1).values
Expand Down Expand Up @@ -218,7 +221,6 @@ def random_forest(input_df, num_trees, max_features):
input_df = input_df.copy()

if len(input_df.columns) == 3:
input_df['guess'] = 0
return input_df

training_features = input_df.loc[input_df['group'] == 'training'].drop(['class', 'group', 'guess'], axis=1).values
Expand Down

0 comments on commit e21d255

Please sign in to comment.