Skip to content

Commit

Permalink
Merge pull request #408 from HealthCatalyst/407
Browse files Browse the repository at this point in the history
Tiny PR
  • Loading branch information
Aylr committed Oct 20, 2017
2 parents a24a272 + 97ea851 commit 9d89f89
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 57 deletions.
45 changes: 0 additions & 45 deletions healthcareai/common/output.py

This file was deleted.

65 changes: 65 additions & 0 deletions healthcareai/common/trainer_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
outputs.py is a decorator function that prints to stdout.
This eliminates lots of boilerplate code in superviseModelTrainer.
"""

import inspect
from functools import partial, wraps


def trainer_output(func=None, *, debug=False):
"""
Trainer output decorator for functions that train models.
This is a decorator that can be applied to any function, and it will print
helpful information to the console such as the model type, and training
results.
Args:
func (function): Function to be applied with decorator.
debug (bool): Debug option true or false.
* (params): trainer_output arguments.
Returns:
trained_model|function: returns trained_model when called without a
function, or returns a callable when supplied with arguments.
"""

if func is None:
# This func is only None when extra arguments are supplied, return a
# callable instead, which will get run and goes to the def wrap. Handy
# way of using decorators with extra arguments.
return partial(trainer_output, debug=debug)

# Wrap around our function so that if debug is true, we can print out
# inputs and outputs. The @wraps decorator copies the parent function's
# attributes, such as __name__, and input parameters.
@wraps(func)
def wrap(self, *args, **kwargs):
# Since we have decorated the function and self at runtime, we can get
# the name of the model, and construct the name out of the function
# name. Then use self's model type to output the model type (regression
# or classification)

algorithm_name = " ".join(func.__name__.split("_")).title()
print("Training: {} , Type: {}".format(
algorithm_name,
self._advanced_trainer.model_type))
trained_model = func(self, *args, **kwargs)
trained_model.print_training_results()

# If debug is true, output the function name, default, argument, and
# returns.
if debug:
print("Function Name: {}, Function Defaults: {}, "
"Function Args: {} {}, Function Return: {}".format(
func.__name__,
inspect.signature(func),
args,
kwargs,
trained_model))

return trained_model

return wrap
32 changes: 20 additions & 12 deletions healthcareai/supervised_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import healthcareai.common.cardinality_checks as hcai_ordinality
from healthcareai.advanced_supvervised_model_trainer import AdvancedSupervisedModelTrainer
from healthcareai.common.get_categorical_levels import get_categorical_levels
from healthcareai.common.output import trainer_output
from healthcareai.common.trainer_output import trainer_output


class SupervisedModelTrainer(object):
Expand Down Expand Up @@ -101,7 +101,10 @@ def knn(self):
Returns:
TrainedSupervisedModel: A trained supervised model.
"""
return self._advanced_trainer.knn(scoring_metric='roc_auc', hyperparameter_grid=None, randomized_search=True)
return self._advanced_trainer.knn(
scoring_metric='roc_auc',
hyperparameter_grid=None,
randomized_search=True)

@trainer_output
def random_forest_regression(self):
Expand All @@ -110,8 +113,10 @@ def random_forest_regression(self):
Returns:
TrainedSupervisedModel: A trained supervised model.
"""
return self._advanced_trainer.random_forest_regressor(trees=200, scoring_metric='neg_mean_squared_error',
randomized_search=True)
return self._advanced_trainer.random_forest_regressor(
trees=200,
scoring_metric='neg_mean_squared_error',
randomized_search=True)

@trainer_output
def random_forest_classification(self, feature_importance_limit=15, save_plot=False):
Expand All @@ -125,15 +130,17 @@ def random_forest_classification(self, feature_importance_limit=15, save_plot=Fa
Returns:
TrainedSupervisedModel: A trained supervised model.
"""
model = self._advanced_trainer.random_forest_classifier(trees=200,
scoring_metric='roc_auc',
randomized_search=True)
model = self._advanced_trainer.random_forest_classifier(
trees=200,
scoring_metric='roc_auc',
randomized_search=True)

# Save or show the feature importance graph
hcai_tsm.plot_rf_features_from_tsm(model,
self._advanced_trainer.x_train,
feature_limit=feature_importance_limit,
save=save_plot)
hcai_tsm.plot_rf_features_from_tsm(
model,
self._advanced_trainer.x_train,
feature_limit=feature_importance_limit,
save=save_plot)

return model

Expand All @@ -144,7 +151,8 @@ def logistic_regression(self):
Returns:
TrainedSupervisedModel: A trained supervised model.
"""
return self._advanced_trainer.logistic_regression(randomized_search=False)
return self._advanced_trainer.logistic_regression(
randomized_search=False)

@trainer_output
def linear_regression(self):
Expand Down

0 comments on commit 9d89f89

Please sign in to comment.