Skip to content

Commit

Permalink
Fixed #10
Browse files Browse the repository at this point in the history
  • Loading branch information
rhiever committed Nov 12, 2015
1 parent 5b42c3b commit a1c95cb
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions tpot/tpot.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,38 @@ def optimize(self, features, classes, feature_names=None):
except KeyboardInterrupt:
self.optimized_pipeline = self.hof[0]

def predict(self, training_features, training_classes, testing_features):
'''
Uses the optimized pipeline to predict the classes for a feature set.
'''
if self.optimized_pipeline == None:
raise Exception('A pipeline has not yet been optimized. '
'Please call the optimize() function first.')

self.best_features_cache = {}

training_data = pd.DataFrame(training_features)
training_data['class'] = training_classes
training_data['group'] = 'training'

testing_data = pd.DataFrame(testing_features)
testing_data['class'] = 0
testing_data['group'] = 'testing'

training_testing_data = pd.concat([training_data, testing_data])
most_frequent_class = Counter(training_classes).most_common(1)[0][0]
training_testing_data['guess'] = most_frequent_class

for column in training_testing_data.columns.values:
if type(column) != str:
training_testing_data.rename(columns={column: str(column).zfill(5)}, inplace=True)

# Transform the tree expression in a callable function
func = self.toolbox.compile(expr=self.optimized_pipeline)

result = func(training_testing_data)
return result[result['group'] == 'testing', 'guess'].values

def score(self, training_features, training_classes, testing_features, testing_classes):
'''
Estimates the testing accuracy of the optimized pipeline.
Expand Down

0 comments on commit a1c95cb

Please sign in to comment.