Skip to content

Commit

Permalink
Improve AbstractTrainer Docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Innixma committed Aug 9, 2022
1 parent 7116279 commit d9cf062
Showing 1 changed file with 101 additions and 4 deletions.
105 changes: 101 additions & 4 deletions core/src/autogluon/core/trainer/abstract_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

from .utils import process_hyperparameters
from ..augmentation.distill_utils import format_distillation_labels, augment_data
from ..constants import AG_ARGS, BINARY, MULTICLASS, REGRESSION, REFIT_FULL_NAME, REFIT_FULL_SUFFIX
from ..constants import AG_ARGS, BINARY, MULTICLASS, REGRESSION, QUANTILE, SOFTCLASS, REFIT_FULL_NAME, REFIT_FULL_SUFFIX
from ..metrics import Scorer
from ..models import AbstractModel, BaggedEnsembleModel, StackerEnsembleModel, WeightedEnsembleModel, GreedyWeightedEnsembleModel, SimpleWeightedEnsembleModel
from ..utils import default_holdout_frac, get_pred_from_proba, generate_train_test_split, infer_eval_metric, compute_permutation_feature_importance, extract_column, compute_weighted_metric
from ..utils.exceptions import TimeLimitExceeded, NotEnoughMemoryError, NoValidFeatures, NoGPUError, NotEnoughCudaMemoryError
Expand All @@ -37,14 +38,90 @@
# TODO: Try midstack Semi-Supervised. Just take final models and re-train them, use bagged preds for SS rows. This would be very cheap and easy to try.
# TODO: Move to autogluon.core
class AbstractTrainer:
"""
AbstractTrainer contains logic to train a variety of models under a variety of constraints and automatically generate a multi-layer stack ensemble.
Beyond the basic functionality, it also has support for model refitting, distillation, pseudo-labelling, unlabeled data, and much more.
It is not recommended to directly use Trainer. Instead, use Predictor or Learner which internally uses Trainer.
This documentation is for developers. Users should avoid this class.
Due to the complexity of the logic within this class, a text description will not give the full picture.
It is recommended to carefully read the code and use a debugger to understand how it works.
AbstractTrainer makes much fewer assumptions about the problem than Learner and Predictor.
It expects these ambiguities to have already been resolved upstream. For example, problem_type, feature_metadata, num_classes, etc.
Parameters
----------
path : str
Path to save and load trainer artifacts to disk.
Path should end in `/`.
problem_type : str
One of ['binary', 'multiclass', 'regression', 'quantile', 'softclass']
num_classes : int
The number of classes in the problem.
If problem_type is in ['regression', 'quantile'], this must be None.
If problem_type is 'binary', this must be 2.
If problem_type is in ['multiclass', 'softclass'], this must be >= 2.
feature_metadata : FeatureMetadata
FeatureMetadata for X. Sent to each model during fit.
eval_metric : Scorer, default = None
Metric to optimize. If None, a default metric is used depending on the problem_type.
quantile_levels : List[float], default = None
# TODO: Add documentation, not documented in Predictor.
Only used when problem_type=quantile
low_memory : bool, default = True
Deprecated parameter, likely to be removed in future versions.
If True, caches models to disk separately instead of containing all models within memory.
If False, will cause a variety of bugs.
k_fold : int, default = 0
If <2, then non-bagged mode is used.
If >= 2, then bagged mode is used with num_bag_folds == k_fold for each model.
Bagged mode changes the way models are trained and ensembled.
Bagged mode enables multi-layer stacking and repeated bagging.
n_repeats : int, default = 1
The maximum repeats of bagging to do when in bagged mode.
Larger values take linearly longer to train and infer, but improves quality slightly.
sample_weight : str, default = None
Column name of the sample weight in X
weight_evaluation : bool, default = False
If True, the eval_metric is calculated with sample_weight incorporated into the score.
save_data : bool, default = True
Whether to cache the data (X, y, X_val, y_val) to disk.
Required for a variety of advanced post-fit functionality.
It is recommended to keep as True.
random_state : int, default = 0
Random state for data splitting in bagged mode.
verbosity : int, default = 2
Verbosity levels range from 0 to 4 and control how much information is printed.
Higher levels correspond to more detailed print statements (you can set verbosity = 0 to suppress warnings).
If using logging, you can alternatively control amount of information printed via `logger.setLevel(L)`,
where `L` ranges from 0 to 50 (Note: higher values of `L` correspond to fewer print statements, opposite of verbosity levels).
"""
trainer_file_name = 'trainer.pkl'
trainer_info_name = 'info.pkl'
trainer_info_json_name = 'info.json'
distill_stackname = 'distill' # name of stack-level for distilled student models

def __init__(self, path: str, problem_type: str, eval_metric=None,
num_classes=None, quantile_levels=None, low_memory=False, feature_metadata=None, k_fold=0, n_repeats=1,
sample_weight=None, weight_evaluation=False, save_data=False, random_state=0, verbosity=2):
def __init__(self,
path: str,
*,
problem_type: str,
num_classes: int,
feature_metadata: FeatureMetadata,
eval_metric: Scorer = None,
quantile_levels: List[float] = None,
low_memory: bool = True,
k_fold: int = 0,
n_repeats: int = 1,
sample_weight: str = None,
weight_evaluation: bool = False,
save_data: bool = True,
random_state: int = 0,
verbosity: int = 2):
# TODO: Make path == self.path_root, change logic so it doesn't assume learner exists.
self._validate_num_classes(num_classes=num_classes, problem_type=problem_type)
self._validate_quantile_levels(quantile_levels=quantile_levels, problem_type=problem_type)
self.path = path
self.problem_type = problem_type
self.feature_metadata = feature_metadata
Expand Down Expand Up @@ -2598,3 +2675,23 @@ def _get_feature_prune_proxy_model(self, proxy_model_class: Union[AbstractModel,
return proxy_model
best_candidate_model_rows = candidate_model_rows.loc[candidate_model_rows['score_val'] == candidate_model_rows['score_val'].max()]
return self.load_model(best_candidate_model_rows.loc[best_candidate_model_rows['fit_time'].idxmin()]['model'])

@staticmethod
def _validate_num_classes(num_classes: int, problem_type: str):
if problem_type == BINARY:
assert num_classes is not None and num_classes == 2, f"num_classes must be 2 when problem_type='{problem_type}' (num_classes={num_classes})"
elif problem_type in [MULTICLASS, SOFTCLASS]:
assert num_classes is not None and num_classes >= 2, f"num_classes must be >=2 when problem_type='{problem_type}' (num_classes={num_classes})"
elif problem_type in [REGRESSION, QUANTILE]:
assert num_classes is None, f"num_clases must be None when problem_type='{problem_type}' (num_classes={num_classes})"
else:
raise AssertionError(f"Unknown problem_type: '{problem_type}'. Valid problem types: {[BINARY, MULTICLASS, REGRESSION, SOFTCLASS, QUANTILE]}")

@staticmethod
def _validate_quantile_levels(quantile_levels: List[float], problem_type: str):
if problem_type == QUANTILE:
assert quantile_levels is not None, f"quantile_levels must not be None when problem_type='{problem_type}' (quantile_levels={quantile_levels})"
assert isinstance(quantile_levels, list), f"quantile_levels must be a list (quantile_levels={quantile_levels})"
assert len(quantile_levels) > 0, f"quantile_levels must not be an empty list (quantile_levels={quantile_levels})"
else:
assert quantile_levels is None, f"quantile_levels must be None when problem_type='{problem_type}' (quantile_levels={quantile_levels})"

0 comments on commit d9cf062

Please sign in to comment.