# RQ2b: What is the effect of prevalence shifts on the quality of decision rules
This notebook generates figure 8. It shows quantitatively the implications of prevalence shifts on optimal thresholds in our use cases.

In [1]:
import os
import sys

sys.path.insert(0, os.path.abspath('..'))
from pathlib import Path

import torch
import pandas as pd
import numpy as np
from progiter import ProgIter

from src.prev.calibration import CalibrationMethod, calibrate_logits_fast
from src.prev.scaling import scale_prevalences_ir
from src.prev.plotting import plot_aggregate_results, Confidence, multiplot
from src.prev.thresholding import find_best_thresholds, ThresholdingMethod
from src.prev.metrics import Metric, compute_all_metrics
from src.prev.quantification import adjust_priors_qp
from src.prev.data_loading import get_values, Kind, Split, all_tasks, binary_tasks, example_tasks, example_binary_tasks

import torch.nn.functional as F

current_path = os.getcwd()
DATA_PATH = Path(current_path).parent / 'data'
RESULT_PATH = Path(current_path).parent / 'results'
assert DATA_PATH.exists() and RESULT_PATH.exists()
torch.manual_seed(seed=0)

# set to false for a full rerun, including all tasks, note that we do not provide
# logits for all tasks for licensing considerations, so one needs to generate them from scratch
EXAMPLE_TASKS_ONLY = True
if EXAMPLE_TASKS_ONLY:
    rel_tasks = example_tasks
    rel_binary_tasks = example_binary_tasks
else: 
    rel_tasks = all_tasks
    rel_binary_tasks = binary_tasks

<torch._C.Generator at 0x7f907c0d5d30>

In [2]:
data = {}
for t in ProgIter(rel_tasks, desc='Loading data'):
    data[t] = get_values(t, DATA_PATH, proj='mic23_predictions_original_0')  # original paper predictions

Loading data 100.00% 30/30... rate=5.43 Hz, eta=0:00:00, total=0:00:05


In [3]:
ir_range = np.arange(1, 10 + 0.5, 0.5)
def thresholds_across_ir(task_data, calibration: CalibrationMethod = CalibrationMethod.NONE,
                          thresholding: ThresholdingMethod = ThresholdingMethod.ARGMAX):
    """
    Computes the values of metrics for optimal thresholds computed for the application test set and
    thresholds set on the development test set for a given task across imbalance ratios.

    :param task_data: logits and labels for all splits
    :param calibration: calibration method
    :param thresholding: thresholding method to be used
    :return results: dictionary of metric values
    """
    # initialize the results dictionary
    results = {m: [] for m in Metric}
    results.update({"reference " + m.value: [] for m in Metric})
    for ir in ir_range:
        # scale prevalences in the deployment set according to imbalance ratio
        app_test_logits, app_test_classes = scale_prevalences_ir(logits=task_data[Kind.LOGITS][Split.APP_TEST],
                                                                 classes=task_data[Kind.LABELS][Split.APP_TEST],
                                                                 ir=ir)
        # create a data dictionary with the modified deployment test set
        mod_data = {Kind.LOGITS: {Split.DEV_CAL: task_data[Kind.LOGITS][Split.DEV_CAL],
                                  Split.DEV_TEST: task_data[Kind.LOGITS][Split.DEV_TEST],
                                  Split.APP_TEST: app_test_logits},
                    Kind.LABELS: {Split.DEV_CAL: task_data[Kind.LABELS][Split.DEV_CAL],
                                  Split.DEV_TEST: task_data[Kind.LABELS][Split.DEV_TEST],
                                  Split.APP_TEST: app_test_classes}}
        # compute the number of samples in each class of the app test set with modified prevalence
        exact_prevalence = torch.bincount(mod_data[Kind.LABELS][Split.APP_TEST]) / len(mod_data[Kind.LABELS][Split.APP_TEST])
        # estimated version of these prevalences
        estimated_prevalence = torch.Tensor(
            adjust_priors_qp(torch.softmax(mod_data[Kind.LOGITS][Split.DEV_TEST], dim=1),
                             mod_data[Kind.LABELS][Split.DEV_TEST],
                             torch.softmax(mod_data[Kind.LOGITS][Split.APP_TEST], dim=1),
                             mod_data[Kind.LABELS][Split.APP_TEST]))

        if calibration == CalibrationMethod.AFFINE_REWEIGHTED:
            prior = exact_prevalence
        elif calibration == CalibrationMethod.AFFINE_ACC:
            prior = estimated_prevalence
        elif calibration == CalibrationMethod.NONE:
            prior = None
        else:
            raise ValueError('calibration method not supported')
        calibrated_logits = calibrate_logits_fast(data=mod_data, calibration=calibration, prior=prior)
        # extract the minority class for F1 computation
        val_prevalences = torch.bincount(mod_data[Kind.LABELS][Split.DEV_CAL])
        min_class = torch.argmin(val_prevalences).item()
        max_class = torch.argmax(val_prevalences).item()
        # catches case where for balanced task, min class was used as max class in scaling
        if min_class == max_class:
            min_class = 1
        if thresholding == ThresholdingMethod.ARGMAX:
            # use 0.5 as threshold (argmax)
            thresholds = {m: 0.5 for m in Metric}
        elif thresholding == ThresholdingMethod.DEV_TEST:
            # find optimal thresholds on dev test
            thresholds = find_best_thresholds(labels=mod_data[Kind.LABELS][Split.DEV_TEST],
                                              logits=calibrated_logits[Split.DEV_TEST], min_class=min_class,
                                              priors=exact_prevalence.numpy(),
                                              est_priors=estimated_prevalence.numpy())
        else:
            raise ValueError('invalid thresholding method')
        # find optimal thresholds on app test
        optimal_thresholds = find_best_thresholds(labels=mod_data[Kind.LABELS][Split.APP_TEST],
                                                  logits=calibrated_logits[Split.APP_TEST], min_class=min_class,
                                                  priors=exact_prevalence.numpy(),
                                                  est_priors=estimated_prevalence.numpy())
        # compute  predictions on app test using the two sets of thresholds                                        
        new_app_test_preds = {key: F.softmax(calibrated_logits[Split.APP_TEST], dim=1)[:, 0] < thresholds[key] for key
                              in thresholds.keys()}
        optimal_app_test_preds = {
            key: F.softmax(calibrated_logits[Split.APP_TEST], dim=1)[:, 0] < optimal_thresholds[key] for key in
            optimal_thresholds.keys()}
        # compute metrics values for predictions made using optimal app test thresholds
        optimal_metrics = compute_all_metrics(mod_data[Kind.LABELS][Split.APP_TEST],
                                              mod_data[Kind.LOGITS][Split.APP_TEST], optimal_app_test_preds,
                                              min_class=min_class, exact_priors=exact_prevalence,
                                              estimated_priors=exact_prevalence)
        # compute metrics values for predictions made using the other thresholds 
        dev_threshold_metrics = compute_all_metrics(mod_data[Kind.LABELS][Split.APP_TEST],
                                                    mod_data[Kind.LOGITS][Split.APP_TEST], new_app_test_preds,
                                                    min_class=min_class, exact_priors=exact_prevalence,
                                                    estimated_priors=exact_prevalence)
        # append the computed metrics values to the results
        for k in optimal_metrics.keys():
            results[k].append(optimal_metrics[k])
        for k in dev_threshold_metrics.keys():
            results['reference ' + k.value].append(dev_threshold_metrics[k])
    return results

In [4]:
for cal in [CalibrationMethod.NONE, CalibrationMethod.AFFINE_REWEIGHTED, CalibrationMethod.AFFINE_ACC]:
    for thresholding in [ThresholdingMethod.DEV_TEST, ThresholdingMethod.ARGMAX]:
        collector = {t: {} for t in rel_tasks}
        for t in ProgIter(rel_binary_tasks):
            # compute values of metrics for thresholds computed on app test and dev test across IRs
            collector[t].update(thresholds_across_ir(task_data=data[t],
                                                 calibration=cal, thresholding=thresholding))
        # save a data frame with the results
        result_df = pd.DataFrame(rel_binary_tasks, columns=["name"])
        for key in collector[rel_binary_tasks[0]].keys():
            result_df[key] = [torch.tensor(collector[t][key]) for t in rel_binary_tasks]
        result_df.to_pickle(RESULT_PATH / ("24_decision_rule_results_" + cal.value + "_" + thresholding.value + ".pkl"))

 100.00% 24/24... rate=0.05 Hz, eta=0:00:00, total=0:08:09
 100.00% 24/24... rate=0.09 Hz, eta=0:00:00, total=0:04:36
 100.00% 24/24... rate=0.05 Hz, eta=0:00:00, total=0:08:02
 100.00% 24/24... rate=0.09 Hz, eta=0:00:00, total=0:04:36
 100.00% 24/24... rate=0.05 Hz, eta=0:00:00, total=0:07:58
 100.00% 24/24... rate=0.09 Hz, eta=0:00:00, total=0:04:36


# Generate Figure 8

In [3]:
# initialize list of subplots
subplts = []
metrics = [Metric.ACCURACY, Metric.F1, Metric.MCC, Metric.BALANCED_ACC, Metric.EC_EST, Metric.EC_ADJUSTED]
for thresholding in [ThresholdingMethod.ARGMAX, ThresholdingMethod.DEV_TEST]:
    for cal in [CalibrationMethod.NONE, CalibrationMethod.AFFINE_REWEIGHTED]:
        # load appropriate results data frame
        result_df = pd.read_pickle(RESULT_PATH / ("24_decision_rule_results_" + cal.value + "_" + thresholding.value + ".pkl"))
        # subtract reference values from metrics
        for met in metrics:
            if met == Metric.EC_EST:
                continue
            result_df[met.value] = result_df[met] - result_df['reference ' + met.value]

        if cal == CalibrationMethod.NONE:
            result_df[Metric.EC_EST.value] = result_df[Metric.EC_EST] - result_df['reference ' + Metric.EC_EST.value]
        else:
            est_df = pd.read_pickle(RESULT_PATH / (
                        "24_decision_rule_results_" + CalibrationMethod.AFFINE_ACC.value + "_" + thresholding.value + ".pkl"))
            result_df[Metric.EC_EST.value] = est_df[Metric.EC_EST] - est_df['reference ' + Metric.EC_EST.value]

        # define relevant metrics to be plotted
        if thresholding == ThresholdingMethod.ARGMAX:
            plot_metrics = [Metric.ACCURACY, Metric.F1, Metric.MCC, Metric.BALANCED_ACC]
        else:
            plot_metrics = metrics
        # create the aggregate subplot
        subplt = plot_aggregate_results(result_df, metrics=plot_metrics, file=None, delta=False, ci=Confidence.STD,  bound=[0, 0.2], opacity=0.15, plot_lines_later=True)
        # append subplot to list
        subplts.append(subplt)

In [4]:
# define y axis ranges
sub_y_ranges = {0: [0, 0.2], 1: [0, 0.2], 2: [0, 0.2], 3: [0, 0.2]}
# create multiplot (Fig. 4) from the paper
fig = multiplot(rows=2, cols=2, row_titles=["Argmax decision", "Optimized threshold"],
                column_titles=["w/o re-calibration", "with re-calibration"], y_title="Difference to optimum",
                subplts=subplts, horizontal_spacing=0.03, vertical_spacing=0.03, legend_index=2,
                sub_y_ranges=sub_y_ranges, height=800, shared_xaxes=True, little_guys=True, icon_axes=[3, 4],
                icon_size=0.08)
fig.update_xaxes(title_text="Imbalance ratio", title_standoff=0, row=2, col=1)
fig.update_xaxes(title_text="Imbalance ratio", title_standoff=0, row=2, col=2)
fig.update_layout(margin=dict(t=40, r=45))

In [9]:
# save plots
name = "24_fig_8"
fig.write_image(RESULT_PATH / (name + ".png"))
fig.write_image(RESULT_PATH / (name + ".pdf"))
fig.write_html(RESULT_PATH / (name + ".html"))
fig.write_image(RESULT_PATH / (name + ".svg"))

# Compute EC results from section 3.3

In [8]:
thresholding = ThresholdingMethod.ARGMAX
no_recal_result_df = pd.read_pickle(
    RESULT_PATH / ("24_decision_rule_results_" + CalibrationMethod.NONE.value + "_" + thresholding.value + ".pkl"))
recal_result_df = pd.read_pickle(RESULT_PATH / (
            "24_decision_rule_results_" + CalibrationMethod.AFFINE_REWEIGHTED.value + "_" + thresholding.value + ".pkl"))

In [9]:
# convert row to numpy:
no_recal_result_df["reference EC"] = no_recal_result_df["reference EC"].apply(lambda row: torch.Tensor(row).numpy())
recal_result_df["reference EC"] = recal_result_df["reference EC"].apply(lambda row: torch.Tensor(row).numpy())

In [10]:
diffs = ((no_recal_result_df["reference EC"] - recal_result_df["reference EC"]) / (no_recal_result_df["reference EC"]))

In [11]:
ir_values = [1, 4, 7, 10]
print(
    f"For IR values: {ir_values}, the % changes in EC values are: {(np.array(diffs.tolist()).mean(axis=0)[np.where(np.isin(ir_range, ir_values))]) * 100}")

For IR values: [1, 4, 7, 10], the % changes in EC values are: [ 0.94896266 17.98756884 32.10647671 41.6311924 ]
