Skip to content

Commit

Permalink
Tabular Predictor Model Parameter (#210)
Browse files Browse the repository at this point in the history
* Added model parameter to tabular predictor so users can specify the model they want to get predictions for rather than all only allowing the best model results.

* Added model_names variable to tabular predictor

* Minor syntax update
  • Loading branch information
Innixma committed Jan 15, 2020
1 parent e345c0b commit 722a1b6
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 12 deletions.
15 changes: 11 additions & 4 deletions autogluon/task/tabular_prediction/predictor.py
Expand Up @@ -27,6 +27,8 @@ class TabularPredictor(BasePredictor):
Name of table column that contains data from the variable to predict (often referred to as: labels, response variable, target variable, dependent variable, Y, etc).
feature_types : dict
Inferred data type of each predictive variable (i.e. column of training data table used to predict `label_column`).
model_names : list
List of model names trained during `fit()`.
model_performance : dict
Maps names of trained models to their predictive performance values attained on the validation dataset during `fit()`.
class_labels : list
Expand Down Expand Up @@ -63,17 +65,20 @@ def __init__(self, learner):
self.eval_metric = self._learner.objective_func
self.label_column = self._learner.label
self.feature_types = self._trainer.feature_types_metadata
self.model_names = self._trainer.get_model_names_all()
self.model_performance = self._trainer.model_performance
self.class_labels = self._learner.class_labels

def predict(self, dataset, as_pandas=False, use_pred_cache=False, add_to_pred_cache=False):
def predict(self, dataset, model=None, as_pandas=False, use_pred_cache=False, add_to_pred_cache=False):
""" Use trained models to produce predicted labels (in classification) or response values (in regression).
Parameters
----------
dataset : :class:`TabularDataset` or `pandas.DataFrame`
The dataset to make predictions for. Should contain same column names as training Dataset and follow same format
(may contain extra columns that won't be used by Predictor, including the label-column itself).
model : str (optional)
The name of the model to get predictions from. Defaults to None, which uses the highest scoring model on the validation set.
as_pandas : bool (optional)
Whether to return the output as a pandas Series (True) or numpy array (False)
use_pred_cache : bool (optional)
Expand All @@ -91,16 +96,18 @@ def predict(self, dataset, as_pandas=False, use_pred_cache=False, add_to_pred_ca
if isinstance(dataset, pd.Series):
raise TypeError("dataset must be TabularDataset or pandas.DataFrame, not pandas.Series. \
To predict on just single example (ith row of table), use dataset.iloc[[i]] rather than dataset.iloc[i]")
return self._learner.predict(X_test=dataset, as_pandas=as_pandas, use_pred_cache=use_pred_cache, add_to_pred_cache=add_to_pred_cache)
return self._learner.predict(X_test=dataset, model=model, as_pandas=as_pandas, use_pred_cache=use_pred_cache, add_to_pred_cache=add_to_pred_cache)

def predict_proba(self, dataset, as_pandas=False):
def predict_proba(self, dataset, model=None, as_pandas=False):
""" Use trained models to produce predicted class probabilities rather than class-labels (if task is classification).
Parameters
----------
dataset : :class:`TabularDataset` or `pandas.DataFrame`
The dataset to make predictions for. Should contain same column names as training Dataset and follow same format
(may contain extra columns that won't be used by Predictor, including the label-column itself).
model : str (optional)
The name of the model to get prediction probabilities from. Defaults to None, which uses the highest scoring model on the validation set.
as_pandas : bool (optional)
Whether to return the output as a pandas object (True) or numpy array (False).
Pandas object is a DataFrame if this is a multiclass problem, otherwise it is a Series.
Expand All @@ -113,7 +120,7 @@ def predict_proba(self, dataset, as_pandas=False):
if isinstance(dataset, pd.Series):
raise TypeError("dataset must be TabularDataset or pandas.DataFrame, not pandas.Series. \
To predict on just single example (ith row of table), use dataset.iloc[[i]] rather than dataset.iloc[i]")
return self._learner.predict_proba(X_test=dataset, as_pandas=as_pandas)
return self._learner.predict_proba(X_test=dataset, model=model, as_pandas=as_pandas)

def evaluate(self, dataset, silent=False):
""" Report the predictive performance evaluated for a given Dataset.
Expand Down
8 changes: 4 additions & 4 deletions autogluon/utils/tabular/ml/learner/abstract_learner.py
Expand Up @@ -71,7 +71,7 @@ def fit(self, X: DataFrame, X_test: DataFrame = None, scheduler_options=None, hy
raise NotImplementedError

# TODO: Add pred_proba_cache functionality as in predict()
def predict_proba(self, X_test: DataFrame, as_pandas=False, inverse_transform=True, sample=None):
def predict_proba(self, X_test: DataFrame, model=None, as_pandas=False, inverse_transform=True, sample=None):
##########
# Enable below for local testing # TODO: do we want to keep sample option?
if sample is not None:
Expand All @@ -80,7 +80,7 @@ def predict_proba(self, X_test: DataFrame, as_pandas=False, inverse_transform=Tr
trainer = self.load_trainer()

X_test = self.transform_features(X_test)
y_pred_proba = trainer.predict_proba(X_test)
y_pred_proba = trainer.predict_proba(X_test, model=model)
if inverse_transform:
y_pred_proba = self.label_cleaner.inverse_transform_proba(y_pred_proba)
if as_pandas:
Expand All @@ -93,7 +93,7 @@ def predict_proba(self, X_test: DataFrame, as_pandas=False, inverse_transform=Tr
# TODO: Add decorators for cache functionality, return core code to previous state
# use_pred_cache to check for a cached prediction of rows, can dramatically speedup repeated runs
# add_to_pred_cache will update pred_cache with new predictions
def predict(self, X_test: DataFrame, as_pandas=False, sample=None, use_pred_cache=False, add_to_pred_cache=False):
def predict(self, X_test: DataFrame, model=None, as_pandas=False, sample=None, use_pred_cache=False, add_to_pred_cache=False):
pred_cache = None
if use_pred_cache or add_to_pred_cache:
try:
Expand All @@ -110,7 +110,7 @@ def predict(self, X_test: DataFrame, as_pandas=False, sample=None, use_pred_cach
X_test_cache_miss = X_test

if len(X_test_cache_miss) > 0:
y_pred_proba = self.predict_proba(X_test=X_test_cache_miss, inverse_transform=False, sample=sample)
y_pred_proba = self.predict_proba(X_test=X_test_cache_miss, model=model, inverse_transform=False, sample=sample)
if self.trainer_problem_type is not None:
problem_type = self.trainer_problem_type
else:
Expand Down
12 changes: 8 additions & 4 deletions autogluon/utils/tabular/ml/trainer/abstract_trainer.py
Expand Up @@ -575,16 +575,20 @@ def generate_stack_log_reg(self, X, y, level, k_fold=0, stack_name=None):

self.train_multi(X_train=X, y_train=y, X_test=None, y_test=None, models=[stacker_model_lr], hyperparameter_tune=False, feature_prune=False, stack_name=stack_name, kfolds=k_fold, level=level)

def predict(self, X):
if self.model_best is not None:
def predict(self, X, model=None):
if model is not None:
return self.predict_model(X, model)
elif self.model_best is not None:
return self.predict_model(X, self.model_best)
elif self.model_best_core is not None:
return self.predict_model(X, self.model_best_core)
else:
raise Exception('Trainer has no fit models to predict with.')

def predict_proba(self, X):
if self.model_best is not None:
def predict_proba(self, X, model=None):
if model is not None:
return self.predict_proba_model(X, model)
elif self.model_best is not None:
return self.predict_proba_model(X, self.model_best)
elif self.model_best_core is not None:
return self.predict_proba_model(X, self.model_best_core)
Expand Down

0 comments on commit 722a1b6

Please sign in to comment.