In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

from tqdm import tqdm
import torch
from torch import nn
from torchvision.models import resnet50, ResNet50_Weights
from torchvision import datasets
from torch.utils.data import DataLoader, Subset
from torchvision.transforms import Compose, Resize, Lambda, ToTensor, Grayscale, ToPILImage
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('seaborn')
import PIL

from imagenet_c import corrupt
from CLIP import clip as clip_utils
from utils import load_datasets_ImageNet_two_transforms

torch.manual_seed(0)
np.random.seed(0)
rng = np.random.default_rng(0)

path_results = os.path.dirname(os.getcwd()) + '/results'
path_dataset = os.path.expandvars('$DSDIR/imagenet') # '/scratchf/'
path_imagenet_labels = os.path.expandvars('$WORK/DATA/LOC_synset_mapping.txt')
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
BATCH_SIZE = 1024

In [None]:
# LOAD CLASSIFIER
weights = ResNet50_Weights.IMAGENET1K_V2
classifier = resnet50(weights=weights).to(device)
preprocess_classif = weights.transforms()
classifier.eval()

# LOAD CLIP
clip, preprocess_clip = clip_utils.load("ViT-B/32", device=device)

# LOAD DATA
dataset_train, dataset_val = load_datasets_ImageNet_two_transforms(path_dataset, BATCH_SIZE, preprocess_classif, preprocess_clip)
dataloader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True, shuffle=False)   
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True, shuffle=False)
idx_to_label = {}
with open(path_imagenet_labels) as f:
    for i, line in enumerate(f):
       idx_to_label[i] = line[10:-1]

In [None]:
if os.path.exists(path_results + '/logs_classif_val.csv'):
    print('Load logs.')
    df = pd.read_csv(path_results + '/logs_classif_val.csv', index_col=0)

else:
    print('Compute classification logs.')
    df = pd.DataFrame(columns=['MSP', 'TCP', 'well_classified', 'pred_in_top5'], index=pd.RangeIndex(len(dataset_val)))
    idx = 0
    for batch in tqdm(dataloader_val):
        (x_classif, y_classif), (x_clip, y_clip) = batch
        assert (y_classif == y_clip).all(), "data for classifier and CLIP is not the same"
        batch_size = x_classif.shape[0]
        x_classif = x_classif.to(device)
        y_classif = y_classif.to(device)

        with torch.no_grad():
            logits = classifier(x_classif)
        probas = torch.nn.functional.softmax(logits, dim=1)
        pred_top5 = logits.topk(5, dim=1).indices.t()

        df.loc[idx:idx+batch_size-1, 'MSP'] = probas.max(1).values.cpu()
        df.loc[idx:idx+batch_size-1, 'TCP'] = probas[torch.arange(batch_size), y_classif].cpu()
        df.loc[idx:idx+batch_size-1, 'well_classified'] = (logits.argmax(dim=1) == y_classif).cpu()
        df.loc[idx:idx+batch_size-1, 'pred_in_top5'] = pred_top5.eq(y_classif.expand_as(pred_top5)).sum(0).bool().cpu()

        idx += batch_size

    df.to_csv(path_results + '/logs_classif_val.csv')
    print('Classification logs saved')

In [None]:
idx_errors_top1 = df.index[df['well_classified'] == False].tolist()
idx_errors_top5 = df.index[df['pred_in_top5'] == False].tolist()

dataset_val_errors_top1 = Subset(dataset_val, idx_errors_top1)
dataset_val_errors_top5 = Subset(dataset_val, idx_errors_top5)

# Examples of misclassified, pred not in top 5

In [None]:
figure = plt.figure(figsize=(12, 12))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(dataset_val_errors_top5), size=(1,)).item()
    (img, label), _ = dataset_val_errors_top5[sample_idx]
    pred = classifier(img.to(device).unsqueeze(0)).argmax(dim=1).item()
    figure.add_subplot(rows, cols, i)
    plt.title(f'real: {idx_to_label[label]}\npred: {idx_to_label[pred]}', fontsize=10)
    plt.axis("off")
    plt.imshow((0.2*img+0.4).permute(1, 2, 0), cmap="gray")
plt.show()

# Selective classification

In [None]:
# baseline: max softmax
domain_cutoff_baseline = np.linspace(0, 1, 1000)
coverage_baseline = np.zeros_like(domain_cutoff_baseline)
risk_baseline = np.zeros_like(domain_cutoff_baseline)
acc_baseline = np.zeros_like(domain_cutoff_baseline)
for i, cut in enumerate(domain_cutoff_baseline):
    idx_domain = df['MSP'] > cut
    coverage_baseline[i] = idx_domain.mean()
    acc_baseline[i] = df.loc[idx_domain, 'well_classified'].mean()

# plot
fig, (ax1) = plt.subplots(1, 1, figsize=(10, 5))
ax1.set_title(f'coverage vs. accuracy')
sc = ax1.scatter(coverage_baseline, acc_baseline, c=domain_cutoff_baseline, cmap='viridis')
fig.colorbar(sc, ax=ax1, label='MSP threshold')
ax1.set_xlabel('coverage')
ax1.set_ylabel('accuracy')

# Per class

In [None]:
acc_per_class = []
mean_msp_per_class = []
for class_idx in range(1000):
    acc_per_class.append(df.loc[np.array(dataset_val.imagenet_data_1.targets) == class_idx, 'well_classified'].mean())
    mean_msp_per_class.append(df.loc[np.array(dataset_val.imagenet_data_1.targets) == class_idx, 'MSP'].mean())

plt.figure()
plt.hist(acc_per_class)
plt.xlabel('accuracy per class')

plt.figure()
plt.hist(mean_msp_per_class)
plt.xlabel('mean MSP per class')

plt.figure()
plt.scatter(mean_msp_per_class, acc_per_class)