Skip to content

Model Selection

Alejandro Moreo edited this page Feb 12, 2024 · 10 revisions

Model Selection

As a supervised machine learning task, quantification methods can strongly depend on a good choice of model hyper-parameters. The process whereby those hyper-parameters are chosen is typically known as Model Selection, and typically consists of testing different settings and picking the one that performed best in a held-out validation set in terms of any given evaluation measure.

Targeting a Quantification-oriented loss

The task being optimized determines the evaluation protocol, i.e., the criteria according to which the performance of any given method for solving is to be assessed. As a task on its own right, quantification should impose its own model selection strategies, i.e., strategies aimed at finding appropriate configurations specifically designed for the task of quantification.

Quantification has long been regarded as an add-on of classification, and thus the model selection strategies customarily adopted in classification have simply been applied to quantification (see the next section). It has been argued in Moreo, Alejandro, and Fabrizio Sebastiani. Re-Assessing the "Classify and Count" Quantification Method. ECIR 2021: Advances in Information Retrieval pp 75–91. that specific model selection strategies should be adopted for quantification. That is, model selection strategies for quantification should target quantification-oriented losses and be tested in a variety of scenarios exhibiting different degrees of prior probability shift.

The class qp.model_selection.GridSearchQ implements a grid-search exploration over the space of hyper-parameter combinations that evaluates each combination of hyper-parameters by means of a given quantification-oriented error metric (e.g., any of the error functions implemented in qp.error) and according to a sampling generation protocol.

The following is an example (also included in the examples folder) of model selection for quantification:

import quapy as qp
from quapy.protocol import APP
from quapy.method.aggregative import DMy
from sklearn.linear_model import LogisticRegression
import numpy as np

"""
In this example, we show how to perform model selection on a DistributionMatching quantifier.
"""

model = DMy(LogisticRegression())

qp.environ['SAMPLE_SIZE'] = 100
qp.environ['N_JOBS'] = -1  # explore hyper-parameters in parallel

training, test = qp.datasets.fetch_reviews('imdb', tfidf=True, min_df=5).train_test

# The model will be returned by the fit method of GridSearchQ.
# Every combination of hyper-parameters will be evaluated by confronting the
# quantifier thus configured against a series of samples generated by means
# of a sample generation protocol. For this example, we will use the
# artificial-prevalence protocol (APP), that generates samples with prevalence
# values in the entire range of values from a grid (e.g., [0, 0.1, 0.2, ..., 1]).
# We devote 30% of the dataset for this exploration.
training, validation = training.split_stratified(train_prop=0.7)
protocol = APP(validation)

# We will explore a classification-dependent hyper-parameter (e.g., the 'C'
# hyper-parameter of LogisticRegression) and a quantification-dependent hyper-parameter
# (e.g., the number of bins in a DistributionMatching quantifier.
# Classifier-dependent hyper-parameters have to be marked with a prefix "classifier__"
# in order to let the quantifier know this hyper-parameter belongs to its underlying
# classifier.
param_grid = {
    'classifier__C': np.logspace(-3, 3, 7),
    'nbins': [8, 16, 32, 64],
}

model = qp.model_selection.GridSearchQ(
    model=model,
    param_grid=param_grid,
    protocol=protocol,
    error='mae',  # the error to optimize is the MAE (a quantification-oriented loss)
    refit=True,  # retrain on the whole labelled set once done
    verbose=True  # show information as the process goes on
).fit(training)

print(f'model selection ended: best hyper-parameters={model.best_params_}')
model = model.best_model_

# evaluation in terms of MAE
# we use the same evaluation protocol (APP) on the test set
mae_score = qp.evaluation.evaluate(model, protocol=APP(test), error_metric='mae')

print(f'MAE={mae_score:.5f}')

In this example, the system outputs:

[GridSearchQ]: starting model selection with self.n_jobs =-1
[GridSearchQ]: hyperparams={'classifier__C': 0.01, 'nbins': 64}	 got mae score 0.04021 [took 1.1356s]
[GridSearchQ]: hyperparams={'classifier__C': 0.01, 'nbins': 32}	 got mae score 0.04286 [took 1.2139s]
[GridSearchQ]: hyperparams={'classifier__C': 0.01, 'nbins': 16}	 got mae score 0.04888 [took 1.2491s]
[GridSearchQ]: hyperparams={'classifier__C': 0.001, 'nbins': 8}	 got mae score 0.05163 [took 1.5372s]
[...]
[GridSearchQ]: hyperparams={'classifier__C': 1000.0, 'nbins': 32}	 got mae score 0.02445 [took 2.9056s]
[GridSearchQ]: optimization finished: best params {'classifier__C': 100.0, 'nbins': 32} (score=0.02234) [took 7.3114s]
[GridSearchQ]: refitting on the whole development set
model selection ended: best hyper-parameters={'classifier__C': 100.0, 'nbins': 32}
MAE=0.03102

Targeting a Classification-oriented loss

Optimizing a model for quantification could rather be computationally costly. In aggregative methods, one could alternatively try to optimize the classifier's hyper-parameters for classification. Although this is theoretically suboptimal, many articles in quantification literature have opted for this strategy.

In QuaPy, this is achieved by simply instantiating the classifier learner as a GridSearchCV from scikit-learn. The following code illustrates how to do that:

learner = GridSearchCV(
    LogisticRegression(),
    param_grid={'C': np.logspace(-4, 5, 10), 'class_weight': ['balanced', None]},
    cv=5)
model = DistributionMatching(learner).fit(dataset.train)

However, this is conceptually flawed, since the model should be optimized for the task at hand (quantification), and not for a surrogate task (classification), i.e., the model should be requested to deliver low quantification errors, rather than low classification errors.

Clone this wiki locally