# Test-time recalibration of conformal predictors
This notebook reproduces the figures in the paper.

In [1]:
from pathlib import Path
import numpy as np
import pandas as pd

import plotly
import plotly.express as px

In [2]:
def update_yaxes_matches(fig: plotly.graph_objs._figure.Figure, num_facets: int):
    r"""update yaxes matches for independent facet customization"""
    fig.layout['yaxis'].matches = 'y1'
    for i in range(2,num_facets+1):
        fig.layout[f'yaxis{i}'].matches = f'y{i}'

## TPS under distribution shift for fixed $\alpha=0.1$

### Figure 2

In [3]:
tps_bars_data_file = "../experiments/cache/tps_recalibration_results.csv"

# parameters
arch = 'resnet50'
alpha = 0.1

use_classes_of_dataset = None

In [4]:
tps_bars_df = pd.read_csv(tps_bars_data_file)

tps_bars_df = tps_bars_df.drop(columns=["Classifier"])
tps_bars_df = tps_bars_df.query("alpha == @alpha and Dataset != 'nonliving26'")
tps_bars_df = tps_bars_df.sort_values(by=['shift_type', 'Dataset', 'alpha', 'Method'], ascending=[False, False, True, True], ignore_index=True)

In [16]:
fig = px.bar(
    tps_bars_df,
    x='Method', y='predicted_coverage', 
    barmode='group',
    # facet_row='shift_type', 
    facet_row_spacing=0.15,
    facet_col='Dataset', facet_col_wrap=3, facet_col_spacing=0.08,
    category_orders={
        "Dataset": ["ImageNetV2", "ImageNet-Sketch", "ImageNet-R", "entity13", "entity30", "living17"]
    },
    # range_y=[0.5, 0.92],
    title="TPS coverage under distribution shift w/ recalibration",
    labels = {'predicted_coverage': 'achieved coverage'},
    # template="simple_white"
)

fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))


fig.update_yaxes(matches=None, showticklabels=True)
# fig.update_yaxes(showticklabels=True)
# update_yaxes_matches(fig, num_facets=6)

fig.update_yaxes(range=[0.80, 0.91], selector=3)
fig.update_yaxes(range=[0.35, 0.93], selector=4)
fig.update_yaxes(range=[0.30, 0.95], selector=5)
fig.update_yaxes(range=[0.58, 0.93], selector=0)
fig.update_yaxes(range=[0.53, 0.93], selector=1)
fig.update_yaxes(range=[0.52, 0.93], selector=2)
# fig.layout['yaxis'].update(range=[0.855, 0.9], matches=None)
# fig.update_traces(textposition="bottom right")


fig.add_hline(
    y=0.9, line_dash="dash", line_color='red', line_width=2, opacity=1,
    # annotation_text=r"$1-\alpha$",
    # annotation_position="top left"
)

fig.update_layout(height=800)

fig.show()

## TPS-APS under distribution shift for varying $\alpha$

### Figure 3

In [18]:
# parameters
arch = 'resnet50'
alpha = 0.1

kreg = 2
lamda = 0.2

use_classes_of_dataset = None
use_platt_scaling = False

In [32]:
rcp_lines_df = pd.read_csv("/root/dockspace/clean_recalibrating_cp/experiments/cache/recalibration_results.csv")

In [35]:
rcp_lines_df = rcp_lines_df.query("Method in ['$y=x$', 'original', 'QTC', 'QTC-SC', 'QTC-ST', 'CHR-']")

# unstable setting
rcp_lines_df = rcp_lines_df.query("(ConformalPredictor != 'TPS' | Dataset != 'ImageNet-R') | Method != 'CHR-'")

rcp_lines_df = rcp_lines_df.query("((raps_kreg == @kreg) & (raps_lambda == @lamda)) | ConformalPredictor == 'APS' | ConformalPredictor == 'TPS'")

fig = px.line(
    rcp_lines_df,
    x='1_alpha', y='predicted_coverage', color='Method', markers=True, facet_col='Dataset', facet_col_spacing=0.08,
    facet_row='ConformalPredictor',
    category_orders={
        "Method": ['$y=x$', 'original', 'QTC', 'QTC-SC', 'QTC-ST', 'CHR-'],
        "ConformalPredictor": ["TPS", "APS", "RAPS"],
        "Dataset": ["ImageNetV2", "ImageNet-Sketch", "ImageNet-R"]
    },
    title="TPS, APS, and RAPS coverage under distribution shift w/ and w/o recalibration",
    labels = {
        '1_alpha': r'$1-\alpha$',
        'predicted_coverage': 'achieved coverage'
    }
)

fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))

fig.update_yaxes(matches=None, showticklabels=True)

fig.update_layout(height=800)

fig.show()

## QTC vs. Covariate shift based methods
This section contains the results for the experiments carried out on the distribution shifts of the [DomainNet dataset](https://paperswithcode.com/dataset/domainnet).

### Figure 4

In [49]:
dom_info_df = pd.concat(
    [
        pd.read_csv("../experiments/cache/domainnet_all2info.csv"),
        pd.read_csv("../experiments/cache/domainnet_real2info.csv"),
    ], ignore_index=True,
)

dom_info_df['1_alpha'] = 1.0 - dom_info_df['alpha']

yx_sub_df = dom_info_df[dom_info_df['Method'] == 'QTC'].copy().drop(columns=['predicted_size', 'calibrated_tau', 'predicted_beta', 'predicted_tau'])
yx_sub_df['Method'] = r"$y=x$"
yx_sub_df['predicted_coverage'] = yx_sub_df['1_alpha']

orig_sub_df = dom_info_df[dom_info_df['Method'] == 'QTC-SC'].copy().drop(columns=['predicted_size', 'calibrated_tau', 'predicted_beta', 'predicted_tau'])
orig_sub_df['Method'] = 'original'
orig_sub_df['predicted_coverage'] = orig_sub_df['original_coverage']

dom_info_df = pd.concat([yx_sub_df, orig_sub_df, dom_info_df], ignore_index=True)

#### Figure 4 (top row)

In [52]:
dom_info_df = dom_info_df.query("Method not in ['PS-C', 'PS-M', 'PS-R']")

fig = px.line(
    dom_info_df,
    x='1_alpha', y='predicted_coverage', color='Method', markers=True, facet_col='source_data', facet_col_spacing=0.08,
    category_orders={
        "Method": ['$y=x$', 'original', 'QTC', 'QTC-SC', 'QTC-ST', 'PS-W', 'WSCI'],
        "source_data": ["DomainNetAll", "DomainNetReal"],
    },
    title="QTC comparison to covariate shift based methods on DomainNet",
    labels = {
        '1_alpha': r'$1-\alpha$',
        'predicted_coverage': 'achieved coverage',
    }
)

fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))

fig.update_yaxes(matches=None, showticklabels=True)

fig.update_layout(height=800)

fig.show()

#### Figure 4 (bottom row)

In [57]:
dom_info_size_df = dom_info_df.query("Method not in ['$y=x$', 'original', 'PS-C', 'PS-M', 'PS-R']")

fig = px.line(
    dom_info_size_df,
    log_y=True,
    x='1_alpha', y='predicted_size', color='Method', markers=True,
    facet_col='source_data', facet_col_spacing=0.08,
    category_orders={
        "Method": ['$y=x$', 'original', 'QTC', 'QTC-SC', 'QTC-ST', 'PS-W', 'WSCI'],
        "source_data": ["DomainNetAll", "DomainNetReal"],
    },
    title="QTC comparison to covariate shift based methods on DomainNet",
    labels = {
        '1_alpha': r'$1-\alpha$',
        'predicted_size': 'avg. set size',
    }
)

fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))

fig.update_yaxes(matches=None, showticklabels=True)

fig.update_layout(height=800)

fig.show()