# RQ2c: What is the effect of prevalence shifts on the quality of performance assessment
This notebook generates figure 9. It shows the impact of prevalence shifts on evaluation metrics in our use cases.

In [1]:
import os
import sys

sys.path.insert(0, os.path.abspath('..'))
from pathlib import Path
from progiter import ProgIter

import torch
import pandas as pd
import numpy as np

from quapy.method.aggregative import ACC

from src.prev.data_loading import get_values, Kind, Split, all_tasks, example_tasks
from src.prev.calibration import calibrate_logits_fast, CalibrationMethod
from src.prev.scaling import scale_prevalences_ir
from src.prev.plotting import plot_aggregate_results, Confidence, box_plot, multiplot
from src.prev.metrics import Metric, compute_all_metrics, compute_metric
from src.prev.quantification import adjust_priors_qp

current_path = os.getcwd()
DATA_PATH = Path(current_path).parent / 'data'
RESULT_PATH = Path(current_path).parent / 'results'
assert DATA_PATH.exists() and RESULT_PATH.exists()
torch.manual_seed(seed=0)

# set to false for a full rerun, including all tasks, note that we do not provide
# logits for all tasks for licensing considerations, so one needs to generate them from scratch
EXAMPLE_TASKS_ONLY = True
if EXAMPLE_TASKS_ONLY:
    rel_tasks = example_tasks
else: rel_tasks = all_tasks

<torch._C.Generator at 0x7f4a340da050>

In [2]:
data = {}
for t in ProgIter(rel_tasks, desc='Loading data'):
    data[t] = get_values(t, DATA_PATH, proj='mic23_predictions_original_0')  # original paper predictions

Loading data 100.00% 30/30... rate=5.46 Hz, eta=0:00:00, total=0:00:05


In [36]:
IRS = list(np.arange(1, 10.5, 0.5))


def metrics_across_ir(task_data, calibration: CalibrationMethod = CalibrationMethod.NONE):
    """ Computes metrics values across imbalance ratios for a given task."""
    # initialize results dictionary
    results = {m: [] for m in Metric}
    results.update({"reference " + m.value: [] for m in Metric})
    #iterate over the imbalance ratio range
    for ir in IRS:
        # scale prevalences in the deployment set according to imbalance ratio
        app_test_logits, app_test_classes = scale_prevalences_ir(logits=task_data[Kind.LOGITS][Split.APP_TEST],
                                                                 classes=task_data[Kind.LABELS][Split.APP_TEST],
                                                                 ir=ir)
        # create a data dictionary with the modified deployment test set
        mod_data = {Kind.LOGITS: {Split.DEV_CAL: task_data[Kind.LOGITS][Split.DEV_CAL],
                                  Split.DEV_TEST: task_data[Kind.LOGITS][Split.DEV_TEST],
                                  Split.APP_TEST: app_test_logits},
                    Kind.LABELS: {Split.DEV_CAL: task_data[Kind.LABELS][Split.DEV_CAL],
                                  Split.DEV_TEST: task_data[Kind.LABELS][Split.DEV_TEST],
                                  Split.APP_TEST: app_test_classes}}

        # extract the minority class for F1 computation
        val_prevalences = torch.bincount(mod_data[Kind.LABELS][Split.DEV_CAL])
        min_class = torch.argmin(val_prevalences).item()
        max_class = torch.argmax(val_prevalences).item()
        # catches case where for balanced task, min class was used as max class in scaling
        if min_class == max_class:
            min_class = 1
        #compute the exact prevalence in the scaled deployment set
        exact_prevalence = torch.bincount(mod_data[Kind.LABELS][Split.APP_TEST]) / len(
            mod_data[Kind.LABELS][Split.APP_TEST])

        # compute EC estimated separate - both calibration and EC adjustment rely on prevalence estimation
        estimated_prevalence = torch.Tensor(
            adjust_priors_qp(torch.softmax(mod_data[Kind.LOGITS][Split.DEV_TEST], dim=1),
                             mod_data[Kind.LABELS][Split.DEV_TEST],
                             torch.softmax(mod_data[Kind.LOGITS][Split.APP_TEST], dim=1),
                             mod_data[Kind.LABELS][Split.APP_TEST], method=ACC))
        # calibrate logits and get estimated prevalence
        if calibration == CalibrationMethod.AFFINE_REWEIGHTED:
            prior = exact_prevalence
        elif calibration == CalibrationMethod.NONE:
            prior = None
        else:
            raise ValueError(f'invalid calibration method: {calibration}')
        calibrated_logits = calibrate_logits_fast(data=mod_data, calibration=calibration, prior=prior)

        # compute predictions on scaled deployment set and development test set according to argmax decision rule
        new_app_test_preds = torch.argmax(calibrated_logits[Split.APP_TEST], dim=1)
        dev_test_preds = torch.argmax(calibrated_logits[Split.DEV_TEST], dim=1)

        #compute the metrics on the deployment and development sets
        dep_metrics = compute_all_metrics(mod_data[Kind.LABELS][Split.APP_TEST], calibrated_logits[Split.APP_TEST],
                                          new_app_test_preds,
                                          min_class=min_class, exact_priors=exact_prevalence,
                                          estimated_priors=exact_prevalence)
        dev_metrics = compute_all_metrics(mod_data[Kind.LABELS][Split.DEV_TEST], calibrated_logits[Split.DEV_TEST],
                                          dev_test_preds,
                                          min_class=min_class, exact_priors=exact_prevalence,
                                          estimated_priors=estimated_prevalence)
        # recompute EC estimated separate since re-calibration has to rely on estimation as well!
        if calibration == CalibrationMethod.AFFINE_REWEIGHTED:
            calibrated_logits_est = calibrate_logits_fast(data=mod_data, calibration=calibration, prior=estimated_prevalence)
            estimated_app_test_preds = torch.argmax(calibrated_logits_est[Split.APP_TEST], dim=1)
            dep_metrics[Metric.EC_EST] = compute_metric(Metric.EC_ADJUSTED, mod_data[Kind.LABELS][Split.APP_TEST],
                                                        estimated_app_test_preds, exact_priors=exact_prevalence,
                                                        min_class=min_class)
            estimated_dev_test_preds = torch.argmax(calibrated_logits_est[Split.DEV_TEST], dim=1)
            dev_metrics[Metric.EC_EST] = compute_metric(Metric.EC_EST, mod_data[Kind.LABELS][Split.DEV_TEST],
                                                        estimated_dev_test_preds,
                                                        min_class=min_class, estimated_priors=estimated_prevalence)
        # append metrics to the results dictionary
        for k in dep_metrics.keys():
            results[k].append(dep_metrics[k])
            results['reference ' + k.value].append(dev_metrics[k])
    return results

In [37]:
# compute metrics at different IRs for different calibration methods and estimated priors
for cal in [CalibrationMethod.NONE, CalibrationMethod.AFFINE_REWEIGHTED]:
    collector = {t: {} for t in rel_tasks}
    for t in ProgIter(rel_tasks):
        # compute metrics for the task at different imbalance ratios
        collector[t].update(metrics_across_ir(task_data=data[t], calibration=cal))
    # save results to a dataframe
    result_df = pd.DataFrame(rel_tasks, columns=["name"])
    for key in collector[rel_tasks[0]].keys():
        result_df[key] = [torch.tensor(collector[t][key]) for t in rel_tasks]
    result_df.to_pickle(RESULT_PATH / ("24_metric_performance_" + cal.value + ".pkl"))

 100.00% 30/30... rate=1.42 Hz, eta=0:00:00, total=0:00:21
 100.00% 30/30... rate=0.97 Hz, eta=0:00:00, total=0:00:30


## Plot of Figure 9

In [3]:
#initialize the list of subplots
subplts = []
# the metrics to be plotted
metrics = [Metric.ACCURACY, Metric.F1, Metric.MCC, Metric.BALANCED_ACC, Metric.EC_EST, Metric.EC_ADJUSTED]
# iterate over calibration methods
for cal in [CalibrationMethod.NONE, CalibrationMethod.AFFINE_REWEIGHTED]:
    # load results file
    result_df = pd.read_pickle(RESULT_PATH / ("24_metric_performance_" + cal.value + ".pkl"))
    # compute difference
    for m in metrics:
        result_df[m.value] = result_df[m] - result_df["reference " + m.value]
    #compute value of metrics at imbalance ratio 10
    ir_10_metrics = pd.DataFrame(result_df['name'])
    for m in metrics:
        ir_10_metrics[m.value] = np.abs(np.stack(result_df[m.value].values))[:, -1]
    # create subplot
    subplt = plot_aggregate_results(result_df, metrics=metrics, file=None,
                                    delta=False, ci=Confidence.STD,
                                    opacity=0.15, line_width=4, bound=[0, 0.3], plot_lines_later=True)
    # append subplot to subplots list
    subplts.append(subplt)
    # create sub box plot
    subplt = box_plot(ir_10_metrics, metrics)
    # append sub box plot to subplots list
    subplts.append(subplt)

In [5]:
# define y axis ranges for subplots
sub_y_ranges = {0: [0, 0.3], 1: [-0.02, 0.48], 2: [0, 0.3], 3: [-0.02, 0.48]}
# create Figure 10 from the paper
fig = multiplot(rows=2, cols=2, row_titles=["w/o re-calibration", "with re-calibration"],
                y_title="Absolute difference to metric score on D<sub>dev</sub> set",
                subplts=subplts, horizontal_spacing=0.03, vertical_spacing=0.03, legend_index=1,
                sub_x_axis_titles={2: "Imbalance ratio"}, sub_y_ranges=sub_y_ranges, height=800, shared_yaxes=False,
                little_guys=True, icon_axes=[3], icon_size=0.08, ir_axes=[2, 4], ir_values=[10, 10])
fig

In [6]:
# save figures
fig.write_image(RESULT_PATH / f"24_fig_9.png")
fig.write_image(RESULT_PATH / f"24_fig_9.pdf")
fig.write_html(RESULT_PATH / f"24_fig_9.html")