From 671a2b9002f63b1f25d85240b0773e47e191c5f7 Mon Sep 17 00:00:00 2001 From: Robert Meyer Date: Tue, 20 Feb 2018 14:18:30 +0100 Subject: [PATCH] Fixed preproc test --- scripts/do_cross_val.py | 2 +- trufflepig/main.py | 4 ++-- trufflepig/model.py | 10 ++++++---- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/scripts/do_cross_val.py b/scripts/do_cross_val.py index 85a15e2..fff2240 100644 --- a/scripts/do_cross_val.py +++ b/scripts/do_cross_val.py @@ -29,7 +29,7 @@ def main(): gc.collect() - regressor_kwargs = dict(n_estimators=256, max_leaf_nodes=4096, + regressor_kwargs = dict(n_estimators=256, max_leaf_nodes=10000, max_features=0.2, n_jobs=-1, verbose=1, random_state=42) diff --git a/trufflepig/main.py b/trufflepig/main.py index ff58198..5634b58 100644 --- a/trufflepig/main.py +++ b/trufflepig/main.py @@ -65,7 +65,7 @@ def main(): if not tpmo.model_exists(current_datetime, model_directoy): post_frame = tpgd.load_or_scrape_training_data(steem_kwargs, data_directory, current_datetime=current_datetime, - days=7, + days=8, offset_days=8, ncores=32) @@ -75,7 +75,7 @@ def main(): else: post_frame = None - regressor_kwargs = dict(n_estimators=256, max_leaf_nodes=4096, + regressor_kwargs = dict(n_estimators=256, max_leaf_nodes=10000, max_features=0.2, n_jobs=-1, verbose=1, random_state=42) diff --git a/trufflepig/model.py b/trufflepig/model.py index 6db6e79..d2e4145 100644 --- a/trufflepig/model.py +++ b/trufflepig/model.py @@ -364,9 +364,11 @@ def load_or_train_pipeline(post_frame, directory, current_datetime=None, def find_truffles(post_frame, pipeline, min_max_reward=(1.0, 10), min_votes=5, k=10): logger.info('Looking for truffles and filtering preprocessed data further. ' 'min max reward {} and min votes {}'.format(min_max_reward, min_votes)) - post_frame = post_frame.loc[(post_frame.reward >= min_max_reward[0]) & - (post_frame.reward <= min_max_reward[1]) & - (post_frame.votes >= min_votes)] + to_drop = post_frame.loc[(post_frame.reward < min_max_reward[0]) | + (post_frame.reward > min_max_reward[1]) | + (post_frame.votes < min_votes)] + + post_frame.drop(to_drop.index, inplace=True) logger.info('Predicting truffles') predicted_rewards_and_votes = pipeline.predict(post_frame) @@ -405,4 +407,4 @@ def log_pipeline_info(pipeline): feature_importance_string = 'Feature importances \n' for kdx, importance in enumerate(pipeline.named_steps['regressor'].feature_importances_): feature_importance_string += '{:03d}: {:.3f}\n'.format(kdx, importance) - logger.info(feature_importance_string) \ No newline at end of file + logger.info(feature_importance_string)