Skip to content

Commit

Permalink
Fixed #52
Browse files Browse the repository at this point in the history
  • Loading branch information
rhiever committed Dec 8, 2015
1 parent 3f95931 commit 5620a70
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tpot/tpot.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ def fit(self, features, classes, feature_names=None):

training_indeces, testing_indeces = next(iter(StratifiedShuffleSplit(training_testing_data['class'].values,
n_iter=1,
train_size=0.75)))
train_size=0.75,
test_size=0.25)))

training_testing_data.loc[training_indeces, 'group'] = 'training'
training_testing_data.loc[testing_indeces, 'group'] = 'testing'
Expand Down Expand Up @@ -427,7 +428,7 @@ def export(self, output_file_name):
pipeline_text += '''
# NOTE: Make sure that the class is labeled 'class' in the data file
tpot_data = pd.read_csv('PATH/TO/DATA/FILE', sep='COLUMN_SEPARATOR')
training_indeces, testing_indeces = next(iter(StratifiedShuffleSplit(tpot_data['class'].values, n_iter=1, train_size=0.75)))
training_indeces, testing_indeces = next(iter(StratifiedShuffleSplit(tpot_data['class'].values, n_iter=1, train_size=0.75, test_size=0.25)))
'''
# Replace the function calls with their corresponding Python code
Expand Down Expand Up @@ -1452,6 +1453,7 @@ def float_range(value):
training_indeces, testing_indeces = next(iter(StratifiedShuffleSplit(input_data['class'].values,
n_iter=1,
train_size=0.75,
test_size=0.25,
random_state=random_state)))

training_features = input_data.loc[training_indeces].drop('class', axis=1).values
Expand Down

0 comments on commit 5620a70

Please sign in to comment.