# RQ2a: What is the effect of prevalence shifts on the quality of calibration
This notebook generates figure 5. It assesses re-calibration methods in our use cases.

In [1]:
import os
import sys

sys.path.insert(0, os.path.abspath('..'))
from pathlib import Path
from typing import Dict

import numpy as np
import pandas as pd
import plotly.subplots
import plotly.graph_objects as go
import torch
from progiter import ProgIter

import IPython
from src.prev.calibration import CalibrationMethod, calc_calibration_metrics, calibrate_logits_fast
from src.prev.data_loading import get_values, Kind, Split, all_tasks, example_tasks
from src.prev.plotting import plot_aggregate_results, Confidence, box_plot, add_little_guys, add_ir_annotation
from src.prev.scaling import scale_prevalences_ir

current_path = os.getcwd()
DATA_PATH = Path(current_path).parent / 'data'
RESULT_PATH = Path(current_path).parent / 'results'
assert DATA_PATH.exists() and RESULT_PATH.exists()
torch.manual_seed(seed=0)

IRS = list(np.arange(1, 10.5, 0.5))

# set to false for a full rerun, including all tasks, note that we do not provide
# logits for all tasks for licensing considerations, so one needs to generate them from scratch
EXAMPLE_TASKS_ONLY = True
if EXAMPLE_TASKS_ONLY:
    rel_tasks = example_tasks
else: rel_tasks = all_tasks

In [2]:
data = {}
for t in ProgIter(rel_tasks, desc='Loading data'):
    data[t] = get_values(t, DATA_PATH, proj='mic23_predictions_original_0')  # original paper predictions

Loading data 100.00% 30/30... rate=5.55 Hz, eta=0:00:00, total=0:00:05


In [3]:
train_adapted_data = {}
train_adapted_estimated_data = {}
for t in ProgIter(rel_tasks, desc='Loading data'):
    train_adapted_data[t] = {}
    train_adapted_estimated_data[t] = {}
    for ir in IRS:
        train_adapted_data[t][ir] = get_values(t, DATA_PATH,
                                               proj=f'mic23_predictions_extension_balanced_0_{ir}')  # additional predictions
        train_adapted_estimated_data[t][ir] = get_values(t, DATA_PATH,
                                                         proj=f'mic23_predictions_extension_balanced_estimated_0_{ir}')  # additional predictions

Loading data 100.00% 30/30... rate=0.08 Hz, eta=0:00:00, total=0:06:12


In [4]:
# results from notebook 3
estimated_prevalences = pd.read_pickle(RESULT_PATH / '24_prev_estimation_df.pkl').set_index(['task', 'ir'])


def get_estimated_prevalences(task: str, target_ir: float, method: str = 'ACC') -> torch.Tensor:
    return torch.tensor(estimated_prevalences[method].loc[task, target_ir])

In [5]:
# Research Question 2 a): What are the effects of prevalence shifts on the quality of calibration

calibration_ir_results = []
for cal_method in ProgIter(list(CalibrationMethod)[:]):
    for ir in IRS:
        # do the calibration on all tasks
        calibrated_test_data: Dict[str, Dict[Kind, Dict[Split, torch.Tensor]]] = {}
        for t in rel_tasks:
            # STEP 1: modify DEV_TEST according to IR
            if cal_method in [CalibrationMethod.ADAPTED_TRAIN_WEIGHTS,
                              CalibrationMethod.ADAPTED_TRAIN_WEIGHTS_AND_AFFINE_SCALING_REWEIGHTED,
                              CalibrationMethod.ADAPTED_TRAIN_WEIGHTS_AND_TEMPERATURE_SCALING_REWEIGHTED]:
                task_data = train_adapted_data[t][ir]
            elif cal_method in [CalibrationMethod.ADAPTED_TRAIN_WEIGHTS_ACC,
                                CalibrationMethod.ADAPTED_TRAIN_WEIGHTS_AND_AFFINE_SCALING_REWEIGHETD_ACC,
                                CalibrationMethod.ADAPTED_TRAIN_WEIGHTS_AND_TEMPERATURE_SCALING_REWEIGHTED_ACC]:
                task_data = train_adapted_estimated_data[t][ir]
            else:
                task_data = data[t]
            app_test_logits, app_test_classes = scale_prevalences_ir(logits=task_data[Kind.LOGITS][Split.APP_TEST],
                                                                     classes=task_data[Kind.LABELS][Split.APP_TEST],
                                                                     ir=ir)
            mod_data = {Kind.LOGITS: {Split.DEV_CAL: task_data[Kind.LOGITS][Split.DEV_CAL],
                                      Split.DEV_TEST: task_data[Kind.LOGITS][Split.DEV_TEST],
                                      Split.APP_TEST: app_test_logits},
                        Kind.LABELS: {Split.DEV_CAL: task_data[Kind.LABELS][Split.DEV_CAL],
                                      Split.DEV_TEST: task_data[Kind.LABELS][Split.DEV_TEST],
                                      Split.APP_TEST: app_test_classes}}
            # STEP 2: determine prior knowledge
            prior = None  # by default, we know nothing
            if cal_method in [CalibrationMethod.AFFINE_REWEIGHTED, CalibrationMethod.TEMPERATURE_SCALING_REWEIGHTED,
                              CalibrationMethod.ADAPTED_TRAIN_WEIGHTS_AND_AFFINE_SCALING_REWEIGHTED,
                              CalibrationMethod.ADAPTED_TRAIN_WEIGHTS_AND_TEMPERATURE_SCALING_REWEIGHTED]:
                # adapt prevalences from DEV_CAL to APP_TEST (the latter is balanced)
                prior = torch.bincount(app_test_classes)
                # scaling for convergence stability
                prior = prior / prior.sum()
            elif cal_method in [CalibrationMethod.AFFINE_ACC,
                                CalibrationMethod.ADAPTED_TRAIN_WEIGHTS_AND_TEMPERATURE_SCALING_REWEIGHTED_ACC,
                                CalibrationMethod.ADAPTED_TRAIN_WEIGHTS_AND_AFFINE_SCALING_REWEIGHETD_ACC]:
                prior = get_estimated_prevalences(task=t, target_ir=ir)
            # STEP 3: re-calibrate
            if cal_method in [CalibrationMethod.ADAPTED_TRAIN_WEIGHTS_AND_AFFINE_SCALING_REWEIGHTED,
                              CalibrationMethod.ADAPTED_TRAIN_WEIGHTS_AND_TEMPERATURE_SCALING_REWEIGHTED,
                              CalibrationMethod.ADAPTED_TRAIN_WEIGHTS_AND_AFFINE_SCALING_REWEIGHETD_ACC,
                              CalibrationMethod.ADAPTED_TRAIN_WEIGHTS_AND_TEMPERATURE_SCALING_REWEIGHTED_ACC]:
                affine = 'AFFINE' in cal_method.name
                calibrated_logits = calibrate_logits_fast(data=mod_data,
                                                          calibration=CalibrationMethod.AFFINE_REWEIGHTED if affine else CalibrationMethod.TEMPERATURE_SCALING_REWEIGHTED,
                                                          prior=prior)
            elif 'ADAPTED_TRAIN_WEIGHTS' not in cal_method.name:
                calibrated_logits = calibrate_logits_fast(data=mod_data,
                                                          calibration=cal_method, prior=prior)
            else:
                # adapted without post calibration
                calibrated_logits = {Split.DEV_TEST: mod_data[Kind.LOGITS][Split.DEV_TEST],
                                     Split.APP_TEST: mod_data[Kind.LOGITS][Split.APP_TEST]}
            # STEP 4: calculate calibration metrics
            # suppress plotting from the metrics reloaded
            with IPython.utils.io.capture_output():
                dev_metrics = calc_calibration_metrics(logits=calibrated_logits[Split.DEV_TEST],
                                                       labels=task_data[Kind.LABELS][Split.DEV_TEST])
                app_metrics = calc_calibration_metrics(logits=calibrated_logits[Split.APP_TEST],
                                                       labels=app_test_classes)
            # going from dev to test
            diff_metrics = {m: app_metrics[m] - dev_metrics[m] for m in dev_metrics}
            _info = {'calibration': cal_method.name, 'ir': ir, 'task': t}
            _info.update({f'dev_{m}': v for m, v in dev_metrics.items()})
            _info.update({f'app_{m}': v for m, v in app_metrics.items()})
            _info.update({f'diff_{m}': v for m, v in diff_metrics.items()})
            calibration_ir_results.append(_info)
            # plot calibration curves (binary tasks only, DEV_TEST visualization)
ir_df = pd.DataFrame(calibration_ir_results)
ir_df.to_csv(RESULT_PATH / '24_recalibration_results.csv')

 100.00% 12/12... rate=0.07 Hz, eta=0:00:00, total=0:03:01


In [2]:
# plotting (if results are already stored)
p = RESULT_PATH / '24_recalibration_results.csv'
DELTA = False
plot_methods_all = [CalibrationMethod.NONE, CalibrationMethod.TEMPERATURE_SCALING,
                    CalibrationMethod.TEMPERATURE_SCALING_REWEIGHTED, CalibrationMethod.AFFINE,
                    CalibrationMethod.ADAPTED_TRAIN_WEIGHTS_ACC,
                    CalibrationMethod.ADAPTED_TRAIN_WEIGHTS, CalibrationMethod.AFFINE_ACC,
                    CalibrationMethod.AFFINE_REWEIGHTED]
plot_methods_unweighted = [CalibrationMethod.NONE, CalibrationMethod.TEMPERATURE_SCALING, CalibrationMethod.AFFINE]
plot_methods_retrained = [CalibrationMethod.ADAPTED_TRAIN_WEIGHTS, CalibrationMethod.ADAPTED_TRAIN_WEIGHTS_ACC,
                          CalibrationMethod.ADAPTED_TRAIN_WEIGHTS_AND_AFFINE_SCALING_REWEIGHTED]
plot_methods_proposed = [CalibrationMethod.TEMPERATURE_SCALING_REWEIGHTED,
                         CalibrationMethod.AFFINE_REWEIGHTED, CalibrationMethod.AFFINE_ACC]

plot_methods_main = plot_methods_all
plot_methods_supp = [CalibrationMethod.NONE, CalibrationMethod.ADAPTED_TRAIN_WEIGHTS_ACC,
                     CalibrationMethod.ADAPTED_TRAIN_WEIGHTS, CalibrationMethod.AFFINE_ACC,
                     CalibrationMethod.AFFINE_REWEIGHTED,
                     CalibrationMethod.ADAPTED_TRAIN_WEIGHTS_AND_AFFINE_SCALING_REWEIGHETD_ACC,
                     CalibrationMethod.ADAPTED_TRAIN_WEIGHTS_AND_AFFINE_SCALING_REWEIGHTED]
left_plots = []
left_plots_bs = []
right_plots = []
right_plots_bs = []
if p.exists():
    ir_df = pd.read_csv(p)
    for metric in ['rbs', 'cwce', 'bs', 'nbs']:
        for variant, plot_methods in zip(['main', 'main', 'main', 'main', 'main', 'supp'],
                                         [plot_methods_all, plot_methods_unweighted, plot_methods_retrained,
                                          plot_methods_proposed, plot_methods_main, plot_methods_supp]):
            if variant == 'main':
                bound = {'rbs': None, 'bs': None, 'nbs': None, 'cwce': None}[metric]
            else:
                bound = None
            plots_base_path = RESULT_PATH / f'24_calibration_deterioration_{metric}_{DELTA}_{variant}'
            fig = plotly.subplots.make_subplots(rows=1, cols=2,
                                                y_title=f'<b>{metric.upper()}</b>',
                                                horizontal_spacing=0.04,
                                                shared_yaxes=True)
            fig.update_layout(template='plotly')


            # first generate the data for the line plot
            def line_values_extraction_func(results_df, line_identifier, delta):
                _col = f'diff_{metric}' if delta else f'app_{metric}'
                _df = results_df[results_df['calibration'] == line_identifier.name][['task', 'ir', _col]]
                _df = _df.pivot(index='task', columns='ir', values=_col)
                _df = _df.abs()
                return _df.values


            line_fig_data = plot_aggregate_results(info_df=ir_df, line_ids=plot_methods,
                                                   file=None,
                                                   line_values_extraction_func=line_values_extraction_func,
                                                   bound=bound,
                                                   ci=Confidence.STD,
                                                   y_axis_title=metric.upper(),
                                                   ).data
            if metric == 'cwce':
                left_plots.append(line_fig_data)
            elif metric == 'bs':
                left_plots_bs.append(line_fig_data)
            # build line plot into fig
            for elem in line_fig_data:
                elem['showlegend'] = False
                fig.add_trace(elem, row=1, col=1)
            # generate box plot and pass into fig
            _col = f'diff_{metric}' if DELTA else f'app_{metric}'
            _df = ir_df[ir_df['ir'] == 10.0][[_col, 'calibration', 'task']]
            _df = _df.pivot(index='task', columns='calibration', values=_col)
            box_fig_data = box_plot(
                df=_df.rename(columns={cal.name: cal.value for cal in CalibrationMethod}),
                line_ids=plot_methods).data
            if metric == 'cwce':
                right_plots.append(box_fig_data)
            elif metric == 'bs':
                right_plots_bs.append(box_fig_data)
            for elem in box_fig_data:
                fig.add_trace(elem, row=1, col=2)
            # layout stuff
            fig.update_layout(autosize=False, width=1800, height=500,
                              legend={"font": {"size": 30}, "itemsizing": 'constant'})
            # fix distances
            fig.update_layout(margin=dict(t=0, b=0))
            fig['layout']['xaxis']['title'] = "<b>Imbalance ratio</b>"
            fig['layout']['yaxis']['range'] = bound
            fig['layout']['font']['size'] = 20
            for i in fig['layout']['annotations']:
                i['font'] = dict(size=30)
            fig['layout']['font']['family'] = "NewComputerModern10"
            joint_path = plots_base_path.parent / (plots_base_path.name + '_joint_white_background')
            fig.update_layout(
                images=[
                    go.layout.Image(
                        source="../data/little_guys_balanced.png",  # URL or local path to the image
                        xref="paper", yref="paper",
                        x=0, y=-0.15,
                        sizex=0.15, sizey=0.15,  # Size of the image
                        xanchor="left", yanchor="middle"
                    ),
                    go.layout.Image(
                        source="../data/little_guys_high_ir.png",  # URL or local path to the image
                        xref="paper", yref="paper",
                        x=0.48, y=-0.15,
                        sizex=0.15, sizey=0.15,  # Size of the image
                        xanchor="right", yanchor="middle"
                    )
                ]
            )
            # fig.write_image(joint_path.with_suffix(suffix='.svg'))
            # fig.write_image(joint_path.with_suffix(suffix='.png'))
            # fig.write_html(joint_path.with_suffix(suffix='.html'))

## Generate Figure 5

In [4]:
fig = plotly.subplots.make_subplots(rows=2, cols=2,
                                    horizontal_spacing=0.04,
                                    vertical_spacing=0.02,
                                    shared_yaxes=False,
                                    shared_xaxes=True
                                    )
fig.update_layout(template='plotly_white')

# plotting shaded regions first, then lines in reversed order so that Affine and No re-calibration lines are not covered
for i, elem in enumerate(left_plots[0]):
    if i % 2 == 0:
        elem['showlegend'] = False
        fig.add_trace(elem, row=1, col=1)
for i, elem in enumerate(reversed(left_plots[0])):
    if i % 2 == 0:
        elem['showlegend'] = False
        fig.add_trace(elem, row=1, col=1)
for i, elem in enumerate(left_plots_bs[0]):
    if i % 2 == 0:
        elem['showlegend'] = False
        fig.add_trace(elem, row=2, col=1)
for i, elem in enumerate(reversed(left_plots_bs[0])):
    if i % 2 == 0:
        elem['showlegend'] = False
        fig.add_trace(elem, row=2, col=1)

for elem in right_plots[0]:
    fig.add_trace(elem, row=1, col=2)
for elem in right_plots_bs[0]:
    elem['showlegend'] = False
    fig.add_trace(elem, row=2, col=2)
fig.add_trace(left_plots_bs[0][-1], row=2, col=1)
fig['layout']['yaxis']['range'] = [0, 0.25]
fig['layout']['yaxis3']['range'] = [0, 0.23]
fig['layout']['yaxis4']['range'] = [0, 0.23]
fig.update_layout(autosize=False, width=1800, height=800,
                  legend={"font": {"size": 30}, "itemsizing": 'constant'}, margin=dict(t=10))

fig['layout']['font']['family'] = "NewComputerModern10"
fig['layout']['xaxis3']['title'] = "<b>Imbalance ratio</b>"
fig['layout']['yaxis']['title'] = "<b>CWCE</b>"
fig['layout']['yaxis3']['title'] = "<b>BS</b>"
add_ir_annotation(fig, 2)
add_ir_annotation(fig, 4)
fig['layout']['font']['size'] = 20
add_little_guys(fig, 3, size=0.08)
fig.update_annotations(font_size=20, font_color='black', bgcolor='white')

In [6]:
fig.write_image(RESULT_PATH / f"24_fig_5.png")
fig.write_image(RESULT_PATH / f"24_fig_5.svg")
fig.write_image(RESULT_PATH / f"24_fig_5.pdf")
fig.write_html(RESULT_PATH / f"24_fig_5.html")