In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

from tqdm import tqdm
import torch
from torch import nn
from torchvision.models import (
    resnet18, ResNet18_Weights,
    resnet34, ResNet34_Weights,
    resnet50, ResNet50_Weights,
    resnet101, ResNet101_Weights,
    vit_b_16, ViT_B_16_Weights,
    vit_b_32, ViT_B_32_Weights,
    vgg16, VGG16_Weights, 
    vgg16_bn, VGG16_BN_Weights,
    convnext_tiny, ConvNeXt_Tiny_Weights,
    convnext_base, ConvNeXt_Base_Weights,
    efficientnet_v2_s, EfficientNet_V2_S_Weights,
    efficientnet_v2_m, EfficientNet_V2_M_Weights
)
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset
from torchvision.transforms import Compose, Resize, Lambda, ToTensor, Grayscale, ToPILImage
import timm
from timm.data import resolve_data_config, create_transform
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('seaborn')
import PIL

torch.manual_seed(0)
np.random.seed(0)
rng = np.random.default_rng(0)

path_results = os.path.dirname(os.getcwd()) + '/results'
path_dataset = os.path.expandvars('$DSDIR/imagenet') # '/scratchf/'
path_imagenet_labels = os.path.expandvars('$WORK/DATA/LOC_synset_mapping.txt')
path_imagenet100_id = os.path.expandvars('$WORK/DATA/imagenet100.txt')

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)



In [None]:
BATCH_SIZE = 64
MODEL = 'ResNet50_V2' 

# TORCHVISION
models_and_weights_torchvision = {
    'ResNet18': (resnet18, ResNet18_Weights.IMAGENET1K_V1), # same/worse
    'ResNet34': (resnet34, ResNet34_Weights.IMAGENET1K_V1), # same/worse
    'ResNet50': (resnet50, ResNet50_Weights.IMAGENET1K_V1), # same/worse
    'ResNet50_V2': (resnet50, ResNet50_Weights.IMAGENET1K_V2), # better
    'ResNet101': (resnet101, ResNet101_Weights.IMAGENET1K_V1), # same/worse
    'ResNet101_V2': (resnet101, ResNet101_Weights.IMAGENET1K_V2), # better
    'ViT_B_16': (vit_b_16, ViT_B_16_Weights.IMAGENET1K_V1), # better
    'ViT_B_16_SWAG_E2E': (vit_b_16, ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1), # same/worse
    'ViT_B_16_SWAG_LINEAR': (vit_b_16, ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1), # same/worse
    'ViT_B_32': (vit_b_32, ViT_B_32_Weights.IMAGENET1K_V1), # better
    'VGG16': (vgg16, VGG16_Weights.IMAGENET1K_V1), # same/worse
    'VGG16_BN': (vgg16_bn, VGG16_BN_Weights.IMAGENET1K_V1), # same/worse
    'ConvNeXt_Tiny': (convnext_tiny, ConvNeXt_Tiny_Weights.IMAGENET1K_V1), # better
    'ConvNeXt_Base': (convnext_base, ConvNeXt_Base_Weights.IMAGENET1K_V1), # better
    'EfficientNet_V2_S': (efficientnet_v2_s, EfficientNet_V2_S_Weights.IMAGENET1K_V1), # better
    'EfficientNet_V2_M': (efficientnet_v2_m, EfficientNet_V2_M_Weights.IMAGENET1K_V1) # better
}

# TIMM
models_timm = [
    'vit_base_patch16_224', # same/worse
    'vit_base_patch16_224_in21k', # NEED TO CONVERT IN21K PREDICTIONS TO IN1K
    'vit_base_patch16_224_miil', # same/worse
    'vit_base_patch16_224_miil_in21k', # NEED TO CONVERT IN21K PREDICTIONS TO IN1K
    'vit_base_patch16_384', # same/worse
    'vit_base_patch32_224', # same/worse
    'vit_base_patch32_224_in21k', # NEED TO CONVERT IN21K PREDICTIONS TO IN1K
    'vit_base_patch32_384', # same/worse
    ]

if MODEL in models_and_weights_torchvision.keys():
    TORCHVISION_OR_TIMM = 'torchvision'
elif MODEL in models_timm:
    TORCHVISION_OR_TIMM = 'timm'
else:
    raise ValueError

In [None]:
# LOAD CLASSIFIER
if TORCHVISION_OR_TIMM == 'timm':
    classifier = timm.create_model(MODEL, pretrained=True).eval().to(device)
    transforms = timm.data.create_transform(**timm.data.resolve_data_config({}, model=classifier))

elif TORCHVISION_OR_TIMM == 'torchvision':
    model, weights = models_and_weights_torchvision[MODEL]
    classifier = model(weights=weights).eval().to(device)
    transforms = weights.transforms()


# LOAD DATA
dataset_train = ImageFolder(path_dataset+'/train', transform=transforms)
dataset_val = ImageFolder(path_dataset+'/val', transform=transforms)

dataloader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True, shuffle=False) # SHUFFLE FALSE IMPORTANT!
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True, shuffle=False) # SHUFFLE FALSE IMPORTANT!
id_to_idx = {}
idx_to_label = {}
with open(path_imagenet_labels) as f:
    for i, line in enumerate(f):
        id_to_idx[line[:9]] = i
        idx_to_label[i] = line[10:-1]

# LOAD CLASSIF OUTPUTS
df_train = pd.read_csv(path_results + f'/classif_outputs/{TORCHVISION_OR_TIMM}/classif_outputs_{MODEL}_train.csv', index_col=0)
df_val = pd.read_csv(path_results + f'/classif_outputs/{TORCHVISION_OR_TIMM}/classif_outputs_{MODEL}_val.csv', index_col=0)

# Stats per class

In [None]:
compute_stats_from = 'train'
# compute_stats_from = 'val'
use_pred_or_label = 'pred'
# use_pred_or_label = 'label'

if compute_stats_from == 'train':
    df = df_train
elif compute_stats_from == 'val':
    df = df_val
else:
    raise NotImplementedError()

targets = df[use_pred_or_label]

acc_per_class = np.empty(1000)
mean_msp_per_class = np.empty(1000)
std_msp_per_class = np.empty(1000)
min_msp_per_class = np.empty(1000)
max_msp_per_class = np.empty(1000)
mean_tcp_per_class = np.empty(1000)
for class_idx in range(1000):
    acc_per_class[class_idx] = df.loc[targets == class_idx, 'well_classified'].mean()
    mean_msp_per_class[class_idx] = df.loc[targets == class_idx, 'MSP'].mean()
    std_msp_per_class[class_idx] = df.loc[targets == class_idx, 'MSP'].std()
    min_msp_per_class[class_idx] = df.loc[targets == class_idx, 'MSP'].min()
    max_msp_per_class[class_idx] = df.loc[targets == class_idx, 'MSP'].max()
    mean_tcp_per_class[class_idx] = df.loc[targets == class_idx, 'TCP'].mean()

del df

plt.figure()
plt.hist(acc_per_class)
plt.xlabel('accuracy per class')

plt.figure()
plt.hist(mean_msp_per_class)
plt.xlabel('mean MSP per class')

plt.figure()
plt.scatter(mean_msp_per_class, acc_per_class)
plt.xlabel('mean MSP')
plt.ylabel('acc')

plt.figure()
plt.hist(min_msp_per_class, alpha=0.5, label='min MSP')
plt.hist(max_msp_per_class, alpha=0.5, label='max MSP')
plt.legend()

In [None]:
targets = df_train[use_pred_or_label]
acc_per_class_train = np.empty(1000)
for class_idx in range(1000):
    acc_per_class_train[class_idx] = df_train.loc[targets == class_idx, 'well_classified'].mean()

targets = df_val[use_pred_or_label]
acc_per_class_val = np.empty(1000)
for class_idx in range(1000):
    acc_per_class_val[class_idx] = df_val.loc[targets == class_idx, 'well_classified'].mean()

plt.figure()
plt.scatter(acc_per_class_train, acc_per_class_val)
plt.xlabel('acc_per_class_train')
plt.ylabel('acc_per_class_val')

# Selective classification

In [None]:
df = df_val

# baseline: max softmax
domain_cutoff_baseline = np.linspace(0, 1, 1000)
coverage_baseline = np.zeros_like(domain_cutoff_baseline)
risk_baseline = np.zeros_like(domain_cutoff_baseline)
acc_baseline = np.zeros_like(domain_cutoff_baseline)
for i, cut in enumerate(domain_cutoff_baseline):
    idx_domain = df['MSP'] > cut
    coverage_baseline[i] = idx_domain.mean()
    acc_baseline[i] = df.loc[idx_domain, 'well_classified'].mean()

# plot
fig, (ax1) = plt.subplots(1, 1, figsize=(10, 5))
ax1.set_title(f'{MODEL} - coverage vs. accuracy')
sc = ax1.scatter(coverage_baseline, acc_baseline, c=domain_cutoff_baseline, cmap='viridis')
fig.colorbar(sc, ax=ax1, label='MSP threshold')
ax1.set_xlabel('coverage')
ax1.set_ylabel('accuracy')

## compare 2 classes

In [None]:
targets = df['pred']

class_idx = 1
# baseline: max softmax
domain_cutoff_baseline = np.linspace(0, 1, 1000)
coverage_baseline = np.zeros_like(domain_cutoff_baseline)
risk_baseline = np.zeros_like(domain_cutoff_baseline)
acc_baseline = np.zeros_like(domain_cutoff_baseline)
for i, cut in enumerate(domain_cutoff_baseline):
    idx_domain = (targets == class_idx) & (df['MSP'] > cut)
    coverage_baseline[i] = idx_domain[targets == class_idx].mean()
    acc_baseline[i] = df.loc[idx_domain, 'well_classified'].mean()

# plot
fig, (ax1) = plt.subplots(1, 1, figsize=(10, 5))
ax1.set_title(f'{MODEL} - coverage vs. accuracy')
sc = ax1.scatter(coverage_baseline, acc_baseline, c=domain_cutoff_baseline, cmap='viridis')
fig.colorbar(sc, ax=ax1, label='MSP threshold class 1')
ax1.set_xlabel('coverage')
ax1.set_ylabel('accuracy')
print(acc_per_class_val[class_idx])

class_idx = 2
# baseline: max softmax
domain_cutoff_baseline = np.linspace(0, 1, 1000)
coverage_baseline = np.zeros_like(domain_cutoff_baseline)
risk_baseline = np.zeros_like(domain_cutoff_baseline)
acc_baseline = np.zeros_like(domain_cutoff_baseline)
for i, cut in enumerate(domain_cutoff_baseline):
    idx_domain = (targets == class_idx) & (df['MSP'] > cut)
    coverage_baseline[i] = idx_domain[targets == class_idx].mean()
    acc_baseline[i] = df.loc[idx_domain, 'well_classified'].mean()
print(acc_per_class_val[class_idx])
# plot
sc = ax1.scatter(coverage_baseline, acc_baseline, c=domain_cutoff_baseline, cmap='plasma')
fig.colorbar(sc, ax=ax1, label='MSP threshold class 2')

# Multiple classes

In [None]:
use_pred_or_label = 'pred'
targets = df_val[use_pred_or_label].tolist()

### scaling uses MSP or acc

In [None]:
nb_classes = 1000

fig, (ax) = plt.subplots(1, 1, figsize=(15, 5))
ax.set_title(f'{MODEL} - coverage vs. accuracy')

# baseline: max softmax
domain_cutoff_baseline = np.linspace(0, 1, 100)
coverage_baseline = np.zeros_like(domain_cutoff_baseline)
risk_baseline = np.zeros_like(domain_cutoff_baseline)
acc_baseline = np.zeros_like(domain_cutoff_baseline)
for i, cut in enumerate(domain_cutoff_baseline):
    idx_domain = df_val['MSP'] > cut
    coverage_baseline[i] = idx_domain.mean()
    acc_baseline[i] = df_val.loc[idx_domain, 'well_classified'].mean()
# plot
sc = ax.scatter(coverage_baseline, acc_baseline, c=domain_cutoff_baseline, cmap='Greys', vmin=0, vmax=1)
fig.colorbar(sc, ax=ax, label='MSP threshold')
ax.set_xlabel('coverage')
ax.set_ylabel('accuracy')


# domain_cutoff_baseline = np.linspace(0, 1, 100)
# coverage_baseline = np.zeros_like(domain_cutoff_baseline)
# risk_baseline = np.zeros_like(domain_cutoff_baseline)
# acc_baseline = np.zeros_like(domain_cutoff_baseline)
# for i, cut in enumerate(domain_cutoff_baseline):
#     k = mean_msp_per_class[targets] / mean_msp_per_class.max()
#     idx_domain = k * df_val['MSP'] > cut
#     coverage_baseline[i] = idx_domain.mean()
#     acc_baseline[i] = df_val.loc[idx_domain, 'well_classified'].mean()
# # plot
# sc = ax.scatter(coverage_baseline, acc_baseline, c=domain_cutoff_baseline, cmap='Purples', vmin=0, vmax=1, marker='+')
# fig.colorbar(sc, ax=ax, label='MSP threshold - scale with mean class MSP')
# ax.set_xlabel('coverage')
# ax.set_ylabel('accuracy')


domain_cutoff_baseline = np.linspace(0, 1, 100)
coverage_baseline = np.zeros_like(domain_cutoff_baseline)
risk_baseline = np.zeros_like(domain_cutoff_baseline)
acc_baseline = np.zeros_like(domain_cutoff_baseline)
for i, cut in enumerate(domain_cutoff_baseline):
    # k = 1 + 2*(acc_per_class[targets] - acc_per_class.mean())
    k = acc_per_class[targets] / acc_per_class.max()
    idx_domain = k * df_val['MSP'] > cut
    coverage_baseline[i] = idx_domain.mean()
    acc_baseline[i] = df_val.loc[idx_domain, 'well_classified'].mean()
# plot
sc = ax.scatter(coverage_baseline, acc_baseline, c=domain_cutoff_baseline, cmap='Greens', vmin=0, vmax=1, marker='+')
fig.colorbar(sc, ax=ax, label='MSP threshold - scale with mean class acc')
ax.set_xlabel('coverage')
ax.set_ylabel('accuracy')


# domain_cutoff_baseline = np.linspace(0, 1, 100)
# coverage_baseline = np.zeros_like(domain_cutoff_baseline)
# risk_baseline = np.zeros_like(domain_cutoff_baseline)
# acc_baseline = np.zeros_like(domain_cutoff_baseline)
# for i, cut in enumerate(domain_cutoff_baseline):
#     k = 1 / (acc_per_class[targets] / mean_msp_per_class[targets])
#     idx_domain = k * df_val['MSP'] > cut
#     coverage_baseline[i] = idx_domain.mean()
#     acc_baseline[i] = df_val.loc[idx_domain, 'well_classified'].mean()
# # plot
# sc = ax.scatter(coverage_baseline, acc_baseline, c=domain_cutoff_baseline, cmap='Reds', vmin=0, vmax=1, marker='+')
# fig.colorbar(sc, ax=ax, label='MSP threshold - tests')
# ax.set_xlabel('coverage')
# ax.set_ylabel('accuracy')


### normalization

In [None]:
fig, (ax) = plt.subplots(1, 1, figsize=(15, 5))
ax.set_title(f'{MODEL} - coverage vs. accuracy')

# baseline: max softmax
domain_cutoff_baseline = np.linspace(0, 1, 100)
coverage_baseline = np.zeros_like(domain_cutoff_baseline)
risk_baseline = np.zeros_like(domain_cutoff_baseline)
acc_baseline = np.zeros_like(domain_cutoff_baseline)
for i, cut in enumerate(domain_cutoff_baseline):
    idx_domain = df_val['MSP'] > cut
    coverage_baseline[i] = idx_domain.mean()
    acc_baseline[i] = df_val.loc[idx_domain, 'well_classified'].mean()
# plot
sc = ax.scatter(coverage_baseline, acc_baseline, c=domain_cutoff_baseline, cmap='Greys', vmin=0, vmax=1)
fig.colorbar(sc, ax=ax, label='MSP threshold')
ax.set_xlabel('coverage')
ax.set_ylabel('accuracy')


# domain_cutoff_baseline = np.linspace(-1, 1, 100)
# coverage_baseline = np.zeros_like(domain_cutoff_baseline)
# risk_baseline = np.zeros_like(domain_cutoff_baseline)
# acc_baseline = np.zeros_like(domain_cutoff_baseline)
# for i, cut in enumerate(domain_cutoff_baseline):
#     idx_domain = (df_val['MSP'] - mean_msp_per_class[targets]) / std_msp_per_class[targets] > cut
#     coverage_baseline[i] = idx_domain.mean()
#     acc_baseline[i] = df_val.loc[idx_domain, 'well_classified'].mean()
# # plot
# sc = ax.scatter(coverage_baseline, acc_baseline, c=domain_cutoff_baseline, cmap='Purples', vmin=0, vmax=1, marker='+')
# fig.colorbar(sc, ax=ax, label='MSP threshold - select per class (normalize mean std)')
# ax.set_xlabel('coverage')
# ax.set_ylabel('accuracy')


domain_cutoff_baseline = np.linspace(0, 1, 100)
coverage_baseline = np.zeros_like(domain_cutoff_baseline)
risk_baseline = np.zeros_like(domain_cutoff_baseline)
acc_baseline = np.zeros_like(domain_cutoff_baseline)
for i, cut in enumerate(domain_cutoff_baseline):
    idx_domain = (df_val['MSP'] - min_msp_per_class[targets]) / (max_msp_per_class[targets] - min_msp_per_class[targets]) > cut
    coverage_baseline[i] = idx_domain.mean()
    acc_baseline[i] = df_val.loc[idx_domain, 'well_classified'].mean()
# plot
sc = ax.scatter(coverage_baseline, acc_baseline, c=domain_cutoff_baseline, cmap='Blues', vmin=0, vmax=1, marker='+')
fig.colorbar(sc, ax=ax, label='MSP threshold - select per class (normalize min max)')
ax.set_xlabel('coverage')
ax.set_ylabel('accuracy')

# vs. Geifman 2017 Selective Classification for Deep Neural Networks

In [None]:
fig, (ax) = plt.subplots(1, 1, figsize=(10, 5))
ax.set_title(f'{MODEL} - coverage vs. risk')

# baseline: max softmax
domain_cutoff_baseline = np.linspace(0, 1, 100)
coverage_baseline = np.zeros_like(domain_cutoff_baseline)
risk_baseline = np.zeros_like(domain_cutoff_baseline)
acc_baseline = np.zeros_like(domain_cutoff_baseline)
for i, cut in enumerate(domain_cutoff_baseline):
    idx_domain = df_val['MSP'] > cut
    coverage_baseline[i] = idx_domain.mean()
    acc_baseline[i] = df_val.loc[idx_domain, 'well_classified'].mean()
# plot
sc = ax.scatter(coverage_baseline, 1-acc_baseline, c=domain_cutoff_baseline, cmap='Greys', vmin=0, vmax=1)
fig.colorbar(sc, ax=ax, label='MSP threshold')
ax.set_xlabel('coverage')
ax.set_ylabel('risk')

domain_cutoff_baseline = np.linspace(0, 1, 100)
coverage_baseline = np.zeros_like(domain_cutoff_baseline)
risk_baseline = np.zeros_like(domain_cutoff_baseline)
acc_baseline = np.zeros_like(domain_cutoff_baseline)
for i, cut in enumerate(domain_cutoff_baseline):
    k = acc_per_class[targets]
    idx_domain = k * df_val['MSP'] > cut
    coverage_baseline[i] = idx_domain.mean()
    acc_baseline[i] = df_val.loc[idx_domain, 'well_classified'].mean()
# plot
sc = ax.scatter(coverage_baseline, 1 - acc_baseline, c=domain_cutoff_baseline, cmap='Greens', vmin=0, vmax=1, marker='+')
fig.colorbar(sc, ax=ax, label='MSP threshold - scale with mean class acc')


coverage = np.array([0.2585, 0.4878, 0.6502, 0.7676, 0.8677, 0.9614])
risk = np.array([0.0164, 0.0474, 0.0988, 0.1475, 0.1955, 0.2451])
ax.scatter(coverage, risk, label='Geifman 2017')
ax.legend()

# vs. Feng 2023 "Towards better selective classification" on ImageNet-100

In [None]:
imagenet100_idx = []
with open(path_imagenet100_id) as f:
    for line in f:
        imagenet100_idx.append(id_to_idx[line[:-1]])
        
idx_classes_imagenet100 = df_val.index[df_val['label'].isin(imagenet100_idx)]
df_val.loc[idx_classes_imagenet100, 'well_classified'].mean()

In [None]:
fig, (ax) = plt.subplots(1, 1, figsize=(10, 5))
ax.set_title(f'{MODEL} - coverage vs. risk - ImageNet-100')

# baseline: max softmax
domain_cutoff_baseline = np.linspace(0, 1, 100)
coverage_baseline = np.zeros_like(domain_cutoff_baseline)
risk_baseline = np.zeros_like(domain_cutoff_baseline)
acc_baseline = np.zeros_like(domain_cutoff_baseline)
for i, cut in enumerate(domain_cutoff_baseline):
    idx_domain = (df_val['MSP'] > cut) & (df_val.index.isin(idx_classes_imagenet100))
    coverage_baseline[i] = idx_domain[idx_classes_imagenet100].mean()
    acc_baseline[i] = df_val.loc[idx_domain, 'well_classified'].mean()
# plot
sc = ax.scatter(coverage_baseline, 1 - acc_baseline, c=domain_cutoff_baseline, cmap='Greys', vmin=0, vmax=1)
fig.colorbar(sc, ax=ax, label='MSP threshold')
ax.set_xlabel('coverage')
ax.set_ylabel('risk')


domain_cutoff_baseline = np.linspace(0, 1, 100)
coverage_baseline = np.zeros_like(domain_cutoff_baseline)
risk_baseline = np.zeros_like(domain_cutoff_baseline)
acc_baseline = np.zeros_like(domain_cutoff_baseline)
for i, cut in enumerate(domain_cutoff_baseline):
    k = acc_per_class[targets]
    idx_domain = (k * df_val['MSP'] > cut) & (df_val.index.isin(idx_classes_imagenet100))
    coverage_baseline[i] = idx_domain[idx_classes_imagenet100].mean()
    acc_baseline[i] = df_val.loc[idx_domain, 'well_classified'].mean()
# plot
sc = ax.scatter(coverage_baseline, 1 - acc_baseline, c=domain_cutoff_baseline, cmap='Greens', vmin=0, vmax=1, marker='+')
fig.colorbar(sc, ax=ax, label='MSP threshold - select per class (acc3)')
ax.set_xlabel('coverage')
ax.set_ylabel('risk')


coverage = np.arange(100, 20, -10) / 100
risk_vanilla = np.array([14.32, 8.96, 4.99, 2.83, 1.7, 1.08, 0.77, 0.60]) / 100
risk_SN = np.array([13.77, 9.44, 6.0, 3.38, 1.99, 1.05, 0.58, 1.04]) / 100
risk_SNSR = np.array([13.77, 7.89, 4.47, 2.21, 1.57, 0.85, 0.53, 0.64]) / 100

ax.scatter(coverage, risk_vanilla, label='vanilla')
ax.scatter(coverage, risk_SN, label='SN')
ax.scatter(coverage, risk_SNSR, label='SNSR')
ax.legend()

### Check labels from ImageNet100 are correct and check if images come from validation set

In [None]:
imagenet100_labels = [idx_to_label[i] for i in imagenet100_idx]

class_chosen = [i for i in idx_to_label if idx_to_label[i] == 'Doberman, Doberman pinscher'][0]
imagenet100_idx_class_chosen = (df_val.index[np.array(dataset_val.targets) == class_chosen]).tolist()

dataset_imagenet100_val = Subset(dataset_val, imagenet100_idx_class_chosen)

figure = plt.figure(figsize=(10, 20))
cols, rows = 5, 10
for i in range(1, cols * rows + 1):
    sample_idx = i - 1
    img, label = dataset_imagenet100_val[sample_idx]
    pred = classifier(img.to(device).unsqueeze(0)).argmax(dim=1).item()
    figure.add_subplot(rows, cols, i)
    plt.title(f'real: {idx_to_label[label]}\npred: {idx_to_label[pred]}', fontsize=10)
    plt.axis("off")
    plt.imshow((0.2*img+0.4).permute(1, 2, 0), cmap="gray")
plt.show()

# Metrics

In [None]:
import sys
sys.path.append('./benchmarking-uncertainty-estimation-performance-main/utils')
from uncertainty_metrics import AUROC, coverage_for_desired_accuracy, ECE_calc, AURC_calc

def metrics_calculations(samples_certainties, num_bins_ece=15):
    # Note: we assume here the certainty scores in samples_certainties are probabilities.
    results = {}
    results['Accuracy'] = (samples_certainties[:,1].sum() / samples_certainties.shape[0]).item() * 100
    results['AUROC'] = AUROC(samples_certainties)
    results['Coverage_for_Accuracy_99'] = coverage_for_desired_accuracy(samples_certainties, accuracy=0.99, start_index=200)
    ece, mce = ECE_calc(samples_certainties, num_bins=num_bins_ece)
    results[f'ECE_{num_bins_ece}'] = ece.item()
    results['AURC'] = AURC_calc(samples_certainties)
    return results

In [None]:
samples_certainties = torch.cat((torch.tensor(df_val['MSP']).unsqueeze(1), torch.tensor(df_val['well_classified']).unsqueeze(1)), dim=1)
indices_sorting_by_confidence = torch.argsort(samples_certainties[:, 0], descending=True)
samples_certainties = samples_certainties[indices_sorting_by_confidence]
res = metrics_calculations(samples_certainties)
print('original', res)

accuracies = []
coverages = []
for acc in np.arange(res['Accuracy']/100+0.001, 1, 0.001):
    cov = coverage_for_desired_accuracy(samples_certainties, accuracy=acc, start_index=200)
    accuracies.append(acc)
    coverages.append(cov)
plt.figure()
plt.plot(coverages, accuracies, label='original')


use_pred_or_label = 'pred'
targets = df_val[use_pred_or_label].tolist()
k = acc_per_class[targets] / acc_per_class.max()

samples_certainties = torch.cat((torch.tensor(k * df_val['MSP']).unsqueeze(1), torch.tensor(df_val['well_classified']).unsqueeze(1)), dim=1)
indices_sorting_by_confidence = torch.argsort(samples_certainties[:, 0], descending=True)
samples_certainties = samples_certainties[indices_sorting_by_confidence]
res = metrics_calculations(samples_certainties)
print('after scaling MSP', res)

accuracies = []
coverages = []
for acc in np.arange(res['Accuracy']/100+0.001, 1, 0.001):
    cov = coverage_for_desired_accuracy(samples_certainties, accuracy=acc, start_index=200)
    accuracies.append(acc)
    coverages.append(cov)
plt.plot(coverages, accuracies, label='after scaling MSP')
plt.legend();

# Test threshold on logit

In [None]:
df = df_val

i = 0
for x, y in dataloader_val:
    with torch.no_grad():
        logits = classifier(x.to(device))
    max_logits = logits.max(axis=1).values
    msp = torch.softmax(logits, axis=1).max(axis=1).values
    df.loc[i:i+max_logits.shape[0]-1, 'max_logit'] = max_logits.cpu().numpy()
    df.loc[i:i+max_logits.shape[0]-1, 'msp_calc'] = msp.cpu().numpy()
    i += max_logits.shape[0]

In [None]:
# baseline: max softmax
domain_cutoff_baseline = np.linspace(0, 1, 1000)
coverage_baseline = np.zeros_like(domain_cutoff_baseline)
risk_baseline = np.zeros_like(domain_cutoff_baseline)
acc_baseline = np.zeros_like(domain_cutoff_baseline)
for i, cut in enumerate(domain_cutoff_baseline):
    idx_domain = df['msp_calc'] > cut
    coverage_baseline[i] = idx_domain.mean()
    acc_baseline[i] = df.loc[idx_domain, 'well_classified'].mean()

# max logit
domain_cutoff = np.linspace(df['max_logit'].min(), df['max_logit'].max(), 1000)
coverage = np.zeros_like(domain_cutoff)
risk_ = np.zeros_like(domain_cutoff)
acc = np.zeros_like(domain_cutoff)
for i, cut in enumerate(domain_cutoff):
    idx_domain = df['max_logit'] > cut
    coverage[i] = idx_domain.mean()
    acc[i] = df.loc[idx_domain, 'well_classified'].mean()

# plot
fig, (ax1) = plt.subplots(1, 1, figsize=(10, 5))
ax1.set_title(f'{MODEL} - coverage vs. accuracy')
sc1 = ax1.scatter(coverage_baseline, acc_baseline, c=domain_cutoff_baseline, cmap='viridis')
sc2 = ax1.scatter(coverage, acc, c=domain_cutoff, cmap='plasma')
fig.colorbar(sc1, ax=ax1, label='MSP threshold')
fig.colorbar(sc2, ax=ax1, label='max_logit threshold')
ax1.set_xlabel('coverage')
ax1.set_ylabel('accuracy')