Skip to content

Commit

Permalink
Fixed preproc test
Browse files Browse the repository at this point in the history
  • Loading branch information
Robert Meyer committed Feb 20, 2018
1 parent c9b52de commit 671a2b9
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
2 changes: 1 addition & 1 deletion scripts/do_cross_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def main():

gc.collect()

regressor_kwargs = dict(n_estimators=256, max_leaf_nodes=4096,
regressor_kwargs = dict(n_estimators=256, max_leaf_nodes=10000,
max_features=0.2, n_jobs=-1, verbose=1,
random_state=42)

Expand Down
4 changes: 2 additions & 2 deletions trufflepig/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def main():
if not tpmo.model_exists(current_datetime, model_directoy):
post_frame = tpgd.load_or_scrape_training_data(steem_kwargs, data_directory,
current_datetime=current_datetime,
days=7,
days=8,
offset_days=8,
ncores=32)

Expand All @@ -75,7 +75,7 @@ def main():
else:
post_frame = None

regressor_kwargs = dict(n_estimators=256, max_leaf_nodes=4096,
regressor_kwargs = dict(n_estimators=256, max_leaf_nodes=10000,
max_features=0.2, n_jobs=-1, verbose=1,
random_state=42)

Expand Down
10 changes: 6 additions & 4 deletions trufflepig/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,11 @@ def load_or_train_pipeline(post_frame, directory, current_datetime=None,
def find_truffles(post_frame, pipeline, min_max_reward=(1.0, 10), min_votes=5, k=10):
logger.info('Looking for truffles and filtering preprocessed data further. '
'min max reward {} and min votes {}'.format(min_max_reward, min_votes))
post_frame = post_frame.loc[(post_frame.reward >= min_max_reward[0]) &
(post_frame.reward <= min_max_reward[1]) &
(post_frame.votes >= min_votes)]
to_drop = post_frame.loc[(post_frame.reward < min_max_reward[0]) |
(post_frame.reward > min_max_reward[1]) |
(post_frame.votes < min_votes)]

post_frame.drop(to_drop.index, inplace=True)

logger.info('Predicting truffles')
predicted_rewards_and_votes = pipeline.predict(post_frame)
Expand Down Expand Up @@ -405,4 +407,4 @@ def log_pipeline_info(pipeline):
feature_importance_string = 'Feature importances \n'
for kdx, importance in enumerate(pipeline.named_steps['regressor'].feature_importances_):
feature_importance_string += '{:03d}: {:.3f}\n'.format(kdx, importance)
logger.info(feature_importance_string)
logger.info(feature_importance_string)

0 comments on commit 671a2b9

Please sign in to comment.