Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DOC TunedThreshodClassifierCV: use business scoring directly in GridSearchCV #29025

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

adrinjalali
Copy link
Member

I found it odd that we care about the business score, but we optimize for a very different metric.

cc @glemaitre

Copy link

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 6d02d9f. Link to the linter CI: here

@glemaitre
Copy link
Member

This is a trap. You should not do that. You should optimise over the proper scoring rule to get the best probability estimate. Your business metric is just a decision over the probability. Optimising the business metric does not allow to get the best probability estimate if it does not correspond to a proper scoring rule.

@glemaitre
Copy link
Member

So I agree this might counter intuitive but this is what we should encourage as a good practice.

@glemaitre
Copy link
Member

And as a consequence, it actually means that our default choice of scoring in the GridSearhCV are really not good. We should instead use the log-loss for classification and most probably the mean squared error for the regression as default.

Maybe @ogrisel have some additional details to provide.

@@ -601,7 +601,7 @@ def business_metric(y_true, y_pred, amount):

logistic_regression = make_pipeline(StandardScaler(), LogisticRegression())
param_grid = {"logisticregression__C": np.logspace(-6, 6, 13)}
model = GridSearchCV(logistic_regression, param_grid, scoring="neg_log_loss").fit(
model = GridSearchCV(logistic_regression, param_grid, scoring=business_scorer).fit(
data_train, target_train
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to pass the amount_train here, otherwise it fails.

@adrinjalali
Copy link
Member Author

This is a trap. You should not do that. You should optimise over the proper scoring rule to get the best probability estimate. Your business metric is just a decision over the probability. Optimising the business metric does not allow to get the best probability estimate if it does not correspond to a proper scoring rule.

I don't really understand this @glemaitre . The estimator itself is using a very different loss to solve the problem. Out of the models I train, in this usecase, I want the want that maximizes my monetary reward. So I don't see how the probability estimate here is relevant.

@glemaitre
Copy link
Member

Out of the models I train, in this usecase, I want the want that maximizes my monetary reward.

Nop. If you do that you are going to pick up a model that might not be calibrated. Potentially, you might have equivalent models with the same monetary rewards but some are calibrated and other are not. In terms of validation curve, you can think that the monetary metric across model will be flat while the proper scoring rule will provide a peak among those model. To illustrate you can have a look at the last figure of this notebook from @ogrisel

https://gist.github.com/ogrisel/8502eb455cd38d41e92fee31863ffea7

You can think of the business metric as the accuracy. The ranking metric show the bump but as large standard deviation while the thresholded metric is completely flat. If I have to give an over simplify explanation here: the ranking metric is loosing some information compare to a proper scoring rule and the thresholded metric is even more loosing information. This translate into a larger std. dev. and a plateau.

So we need to have a really good example to describe and show these aspects and they impact the choice of the "default" metrics in the grid-search.

@adrinjalali
Copy link
Member Author

When it comes to making a decision based on the interpretation of the predict_proba, then I understand why the model being calibrated is important. But if I'm finding a different threshold than 0.5 and making a decision, and the model is not calibrated but with the new found threshold I get what I want, then why would I care about calibration?

It also makes me thing, calibration and finding a threshold are very similar. You could either calibrate your model, or find a different threshold.

I'm not getting the intuition you have I think.

@glemaitre
Copy link
Member

But if I'm finding a different threshold than 0.5 and making a decision, and the model is not calibrated but with the new found threshold I get what I want, then why would I care about calibration?

Because you have use case that you are both interested by both the probability and the decision threshold. Tuning hyperparameter on a proper scoring rule ensure the first one while post-tuning the decision threshold provide you the second one. And by only tuning by using the business metric does not mean that you get a better model for decision than the sequential tuning.

Regarding a use case that you might be interested by both aspects: in a recommender system, you might want to have calibrated probability of someone clicking on an add and you decision to use the add could be linked to other data such as diversity of add on the page, etc. Another one could be linked to weather forecast where one want the real probability that it rains but also a decision metric such as "should I take an umbrella or a jacket" :).

At the end, grid-searching on a proper scoring rule and a subsequent decision threshold optimization will give you the best of the two worlds: a calibrated model optimized to take decision.

It also makes me thing, calibration and finding a threshold are very similar. You could either calibrate your model, or find a different threshold.

On calibrated model, the best decision threshold is not necessarily at 0.5 so you need to post-tune on your decision metric.

@adrinjalali
Copy link
Member Author

That might be true for a different usecase, but in this case that doesn't apply.

This is the script I used for my talk, and I honestly don't see why in this particular case one would do otherwise:

# %%
import numpy as np
from sklearn.datasets import fetch_openml

# %%
credit_card = fetch_openml(data_id=1597, as_frame=True, parser="pandas")
credit_card.frame.info()

# %%
columns_to_drop = ["Class"]
data = credit_card.frame.drop(columns=columns_to_drop)
target = credit_card.frame["Class"].astype(int)

# %%
target.value_counts(normalize=True)

# %%
target.value_counts()

# %%
import matplotlib.pyplot as plt

fraud = target == 1
amount_fraud = data["Amount"][fraud]
_, ax = plt.subplots()
ax.hist(amount_fraud, bins=100)
ax.set_title("Amount of fraud transaction")
_ = ax.set_xlabel("Amount (€)")
plt.show()


# %%
def business_metric(y_true, y_pred, amount):
    mask_true_positive = (y_true == 1) & (y_pred == 1)
    mask_true_negative = (y_true == 0) & (y_pred == 0)
    mask_false_positive = (y_true == 0) & (y_pred == 1)
    mask_false_negative = (y_true == 1) & (y_pred == 0)
    legitimate_refuse = mask_false_positive.sum() * -5
    fraudulent_refuse = (mask_true_positive.sum() * 50) + amount[
        mask_true_positive
    ].sum()
    fraudulent_accept = -amount[mask_false_negative].sum()
    legitimate_accept = (amount[mask_true_negative] * 0.02).sum()
    return fraudulent_refuse + fraudulent_accept + legitimate_refuse + legitimate_accept


# %%
import sklearn
from sklearn.metrics import make_scorer

sklearn.set_config(enable_metadata_routing=True)
business_scorer = make_scorer(business_metric).set_score_request(amount=True)

# %%
amount = credit_card.frame["Amount"].to_numpy()

# %%
from sklearn.model_selection import train_test_split

data_train, data_test, target_train, target_test, amount_train, amount_test = (
    train_test_split(
        data, target, amount, stratify=target, test_size=0.5, random_state=42
    )
)

# %%
from sklearn.dummy import DummyClassifier

easy_going_classifier = DummyClassifier(strategy="constant", constant=0)
easy_going_classifier.fit(data_train, target_train)
benefit_cost_tolerant = business_scorer(
    easy_going_classifier, data_test, target_test, amount=amount_test
)
print(f"Benefit/cost of our easy-going classifier: {benefit_cost_tolerant:,.2f}€")

# %%
intolerant_classifier = DummyClassifier(strategy="constant", constant=1)
intolerant_classifier.fit(data_train, target_train)
benefit_cost_intolerant = business_scorer(
    intolerant_classifier, data_test, target_test, amount=amount_test
)
print(f"Benefit/cost of our intolerant classifier: {benefit_cost_intolerant:,.2f}€")

# %%
from sklearn.metrics import get_scorer

balanced_accuracy_scorer = get_scorer("balanced_accuracy")
tolerant_balanced_arruracy = balanced_accuracy_scorer(
    easy_going_classifier, data_test, target_test
)
print(
    "Balanced accuracy of our easy-going classifier: "
    f"{tolerant_balanced_arruracy:.3f}"
)
intolerant_balanced_accuracy = balanced_accuracy_scorer(
    intolerant_classifier, data_test, target_test
)
print(
    "Balanced accuracy of our intolerant classifier: "
    f"{intolerant_balanced_accuracy:.3f}"
)

# %%
accuracy_scorer = get_scorer("accuracy")
tolerant_arruracy = accuracy_scorer(
    easy_going_classifier, data_test, target_test
)
print(
    "Accuracy of our easy-going classifier: "
    f"{tolerant_arruracy:.3f}"
)
intolerant_accuracy = accuracy_scorer(
    intolerant_classifier, data_test, target_test
)
print(
    "Accuracy of our intolerant classifier: "
    f"{intolerant_accuracy:.3f}"
)


# %%
import pandas as pd
scores = pd.DataFrame(
    {"Name": [], "Benefit/Cost": [], "Balanced Accuracy": [], "Accuracy": []}
)

def add_score(name, business_score, balanced_accuracy_score, accuracy_score):
    return pd.concat(
        [scores, pd.DataFrame(
            {"Name": [name],
             "Benefit/Cost": [business_score],
             "Balanced Accuracy": [balanced_accuracy_score],
             "Accuracy": [accuracy_score]})]
    ).reset_index(drop=True)

scores = add_score(
    "Tolerant",
    benefit_cost_tolerant,
    tolerant_balanced_arruracy,
    tolerant_arruracy
)
scores = add_score(
    "Intolerant",
    benefit_cost_intolerant,
    intolerant_balanced_accuracy,
    intolerant_accuracy
)
scores

# %%
def handle_scores(model, name):
    business_score = business_scorer(
        model, data_test, target_test, amount=amount_test
    )
    balanced_accuracy_score = balanced_accuracy_scorer(model, data_test, target_test)
    accuracy_score = accuracy_scorer(model, data_test, target_test)
    global scores
    scores = add_score(name, business_score, balanced_accuracy_score, accuracy_score)
    with pd.option_context(
            'display.max_rows', None,
            'display.max_columns', None,
            'display.width', 199
    ):
        print(scores)


# %%
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

logistic_regression = make_pipeline(StandardScaler(), LogisticRegression())
param_grid = {"logisticregression__C": np.logspace(-6, 6, 13)}
model = GridSearchCV(
    logistic_regression, param_grid, scoring=business_scorer, n_jobs=-1
).fit(
    data_train, target_train, amount=amount_train
)

handle_scores(model, "Searched logistic regression")
model.best_estimator_

# %%
from sklearn.model_selection import TunedThresholdClassifierCV

tuned_model = TunedThresholdClassifierCV(
    estimator=model.best_estimator_,
    scoring=business_scorer,
    thresholds=40,
    n_jobs=-1,
).fit(data_train, target_train, amount=amount_train)

handle_scores(tuned_model, "Tuned searched logistic regression")
print(f"Best threshold: {tuned_model.best_threshold_}")
# %%
tuned_model = TunedThresholdClassifierCV(
    estimator=logistic_regression,
    scoring=business_scorer,
    thresholds=40,
    n_jobs=1,
)
tuned_model._response_method = "predict_proba"

param_grid = {"estimator__logisticregression__C": np.logspace(-6, 6, 13)}
model = GridSearchCV(tuned_model, param_grid, scoring=business_scorer, n_jobs=-1).fit(
    data_train, target_train, amount=amount_train
)

handle_scores(model, "Searched tuned logistic regression")
model.best_estimator_

# %%
print(f"Best threshold: {model.best_estimator_.best_threshold_}")

And the resulting table at the end is:

                                 Name  Benefit/Cost  Balanced Accuracy  Accuracy
0                            Tolerant   221445.0720           0.500000  0.998273
1                          Intolerant  -668903.2400           0.500000  0.001727
2        Searched logistic regression   260787.2106           0.814960  0.999199
3  Tuned searched logistic regression   268816.0748           0.900009  0.998862
Best threshold: 0.02564102564098705
                                 Name  Benefit/Cost  Balanced Accuracy  Accuracy
0                            Tolerant   221445.0720           0.500000  0.998273
1                          Intolerant  -668903.2400           0.500000  0.001727
2        Searched logistic regression   260787.2106           0.814960  0.999199
3  Tuned searched logistic regression   268816.0748           0.900009  0.998862
4  Searched tuned logistic regression   274137.1150           0.892129  0.999333
Best threshold: 0.001980122980204754

And it is clear that the second model is actually less calibrated that the first one. I guess one can say that since the threshold is an order of magnitude smaller. But for this particular example, I don't see why not.

@glemaitre
Copy link
Member

OK so I don't have all the answers here.

One things that I can see now is that the first "Tuned searched" model will have a C of 100 while the "Searched Tuned" model will have a C of 1e-5. So the second model is way more regularized than the first model. So my next step would be to show the validation curve to get some insights regarding the trade-off that we are facing here. I'm also thinking that we should look at the stability of the C coefficients during the cross-validation.

@lorentzenchr
Copy link
Member

lorentzenchr commented May 17, 2024

But if I'm finding a different threshold than 0.5 and making a decision, and the model is not calibrated but with the new found threshold I get what I want, then why would I care about calibration?

In the end, if you get what you want, that's fine. Questions are then whether there is a better solution and whether your methodology gives systematically good results (or just by chance). To answer these questions we should provide good and sound examples.

I see 3 fundamental issues with grid searching (GS) a logistic regression (LR) penalty using a score other than a proper scoring rule, in particular a score implying a threshold:

  1. Misuse of predict_proba instead of the threshold
    The implied threshold by naive predict is fixed at 0.5. Then you (hyper)search for a parameter (penalty in our case) such that your GS score is optimized. The only thing poor LR can change is predict_proba. But what you really want to change is the threshold. One could say a little more pointedly:
    • You misused the probability at the expense of the decision threshold.
    • LR is really a regression and one should only ever use predict_proba.
  2. Focus on the small/tiny decision boundary:
    The result is that predict_proba around 0.5 has a huge impact and is not reliable at all some distance away from 0.5. You should be very careful to change the threshold afterwards because you completely lost the absolute scale of predict_proba. This is not a setup that I would recommend for a production setup.
  3. Contradiction between score/loss used by LR und used by grid search.
    LR minimizes log loss and thereby tries its best to produce good results for predict_proba. But your GS tampers with it and forces predict_proba to serve a different purpose.

My opinion is that we should very much advertise the following:

  • Accept that there are 2 distinct problems and they are better addressed one after the other instead of mixed up (but there are models like SVM that mix them up implicitly).
  • The first problem is to predict reliable probabilities. Good news: lots of statistical tools we can use.
  • The second problem is then to make a decision based on those probability predictions. This is very much use case dependent and usually amounts to finding/deciding for a threshold of your probability predictions.

EDIT for the interested reader:
It gets fun when looking at point of views in use cases.

  • If it is a „repeated setting“, like a bank selling credits, then a law of large numbers argument kicks in and the best action/decision is the Bayes act $\mathrm{argmin}_{action} E[cost(Y, action)|X]$ for which we need $P(Y|X)$. This is what we need to predict by a model and the above reasoning applies.
  • If you only take action once, e.g. a customer looking for an (optional) insurance, it’s less clear on what to base a decision. A possibility is the most likely cost or, to sleep better, the cost that happens in the 5% worst scenarios (5% value at risk, read quantile).

@ogrisel
Copy link
Member

ogrisel commented May 20, 2024

But if I'm finding a different threshold than 0.5 and making a decision, and the model is not calibrated but with the new found threshold I get what I want, then why would I care about calibration?

If you never use or show the values computed by predict_proba before thresholding to the operator of the system, then maybe it's ok to only tune for the business metric everywhere (hparams and decision threshold).

But if you want to use the predict_proba value one way or another, e.g. to disable automated decisions to a safer fallback (e.g. manual inspection and processing) for the individuals with the most uncertain predictions, then it's important to have the best predict_proba values possible (in conjunction to the best business metric value for the thresholded decisions).

In that case, I would argue to first get the best model in terms of predict_proba by optimizing for a proper scoring rule (e.g. log loss or Brier score) and then secondly to find the best decision threshold based on the business metric.

In practice, I think this two-stage (log loss first, business metric second) approach should yield a very similar test business metric than a pipeline that is entirely tuned (both hparams and decision threshold) for the business metric only. Therefore, I would also advise for the two-stage approach by default.

@glemaitre
Copy link
Member

glemaitre commented May 20, 2024

Basically, I like the discussion here because they are messages that we should convey in the example and that we probably are not writing now. We should improve the example in this direction.

I just try to understand a bit more what was going on regarding the trade-off of the losses and here are the losses of the grid-search whether we use the business score of the log loss:

image

image

So indeed, if we are only interested about the decision without any regards for the probabilities, a crafted business metric can have a higher score on section where the proper scoring rule is not good.

Here, what I find particularly interested is that a model with a high regularization will perform better with the business metric. I would like to get more insights there for what it really means. The fact that here we use "Amount" in both the data and the metric has an impact, meaning that the metadata used in the metric would be enough to define our model?

The fact that we see something similar for the balanced accuracy

image

Really intrigued me. The fact that you have an increase of the score with a plateau and then a decrease before the second increase (that I would expect) is really interesting.

I think that here, we should advocate the default and explain that we want the best of the two worlds. However, I think this is an interesting topic to navigate and improve a potential example of the grid-search to discuss this specific aspects and the choice of the metric to optimize during the search.

@ogrisel
Copy link
Member

ogrisel commented May 20, 2024

I don't see the trade-off in the plots you give: C=1e3 is always best (or near optimal) for all metrics. Model selection based on your business metrics only seems to lead to under-specification: many very different models are near optimal.

The highly non-monotonic relationship between C and balanced accuracy/business metric is indeed intriguing but we would need to repeat this experiment with cross-validation (or just a bootstrap of the test metric) to check whether or not this is not just related to sampling noise on small data sets.

EDIT: I typed C=1e-3 instead of C=1e3!

@glemaitre
Copy link
Member

glemaitre commented May 20, 2024

I don't see the trade-off in the plots you give: C=1e-3 is always best (or near optimal) for all metrics.

For the log-loss, the optimal is at 1e3 while it will be 1e-5 for the others.

@ogrisel
Copy link
Member

ogrisel commented May 20, 2024

while it will be 1e-5 for the others.

Given the shape of the plots, I would really double check for uncertainties in that assertion via cross-validation / bootstrap of the test metrics. The difference does not seem significant at all.

EDIT: the best model is around C=1e3 or more (low regularization) for all metrics on your experiment. The neg log loss is also higher is better! This is the negative negative log likelihood.

@glemaitre
Copy link
Member

EDIT: the best model is around C=1e3 or more (low regularization) for all metrics on your experiment. The neg log loss is also higher is better! This is the negative negative log likelihood.

This what I said. 1e3 for the log-loss while high regularization 1e-5 for the thresholded metric. But you are right regarding the std. dev. They are all within the noise level when it comes to the difference between 1e-5 and 1e3 on the thresholded metrics.

@ogrisel
Copy link
Member

ogrisel commented May 21, 2024

So it seems that a majority of us agree to recommend to tune probabilistic classifiers on proper scoring rules to select a model that outputs the most meaningful un-thresholded probabilistic predictions first, and then as a second stage, only optimize the decision threshold for the business metric of interest.

However I think the point Adrin raised needs to be explicitly addressed in our doc and/our example.

Note that for non-probabilistic classifiers such as Support Vector Machines, it perfectly makes sense to directly optimize for the business metric everywhere. I wouldn't be surprised that for "good" business metrics, this would lead to similar results as first hparam-tuning for the model with the best ROC AUC followed by threshold-tuning on the business metric.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants