In [1]:
import torch
import torch.nn as nn
from sklearn.metrics import log_loss

In [2]:
import numpy as np

def set_seed(seed_value, use_cuda):
    if seed_value is not None:
        np.random.seed(seed_value)  # cpu vars
        torch.manual_seed(seed_value)  # cpu  vars
        random.seed(seed_value)  # Python
        torch.use_deterministic_algorithms(True)

        if use_cuda:
            torch.cuda.manual_seed(seed_value)
            torch.cuda.manual_seed_all(seed_value)
            torch.backends.cudnn.deterministic = True

In [3]:
import os, sys, random, argparse, os.path as osp
import pandas as pd
import numpy as np
import torch
import medmnist
from utils.data_handling import get_medmnist_loaders, get_medmnist_test_loader
from utils.data_handling import get_class_loaders, get_class_test_loader
from utils.evaluation import evaluate_cls
from utils.get_model_v2 import get_arch
from tqdm import trange, tqdm
import torchvision.transforms as tr

In [4]:
from torchmetrics.classification import MulticlassCalibrationError
from utils.data_handling import get_class_test_loader
from tqdm.notebook import tqdm


from sklearn.metrics import accuracy_score as acc
from sklearn.metrics import roc_auc_score as auc
from sklearn.metrics import log_loss as error

# from utils.temperature_calibration import ModelWithTemperature
from utils.data_handling import get_medmnist_loaders
from sklearn.metrics import balanced_accuracy_score as balacc
from sklearn.metrics import matthews_corrcoef as mcc

In [5]:
def test_one_epoch(model, loader, device):
    model.to(device)
    model.eval()
    probs_all, labels_all = [], []

    for i_batch, (inputs, labels) in enumerate(loader):
        inputs = inputs.to(device)
        logits = model(inputs)  # bs x n_classes
        probs = logits.softmax(dim=1).detach().cpu().numpy()
        labels = labels.numpy()
        if labels.ndim == 0:  # for 1-element batches labels degenerates to a scalar
            labels = np.expand_dims(labels, 0)
        probs_all.extend(probs)
        labels_all.extend(list(labels))

    return np.stack(probs_all), np.array(labels_all).squeeze()

def test_cls(model, test_loader, device):

    with torch.inference_mode():
        probs, labels = test_one_epoch(model, test_loader, device)

    del model
    torch.cuda.empty_cache()

    return probs, labels

In [6]:
def test_one_epoch_multihead(model, loader, device):
    ## IMPORTANT: Note that multi-head models in test time return softmax-activated tensors and not logits
    ## which is shitty because I cannot apply TS directly
    model.to(device)
    model.eval()

    probs_all, labels_all = [], []

    for i_batch, (inputs, labels) in enumerate(loader):
        inputs = inputs.to(device)
        probs = model(inputs).detach().cpu().numpy()

        labels = labels.numpy()
        if labels.ndim == 0:  # for 1-element batches labels degenerates to a scalar
            labels = np.expand_dims(labels, 0)
        probs_all.extend(probs)
        labels_all.extend(list(labels))

    return np.stack(probs_all), np.array(labels_all).squeeze()

def test_cls_multihead(model, test_loader, device):
    with torch.inference_mode():
        probs, labels = test_one_epoch_multihead(model, test_loader, device)

    del model
    torch.cuda.empty_cache()

    return probs, labels

In [7]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

In [8]:
from sklearn.metrics import accuracy_score as acc

In [9]:
use_cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if use_cuda else 'cpu')

In [10]:
def print_results_multi_fold(dataset=None, model_name='convnext', n_bins=15, method='sl1h', with_ens=False):
    assert dataset is not None
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda:0' if use_cuda else 'cpu')
    seed=0
    set_seed(seed, use_cuda)
    if model_name in ['convnext', 'swin'] :
        torch.use_deterministic_algorithms(False)
        torch.backends.cudnn.deterministic = True

    assert dataset in ['chaoyang', 'kvasir', 'pathmnist']

    if dataset=='pathmnist':
        tg_size = 28 if model_name !='convnext' else 32
        test_loader = get_medmnist_test_loader(dataset, batch_size=128, num_workers=6, tg_size=tg_size)
        class_names = list(medmnist.INFO[dataset]['label'].keys())
        num_classes = len(class_names)
            
    else:
        data_path = osp.join('data', dataset)
        csv_test = osp.join('data', 'test_'+dataset+'.csv')
        tg_size = 224,224
        test_loader  = get_class_test_loader(csv_test, data_path, tg_size, batch_size=64, num_workers=6)
        num_classes = len(test_loader.dataset.classes)
        
    load_path = osp.join('experiments', dataset, method)
    
    ########################
    # results for individual models
    ########################    
    if model_name=='convnext':
        load_path_this = osp.join(load_path, 'cnx_f')        
    elif model_name=='resnet50':
        load_path_this = osp.join(load_path, 'r50_f')
    elif model_name=='swin':
        load_path_this = osp.join(load_path, 'swt_f')
    model = get_arch(model_name, num_classes)
    checkpoint_list = [osp.join(load_path_this + str(i), 'model_checkpoint.pth') for i in [0, 1, 2, 3, 4]]
    states = [torch.load(c,  map_location=device) for c in checkpoint_list]
    
    all_probs = []
    # Do inference for each model
    with torch.inference_mode():
        for i in range(len(states)):
            state=states[i]
            model.load_state_dict(state['model_state_dict'])
            probs, labels = test_cls(model, test_loader, device)
            all_probs.append(probs)
    
    ece = MulticlassCalibrationError(num_classes=num_classes, n_bins=n_bins, norm='l1')
    

    accs, eces, ces= [],[],[]
    for probs in all_probs:
        preds = np.argmax(probs, axis=1)
        test_auc, test_f1, test_acc, test_auc_all, test_f1_all = evaluate_cls(labels, preds, probs, 
                                                                              print_conf=False)
        e = ece(torch.from_numpy(probs), torch.from_numpy(labels)).item()        
        ce = log_loss(labels,probs)
        
        accs.append(100*test_acc)
        eces.append(100*e)
        ces.append(100*ce)
 
    ########################
    # print average results
    ######################## 
    print('ACC={:.2f}+/-{:.2f}, ECE={:.2f}+/-{:.2f}, NLL={:.2f}+/-{:.2f}'.format(
        np.mean(accs), np.std(accs), np.mean(eces), np.std(eces), np.mean(ces), np.std(ces)))

    if with_ens:
        ########################
        # results for ensemble
        ########################  
        ens_probs = np.mean(all_probs, axis=0)
        ens_preds = np.argmax(ens_probs, axis=1)
        test_auc, test_f1, test_acc, test_auc_all, test_f1_all = evaluate_cls(labels, ens_preds, ens_probs, 
                                                                              print_conf=False)
        e = ece(torch.from_numpy(ens_probs), torch.from_numpy(labels)).item()
        ce = log_loss(labels,ens_probs)

        print(30*'-')
        print('DEEP ENSEMBLES:')
        print('ACC={:.2f}, ECE={:.2f}, NLL={:.2f}'.format(100*test_acc,100*e,100*ce))
        
    mtrcs = [np.mean(accs), np.mean(eces), np.mean(ces), 
             np.std(accs),  np.std(eces),  np.std(ces)]

    if with_ens:
        return mtrcs, [100*test_acc, 100*e, 100*ce]
    return mtrcs

In [11]:
def print_results_multi_head(dataset=None, model_name='convnext', method='4lmh', n_bins=15, nh=2):
    assert dataset is not None
    
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda:0' if use_cuda else 'cpu')
    # reproducibility
    seed=0
    set_seed(seed, use_cuda)
    if model_name in ['convnext', 'swin'] :
        torch.use_deterministic_algorithms(False)
        torch.backends.cudnn.deterministic = True

    assert dataset in ['chaoyang', 'kvasir', 'pathmnist']
    

    if dataset=='pathmnist':
        tg_size = 28 if model_name !='convnext' else 32
        test_loader = get_medmnist_test_loader(dataset, batch_size=128, num_workers=6, tg_size=tg_size)
        class_names = list(medmnist.INFO[dataset]['label'].keys())
        num_classes = len(class_names)

    else:
        data_path = osp.join('data', dataset)
        csv_test = osp.join('data', 'test_'+dataset+'.csv')
        tg_size = 224,224   
        test_loader  = get_class_test_loader(csv_test, data_path, tg_size, batch_size=64, num_workers=6)
        num_classes = len(test_loader.dataset.classes)
        class_names = ['C{}_probs'.format(i) for i in range(num_classes)]

    
    load_path = osp.join('experiments', dataset, method)
    if '2h' in method: nh = 2
    elif '4h' in method: nh = 4
    else: return 'need multi-head model here'
    
    if model_name=='convnext':
        load_path_this = osp.join(load_path, 'cnx_f')
    elif model_name=='resnet50':
        load_path_this = osp.join(load_path, 'r50_f')
    elif model_name=='swin':
        load_path_this = osp.join(load_path, 'swt_f')

    ########################
    # results for multi-head
    ########################
    model = get_arch(model_name, num_classes, n_heads=nh)
    checkpoint_list = [osp.join(load_path_this + str(i), 'model_checkpoint.pth') for i in [0, 1, 2, 3, 4]]
    states = [torch.load(c,  map_location=device) for c in checkpoint_list]
    all_probs = []
    # Do inference on val data
    with torch.inference_mode():
        for i in range(len(states)):
            state=states[i]        
            model.load_state_dict(state['model_state_dict'])
            probs, labels = test_cls_multihead(model, test_loader, device)
            all_probs.append(probs)
    
    ece = MulticlassCalibrationError(num_classes=num_classes, n_bins=15, norm='l1')

    
    accs, eces, ces= [],[],[]
    for probs in all_probs:
        preds = np.argmax(probs, axis=1)
        test_auc, test_f1, test_acc, test_auc_all, test_f1_all = evaluate_cls(labels, preds, probs, 
                                                                              print_conf=False)
        e = ece(torch.from_numpy(probs), torch.from_numpy(labels)).item()
        ce = log_loss(labels,probs)                
        accs.append(100*test_acc)
        eces.append(100*e)
        ces.append(100*ce)
        
    ########################
    # print average results
    ######################## 
    print('ACC={:.2f}+/-{:.2f}, ECE={:.2f}+/-{:.2f}, NLL={:.2f}+/-{:.2f}'.format(
    np.mean(accs), np.std(accs), np.mean(eces), np.std(eces), np.mean(ces), np.std(ces)))
    
    mtrcs = [np.mean(accs), np.mean(eces), np.mean(ces), 
             np.std(accs),  np.std(eces),  np.std(ces)]
    return mtrcs

In [12]:
def print_all(dataset, model_name):
    metrcs = dict()
    print(60*'=')
    print('Single Head (avg) and Ensembles')
    print(60*'=')

    metrcs['sl1h'], metrcs['ens'] = \
    print_results_multi_fold(dataset, model_name, with_ens=True)

    # LS
    print('\n')
    print(60*'=')
    print('LS gamma=0.05: Single Head (avg) and Ensembles')
    print(60*'=')
    metrcs['ls_005'] = \
    print_results_multi_fold(dataset, model_name, method='ls_005', with_ens=False)

    # MbLS
    print('\n')
    print(60*'=')
    print('MbLS m=6: Single Head (avg) and Ensembles')
    print(60*'=')
    metrcs['mbls_6'] = \
    print_results_multi_fold(dataset, model_name, method='mbls_6', with_ens=False)
    
    # MIXUP
    print('\n')   
    print(60*'=')
    print('MixUp gamma=0.2: Single Head (avg) and Ensembles')
    print(60*'=')
    metrcs['mxp_02'] = \
    print_results_multi_fold(dataset, model_name, method='mxp_02', with_ens=False)

    # DCA
    print('\n')   
    print(60*'=')
    print('DCA: Single Head (avg) and Ensembles')
    print(60*'=')
    metrcs['dca'] = \
    print_results_multi_fold(dataset, model_name, method='dca', with_ens=False)

    # 2hsl
    print('\n') 
    print(60*'=')
    print('2hsl')
    print(60*'=')
    metrcs['2hsl'] = \
    print_results_multi_head(dataset, model_name, method='2hsl')

    # 2hml
    print(60*'*')
    print('2hml')
    print(60*'*')
    metrcs['2hml'] = \
    print_results_multi_head(dataset, model_name, method='2hml')

    # 4hml   
    print(60*'*')
    print('4hml')
    print(60*'*')
    metrcs['4hml'] = \
    print_results_multi_head(dataset, model_name, method='4hml')
    
    return metrcs

In [25]:
def fill_rows(col_model_dict, col1, col2, col3, col4, s1, s2, s3, strings):
    sl1h_str, DE_str, LS_str, MbLS_str, Mxp_str, DCA_str, h2sl_str, h2ml_str, h4ml_str = strings
    a, b, c = 0, 1, 2
    x, y, z = 3, 4, 5
    sl1h_str=sl1h_str.replace(col1, '{:.2f}'.format(col_model_dict['sl1h'][a]))
    sl1h_str=sl1h_str.replace(col2, '{:.2f}'.format(col_model_dict['sl1h'][b]))
    sl1h_str=sl1h_str.replace(col3,  '{:.2f}'.format(col_model_dict['sl1h'][c]))
    sl1h_str=sl1h_str.replace(col4,  '{:.1f}'.format(col_model_dict['sl1h'][-1]))
    sl1h_str=sl1h_str.replace(s1, '{:.2f}'.format(col_model_dict['sl1h'][x]))
    sl1h_str=sl1h_str.replace(s2, '{:.2f}'.format(col_model_dict['sl1h'][y]))
    sl1h_str=sl1h_str.replace(s3,  '{:.2f}'.format(col_model_dict['sl1h'][z]))
    
    DE_str=DE_str.replace(col1, '{:.2f}'.format(col_model_dict['ens'][a]))
    DE_str=DE_str.replace(col2, '{:.2f}'.format(col_model_dict['ens'][b]))
    DE_str=DE_str.replace(col3,  '{:.2f}'.format(col_model_dict['ens'][c]))
    DE_str=DE_str.replace(col4,  '{:.1f}'.format(col_model_dict['ens'][-1]))

    LS_str=LS_str.replace(col1, '{:.2f}'.format(col_model_dict['ls_005'][a]))
    LS_str=LS_str.replace(col2, '{:.2f}'.format(col_model_dict['ls_005'][b]))
    LS_str=LS_str.replace(col3,  '{:.2f}'.format(col_model_dict['ls_005'][c]))
    LS_str=LS_str.replace(col4,  '{:.1f}'.format(col_model_dict['ls_005'][-1]))
    LS_str=LS_str.replace(s1, '{:.2f}'.format(col_model_dict['ls_005'][x]))
    LS_str=LS_str.replace(s2, '{:.2f}'.format(col_model_dict['ls_005'][y]))
    LS_str=LS_str.replace(s3,  '{:.2f}'.format(col_model_dict['ls_005'][z]))    

    MbLS_str=MbLS_str.replace(col1, '{:.2f}'.format(col_model_dict['mbls_6'][a]))
    MbLS_str=MbLS_str.replace(col2, '{:.2f}'.format(col_model_dict['mbls_6'][b]))
    MbLS_str=MbLS_str.replace(col3,  '{:.2f}'.format(col_model_dict['mbls_6'][c]))
    MbLS_str=MbLS_str.replace(col4,  '{:.1f}'.format(col_model_dict['mbls_6'][-1]))
    MbLS_str=MbLS_str.replace(s1, '{:.2f}'.format(col_model_dict['mbls_6'][x]))
    MbLS_str=MbLS_str.replace(s2, '{:.2f}'.format(col_model_dict['mbls_6'][y]))
    MbLS_str=MbLS_str.replace(s3,  '{:.2f}'.format(col_model_dict['mbls_6'][z]))   
    
    Mxp_str=Mxp_str.replace(col1, '{:.2f}'.format(col_model_dict['mxp_02'][a]))
    Mxp_str=Mxp_str.replace(col2, '{:.2f}'.format(col_model_dict['mxp_02'][b]))
    Mxp_str=Mxp_str.replace(col3,  '{:.2f}'.format(col_model_dict['mxp_02'][c]))
    Mxp_str=Mxp_str.replace(col4,  '{:.1f}'.format(col_model_dict['mxp_02'][-1]))
    Mxp_str=Mxp_str.replace(s1, '{:.2f}'.format(col_model_dict['mxp_02'][x]))
    Mxp_str=Mxp_str.replace(s2, '{:.2f}'.format(col_model_dict['mxp_02'][y]))
    Mxp_str=Mxp_str.replace(s3,  '{:.2f}'.format(col_model_dict['mxp_02'][z]))   
    
    DCA_str=DCA_str.replace(col1, '{:.2f}'.format(col_model_dict['dca'][a]))
    DCA_str=DCA_str.replace(col2, '{:.2f}'.format(col_model_dict['dca'][b]))
    DCA_str=DCA_str.replace(col3,  '{:.2f}'.format(col_model_dict['dca'][c]))
    DCA_str=DCA_str.replace(col4,  '{:.1f}'.format(col_model_dict['dca'][-1]))
    DCA_str=DCA_str.replace(s1, '{:.2f}'.format(col_model_dict['dca'][x]))
    DCA_str=DCA_str.replace(s2, '{:.2f}'.format(col_model_dict['dca'][y]))
    DCA_str=DCA_str.replace(s3,  '{:.2f}'.format(col_model_dict['dca'][z]))   
    
    h2sl_str=h2sl_str.replace(col1, '{:.2f}'.format(col_model_dict['2hsl'][a]))
    h2sl_str=h2sl_str.replace(col2, '{:.2f}'.format(col_model_dict['2hsl'][b]))
    h2sl_str=h2sl_str.replace(col3,  '{:.2f}'.format(col_model_dict['2hsl'][c]))
    h2sl_str=h2sl_str.replace(col4,  '{:.1f}'.format(col_model_dict['2hsl'][-1]))
    h2sl_str=h2sl_str.replace(s1, '{:.2f}'.format(col_model_dict['2hsl'][x]))
    h2sl_str=h2sl_str.replace(s2, '{:.2f}'.format(col_model_dict['2hsl'][y]))
    h2sl_str=h2sl_str.replace(s3,  '{:.2f}'.format(col_model_dict['2hsl'][z]))  
    
    h2ml_str=h2ml_str.replace(col1, '{:.2f}'.format(col_model_dict['2hml'][a]))
    h2ml_str=h2ml_str.replace(col2, '{:.2f}'.format(col_model_dict['2hml'][b]))
    h2ml_str=h2ml_str.replace(col3,  '{:.2f}'.format(col_model_dict['2hml'][c]))
    h2ml_str=h2ml_str.replace(col4,  '{:.1f}'.format(col_model_dict['2hml'][-1]))
    h2ml_str=h2ml_str.replace(s1, '{:.2f}'.format(col_model_dict['2hml'][x]))
    h2ml_str=h2ml_str.replace(s2, '{:.2f}'.format(col_model_dict['2hml'][y]))
    h2ml_str=h2ml_str.replace(s3,  '{:.2f}'.format(col_model_dict['2hml'][z]))
    
    h4ml_str=h4ml_str.replace(col1, '{:.2f}'.format(col_model_dict['4hml'][a]))
    h4ml_str=h4ml_str.replace(col2, '{:.2f}'.format(col_model_dict['4hml'][b]))
    h4ml_str=h4ml_str.replace(col3,  '{:.2f}'.format(col_model_dict['4hml'][c])) 
    h4ml_str=h4ml_str.replace(col4,  '{:.1f}'.format(col_model_dict['4hml'][-1])) 
    h4ml_str=h4ml_str.replace(s1, '{:.2f}'.format(col_model_dict['4hml'][x]))
    h4ml_str=h4ml_str.replace(s2, '{:.2f}'.format(col_model_dict['4hml'][y]))
    h4ml_str=h4ml_str.replace(s3,  '{:.2f}'.format(col_model_dict['4hml'][z]))   
    
    return sl1h_str, DE_str, LS_str, MbLS_str, Mxp_str, DCA_str, h2sl_str, h2ml_str, h4ml_str

In [14]:
def add_acc_ece_nll_combined_ranks(dic):
    dic_acc = {key:value[0] for (key,value) in dic.items()} #acc
    # False for lower is better, True for higher is better
    dic_rank_acc = {key: rank for rank, key in enumerate(sorted(dic_acc, key=dic_acc.get, reverse=True), 1)}
       
    dic_ece = {key:value[1] for (key,value) in dic.items()}
    dic_rank_ece = {key: rank for rank, key in enumerate(sorted(dic_ece, key=dic_ece.get, reverse=False), 1)}

    dic_nll = {key:value[2] for (key,value) in dic.items()}
    # False for lower is better, True for higher is better
    dic_rank_nll = {key: rank for rank, key in enumerate(sorted(dic_nll, key=dic_nll.get, reverse=False), 1)}

    
    for key in dic:
        dic[key].extend([dic_rank_acc[key], dic_rank_ece[key], dic_rank_nll[key],
                         np.mean([dic_rank_acc[key], dic_rank_ece[key], dic_rank_nll[key]])])    
    return dic

In [15]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

# CHAOYANG

In [16]:
dataset='chaoyang'

In [17]:
m_r50 = \
print_all(dataset, 'resnet50') 

Single Head (avg) and Ensembles
ACC=80.71+/-0.64, ECE=5.79+/-1.75, NLL=53.46+/-3.24
------------------------------
DEEP ENSEMBLES:
ACC=82.19, ECE=2.42, NLL=46.64


LS gamma=0.05: Single Head (avg) and Ensembles
ACC=74.81+/-1.88, ECE=2.55+/-0.49, NLL=64.27+/-2.82


MbLS m=6: Single Head (avg) and Ensembles
ACC=75.02+/-1.97, ECE=3.26+/-0.97, NLL=63.86+/-4.29


MixUp gamma=0.2: Single Head (avg) and Ensembles
ACC=76.00+/-1.51, ECE=3.67+/-0.73, NLL=62.72+/-3.00


DCA: Single Head (avg) and Ensembles
ACC=76.17+/-0.87, ECE=5.75+/-1.68, NLL=62.13+/-2.30


2hsl
ACC=80.97+/-0.87, ECE=4.36+/-1.88, NLL=51.42+/-2.52
************************************************************
2hml
************************************************************
ACC=80.28+/-0.81, ECE=4.49+/-1.29, NLL=51.86+/-2.17
************************************************************
4hml
************************************************************
ACC=81.13+/-0.82, ECE=3.09+/-0.76, NLL=49.44+/-0.96


In [18]:
m_cvx = \
print_all(dataset, 'convnext')

Single Head (avg) and Ensembles
ACC=81.91+/-0.46, ECE=6.94+/-1.49, NLL=50.98+/-3.44
------------------------------
DEEP ENSEMBLES:
ACC=82.98, ECE=5.21, NLL=46.08


LS gamma=0.05: Single Head (avg) and Ensembles
ACC=79.59+/-0.62, ECE=6.13+/-2.17, NLL=55.65+/-0.79


MbLS m=6: Single Head (avg) and Ensembles
ACC=79.53+/-0.93, ECE=2.94+/-0.76, NLL=53.44+/-1.76


MixUp gamma=0.2: Single Head (avg) and Ensembles
ACC=79.95+/-0.94, ECE=6.20+/-2.28, NLL=55.58+/-1.07


DCA: Single Head (avg) and Ensembles
ACC=78.28+/-0.62, ECE=3.69+/-0.73, NLL=57.78+/-2.28


2hsl
ACC=81.94+/-0.29, ECE=4.30+/-1.17, NLL=46.71+/-0.83
************************************************************
2hml
************************************************************
ACC=81.97+/-0.31, ECE=3.66+/-1.15, NLL=45.96+/-0.89
************************************************************
4hml
************************************************************
ACC=82.17+/-0.21, ECE=1.79+/-0.30, NLL=44.73+/-0.51


In [19]:
m_swt = \
print_all(dataset, 'swin') 

Single Head (avg) and Ensembles
ACC=83.09+/-0.51, ECE=8.73+/-0.45, NLL=52.75+/-1.81
------------------------------
DEEP ENSEMBLES:
ACC=83.50, ECE=6.79, NLL=44.80


LS gamma=0.05: Single Head (avg) and Ensembles
ACC=79.76+/-0.73, ECE=3.98+/-0.88, NLL=55.37+/-2.18


MbLS m=6: Single Head (avg) and Ensembles
ACC=80.24+/-0.73, ECE=5.06+/-1.65, NLL=54.18+/-2.64


MixUp gamma=0.2: Single Head (avg) and Ensembles
ACC=80.25+/-0.54, ECE=3.89+/-0.71, NLL=54.62+/-1.87


DCA: Single Head (avg) and Ensembles
ACC=79.12+/-0.79, ECE=7.91+/-1.69, NLL=59.91+/-3.56


2hsl
ACC=82.90+/-0.63, ECE=8.20+/-1.76, NLL=54.19+/-8.04
************************************************************
2hml
************************************************************
ACC=82.79+/-0.43, ECE=5.01+/-1.24, NLL=46.12+/-0.88
************************************************************
4hml
************************************************************
ACC=82.89+/-0.44, ECE=4.80+/-1.58, NLL=46.70+/-1.84


In [20]:
# m_r18_with_ranks = add_acc_ece_combined_ranks(m_r18)
m_r50_with_ranks = add_acc_ece_nll_combined_ranks(m_r50)
m_cvx_with_ranks = add_acc_ece_nll_combined_ranks(m_cvx)
m_swt_with_ranks = add_acc_ece_nll_combined_ranks(m_swt)

In [21]:
sl1h_str  ='\\textbf{SL1H}     & xx1$\pm$ss1 & xx2$\pm$ss2 & xx3$\pm$ss3 & xx4 & xx5$\pm$ss5 & xx6$\pm$ss6 & xx7$\pm$ss7 & xx8 & xx9$\pm$ss9 & x10$\pm$s10 & x11$\pm$s11 & x12 \\\\'
DE_str    ='\\textbf{D-Ens}    & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
LS_str    ='\\textbf{LS}       & xx1$\pm$ss1 & xx2$\pm$ss2 & xx3$\pm$ss3 & xx4 & xx5$\pm$ss5 & xx6$\pm$ss6 & xx7$\pm$ss7 & xx8 & xx9$\pm$ss9 & x10$\pm$s10 & x11$\pm$s11 & x12 \\\\'
MbLS_str  ='\\textbf{MbLS}     & xx1$\pm$ss1 & xx2$\pm$ss2 & xx3$\pm$ss3 & xx4 & xx5$\pm$ss5 & xx6$\pm$ss6 & xx7$\pm$ss7 & xx8 & xx9$\pm$ss9 & x10$\pm$s10 & x11$\pm$s11 & x12 \\\\'
Mxp_str   ='\\textbf{MixUp}    & xx1$\pm$ss1 & xx2$\pm$ss2 & xx3$\pm$ss3 & xx4 & xx5$\pm$ss5 & xx6$\pm$ss6 & xx7$\pm$ss7 & xx8 & xx9$\pm$ss9 & x10$\pm$s10 & x11$\pm$s11 & x12 \\\\'
DCA_str   ='\\textbf{DCA}      & xx1$\pm$ss1 & xx2$\pm$ss2 & xx3$\pm$ss3 & xx4 & xx5$\pm$ss5 & xx6$\pm$ss6 & xx7$\pm$ss7 & xx8 & xx9$\pm$ss9 & x10$\pm$s10 & x11$\pm$s11 & x12 \\\\'
h2sl_str  ='\\textbf{2HSL}     & xx1$\pm$ss1 & xx2$\pm$ss2 & xx3$\pm$ss3 & xx4 & xx5$\pm$ss5 & xx6$\pm$ss6 & xx7$\pm$ss7 & xx8 & xx9$\pm$ss9 & x10$\pm$s10 & x11$\pm$s11 & x12 \\\\'
h2ml_str   ='\\textbf{2HML}      & xx1$\pm$ss1 & xx2$\pm$ss2 & xx3$\pm$ss3 & xx4 & xx5$\pm$ss5 & xx6$\pm$ss6 & xx7$\pm$ss7 & xx8 & xx9$\pm$ss9 & x10$\pm$s10 & x11$\pm$s11 & x12 \\\\'
h4ml_str   ='\\textbf{4HML}      & xx1$\pm$ss1 & xx2$\pm$ss2 & xx3$\pm$ss3 & xx4 & xx5$\pm$ss5 & xx6$\pm$ss6 & xx7$\pm$ss7 & xx8 & xx9$\pm$ss9 & x10$\pm$s10 & x11$\pm$s11 & x12 \\\\'

In [26]:
strings = [sl1h_str, DE_str, LS_str, MbLS_str, Mxp_str, DCA_str, h2sl_str, h2ml_str, h4ml_str]

strings = fill_rows(m_r50, 'xx1', 'xx2', 'xx3', 'xx4', 'ss1', 'ss2', 'ss3', strings)
strings = fill_rows(m_cvx, 'xx5', 'xx6', 'xx7', 'xx8', 'ss5', 'ss6', 'ss7', strings)
strings = fill_rows(m_swt, 'xx9', 'x10', 'x11', 'x12', 'ss9', 's10', 's11', strings)

sl1h_str, DE_str, LS_str, MbLS_str, Mxp_str, DCA_str, h2sl_str, h2ml_str, h4ml_str = strings

In [27]:
caption = r'\caption{Results on the \textbf{Chaoyang dataset} , with standard deviation for 5 training runs.\
For each model, \unl{\textbf{best}} and \textbf{second best} ranks are marked.}\label{chaoyang_dispersion}'

In [28]:
print('\\begin{sidewaystable}')
print('\\renewcommand{\\arraystretch}{1.03}')
print('\\setlength\\tabcolsep{1.00pt}')
print('\\centering')
print(caption)
print('\\smallskip')
print('\\begin{tabular}{c cccc cccc cccc}')
print('& \\multicolumn{4}{c}{\\textbf{ResNet50}} & \\multicolumn{4}{c}{\\textbf{ConvNeXt}} & \\multicolumn{4}{c}{\\textbf{Swin-Transformer}} \\\\')
print('\\cmidrule(lr){2-5} \\cmidrule(lr){6-9} \\cmidrule(lr){10-13} &  ACC$^\\uparrow$  &  ECE$_\\downarrow$  &  NLL$_\\downarrow$    &  Rank$_\\downarrow$  &  ACC$^\\uparrow$  &  ECE$_\\downarrow$  &  NLL$_\\downarrow$    &  Rank$_\\downarrow$    &  ACC$^\\uparrow$ &  ECE$_\\downarrow$  &  NLL$_\\downarrow$    &  Rank$_\\downarrow$\\\\')
print('\\midrule')
print(sl1h_str)
print('\midrule')
print(LS_str)
print('\midrule')
print(MbLS_str)
print('\midrule')
print(Mxp_str)
print('\midrule')
print(DCA_str)
print('\midrule')
print('\midrule')
print(DE_str)
print('\midrule')
print('\midrule')
print(h2sl_str)
print('\midrule')
print(h2ml_str)
print('\midrule')
print(h4ml_str)
print('\\bottomrule')
print('\\\[-0.25cm]')
print('\\end{tabular}')

\begin{sidewaystable}
\renewcommand{\arraystretch}{1.03}
\setlength\tabcolsep{1.00pt}
\centering
\caption{Results on the \textbf{Chaoyang dataset} , with standard deviation for 5 training runs.\
For each model, \unl{\textbf{best}} and \textbf{second best} ranks are marked.}\label{chaoyang_dispersion}
\smallskip
\begin{tabular}{c cccc cccc cccc}
& \multicolumn{4}{c}{\textbf{ResNet50}} & \multicolumn{4}{c}{\textbf{ConvNeXt}} & \multicolumn{4}{c}{\textbf{Swin-Transformer}} \\
\cmidrule(lr){2-5} \cmidrule(lr){6-9} \cmidrule(lr){10-13} &  ACC$^\uparrow$  &  ECE$_\downarrow$  &  NLL$_\downarrow$    &  Rank$_\downarrow$  &  ACC$^\uparrow$  &  ECE$_\downarrow$  &  NLL$_\downarrow$    &  Rank$_\downarrow$    &  ACC$^\uparrow$ &  ECE$_\downarrow$  &  NLL$_\downarrow$    &  Rank$_\downarrow$\\
\midrule
\textbf{SL1H}     & 80.71$\pm$0.64 & 5.79$\pm$1.75 & 53.46$\pm$3.24 & 6.0 & 81.91$\pm$0.46 & 6.94$\pm$1.49 & 50.98$\pm$3.44 & 6.3 & 83.09$\pm$0.51 & 8.73$\pm$0.45 & 52.75$\pm$1.81 & 5.0 \\
\midrule

# KVASIR

In [29]:
dataset='kvasir'

In [30]:
m_cvx = \
print_all(dataset, 'convnext')

Single Head (avg) and Ensembles
ACC=90.02+/-0.27, ECE=5.18+/-0.52, NLL=35.59+/-2.23
------------------------------
DEEP ENSEMBLES:
ACC=90.76, ECE=3.34, NLL=29.74


LS gamma=0.05: Single Head (avg) and Ensembles
ACC=88.24+/-0.56, ECE=6.97+/-1.60, NLL=42.09+/-2.50


MbLS m=6: Single Head (avg) and Ensembles
ACC=88.62+/-0.22, ECE=8.55+/-2.06, NLL=43.07+/-2.05


MixUp gamma=0.2: Single Head (avg) and Ensembles
ACC=87.58+/-0.58, ECE=8.96+/-2.81, NLL=48.88+/-3.67


DCA: Single Head (avg) and Ensembles
ACC=85.27+/-0.89, ECE=4.11+/-0.94, NLL=46.78+/-2.91


2hsl
ACC=90.21+/-0.14, ECE=2.63+/-0.45, NLL=28.69+/-0.81
************************************************************
2hml
************************************************************
ACC=89.92+/-0.31, ECE=1.49+/-0.28, NLL=28.15+/-0.18
************************************************************
4hml
************************************************************
ACC=90.10+/-0.29, ECE=1.65+/-0.42, NLL=28.01+/-0.87


In [31]:
m_r50 = \
print_all(dataset, 'resnet50')

Single Head (avg) and Ensembles
ACC=89.87+/-0.18, ECE=6.32+/-0.24, NLL=41.88+/-1.65
------------------------------
DEEP ENSEMBLES:
ACC=90.76, ECE=3.83, NLL=32.09


LS gamma=0.05: Single Head (avg) and Ensembles
ACC=88.13+/-0.46, ECE=14.63+/-1.99, NLL=53.96+/-2.54


MbLS m=6: Single Head (avg) and Ensembles
ACC=88.20+/-0.65, ECE=16.92+/-1.05, NLL=57.48+/-2.87


MixUp gamma=0.2: Single Head (avg) and Ensembles
ACC=87.60+/-0.50, ECE=10.28+/-2.41, NLL=50.69+/-2.65


DCA: Single Head (avg) and Ensembles
ACC=87.14+/-0.64, ECE=3.84+/-0.76, NLL=40.50+/-2.38


2hsl
ACC=89.76+/-0.27, ECE=4.52+/-0.93, NLL=34.34+/-3.15
************************************************************
2hml
************************************************************
ACC=90.05+/-0.40, ECE=3.62+/-0.78, NLL=31.37+/-1.97
************************************************************
4hml
************************************************************
ACC=89.99+/-0.25, ECE=2.22+/-0.53, NLL=30.02+/-1.10


In [32]:
m_swt = \
print_all(dataset, 'swin')

Single Head (avg) and Ensembles
ACC=90.07+/-0.43, ECE=5.81+/-0.72, NLL=38.01+/-3.61
------------------------------
DEEP ENSEMBLES:
ACC=90.53, ECE=3.94, NLL=29.36


LS gamma=0.05: Single Head (avg) and Ensembles
ACC=88.74+/-0.68, ECE=9.20+/-1.77, NLL=43.46+/-2.20


MbLS m=6: Single Head (avg) and Ensembles
ACC=89.15+/-0.49, ECE=8.19+/-0.45, NLL=41.85+/-1.70


MixUp gamma=0.2: Single Head (avg) and Ensembles
ACC=89.23+/-0.32, ECE=2.11+/-0.25, NLL=35.52+/-1.36


DCA: Single Head (avg) and Ensembles
ACC=87.62+/-0.82, ECE=4.38+/-1.62, NLL=38.44+/-1.16


2hsl
ACC=90.40+/-0.17, ECE=3.65+/-0.67, NLL=29.14+/-1.31
************************************************************
2hml
************************************************************
ACC=90.19+/-0.33, ECE=2.73+/-0.64, NLL=28.66+/-1.31
************************************************************
4hml
************************************************************
ACC=90.00+/-0.32, ECE=1.82+/-0.35, NLL=27.96+/-0.87


In [33]:
m_r50_with_ranks = add_acc_ece_nll_combined_ranks(m_r50)
m_cvx_with_ranks = add_acc_ece_nll_combined_ranks(m_cvx)
m_swt_with_ranks = add_acc_ece_nll_combined_ranks(m_swt)

In [34]:
sl1h_str  ='\\textbf{SL1H}     & xx1$\pm$ss1 & xx2$\pm$ss2 & xx3$\pm$ss3 & xx4 & xx5$\pm$ss5 & xx6$\pm$ss6 & xx7$\pm$ss7 & xx8 & xx9$\pm$ss9 & x10$\pm$s10 & x11$\pm$s11 & x12 \\\\'
DE_str    ='\\textbf{D-Ens}    & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
LS_str    ='\\textbf{LS}       & xx1$\pm$ss1 & xx2$\pm$ss2 & xx3$\pm$ss3 & xx4 & xx5$\pm$ss5 & xx6$\pm$ss6 & xx7$\pm$ss7 & xx8 & xx9$\pm$ss9 & x10$\pm$s10 & x11$\pm$s11 & x12 \\\\'
MbLS_str  ='\\textbf{MbLS}     & xx1$\pm$ss1 & xx2$\pm$ss2 & xx3$\pm$ss3 & xx4 & xx5$\pm$ss5 & xx6$\pm$ss6 & xx7$\pm$ss7 & xx8 & xx9$\pm$ss9 & x10$\pm$s10 & x11$\pm$s11 & x12 \\\\'
Mxp_str   ='\\textbf{MixUp}    & xx1$\pm$ss1 & xx2$\pm$ss2 & xx3$\pm$ss3 & xx4 & xx5$\pm$ss5 & xx6$\pm$ss6 & xx7$\pm$ss7 & xx8 & xx9$\pm$ss9 & x10$\pm$s10 & x11$\pm$s11 & x12 \\\\'
DCA_str   ='\\textbf{DCA}      & xx1$\pm$ss1 & xx2$\pm$ss2 & xx3$\pm$ss3 & xx4 & xx5$\pm$ss5 & xx6$\pm$ss6 & xx7$\pm$ss7 & xx8 & xx9$\pm$ss9 & x10$\pm$s10 & x11$\pm$s11 & x12 \\\\'
h2sl_str  ='\\textbf{2HSL}     & xx1$\pm$ss1 & xx2$\pm$ss2 & xx3$\pm$ss3 & xx4 & xx5$\pm$ss5 & xx6$\pm$ss6 & xx7$\pm$ss7 & xx8 & xx9$\pm$ss9 & x10$\pm$s10 & x11$\pm$s11 & x12 \\\\'
h2ml_str   ='\\textbf{2HML}      & xx1$\pm$ss1 & xx2$\pm$ss2 & xx3$\pm$ss3 & xx4 & xx5$\pm$ss5 & xx6$\pm$ss6 & xx7$\pm$ss7 & xx8 & xx9$\pm$ss9 & x10$\pm$s10 & x11$\pm$s11 & x12 \\\\'
h4ml_str   ='\\textbf{4HML}      & xx1$\pm$ss1 & xx2$\pm$ss2 & xx3$\pm$ss3 & xx4 & xx5$\pm$ss5 & xx6$\pm$ss6 & xx7$\pm$ss7 & xx8 & xx9$\pm$ss9 & x10$\pm$s10 & x11$\pm$s11 & x12 \\\\'

In [35]:
strings = [sl1h_str, DE_str, LS_str, MbLS_str, Mxp_str, DCA_str, h2sl_str, h2ml_str, h4ml_str]

strings = fill_rows(m_r50, 'xx1', 'xx2', 'xx3', 'xx4', 'ss1', 'ss2', 'ss3', strings)
strings = fill_rows(m_cvx, 'xx5', 'xx6', 'xx7', 'xx8', 'ss5', 'ss6', 'ss7', strings)
strings = fill_rows(m_swt, 'xx9', 'x10', 'x11', 'x12', 'ss9', 's10', 's11', strings)

sl1h_str, DE_str, LS_str, MbLS_str, Mxp_str, DCA_str, h2sl_str, h2ml_str, h4ml_str = strings

In [36]:
caption = r'\caption{Results on the \textbf{Kvasir dataset} , with standard deviation for 5 training runs.\
For each model, \unl{\textbf{best}} and \textbf{second best} ranks are marked.}\label{kvasir_dispersion}'

In [37]:
print('\\bigskip\\bigskip  % provide some separation between the two tables')
print(caption)
print('\\smallskip')
print('\\smallskip')
print('\\begin{tabular}{c cccc cccc cccc}')
print('& \\multicolumn{4}{c}{\\textbf{ResNet50}} & \\multicolumn{4}{c}{\\textbf{ConvNeXt}} & \\multicolumn{4}{c}{\\textbf{Swin-Transformer}} \\\\')
print('\\cmidrule(lr){2-5} \\cmidrule(lr){6-9} \\cmidrule(lr){10-13} &  ACC$^\\uparrow$  &  ECE$_\\downarrow$  &  NLL$_\\downarrow$    &  Rank$_\\downarrow$  &  ACC$^\\uparrow$  &  ECE$_\\downarrow$  &  NLL$_\\downarrow$    &  Rank$_\\downarrow$    &  ACC$^\\uparrow$ &  ECE$_\\downarrow$  &  NLL$_\\downarrow$    &  Rank$_\\downarrow$\\\\')
print('\\midrule')
print(sl1h_str)
print('\midrule')
print(LS_str)
print('\midrule')
print(MbLS_str)
print('\midrule')
print(Mxp_str)
print('\midrule')
print(DCA_str)
print('\midrule')
print('\midrule')
print(DE_str)
print('\midrule')
print('\midrule')
print(h2sl_str)
print('\midrule')
print(h2ml_str)
print('\midrule')
print(h4ml_str)
print('\\bottomrule')
print('\\\[-0.25cm]')
print('\\end{tabular}')
print('\\end{sidewaystable}')

\bigskip\bigskip  % provide some separation between the two tables
\caption{Results on the \textbf{Kvasir dataset} , with standard deviation for 5 training runs.\
For each model, \unl{\textbf{best}} and \textbf{second best} ranks are marked.}\label{kvasir_dispersion}
\smallskip
\smallskip
\begin{tabular}{c cccc cccc cccc}
& \multicolumn{4}{c}{\textbf{ResNet50}} & \multicolumn{4}{c}{\textbf{ConvNeXt}} & \multicolumn{4}{c}{\textbf{Swin-Transformer}} \\
\cmidrule(lr){2-5} \cmidrule(lr){6-9} \cmidrule(lr){10-13} &  ACC$^\uparrow$  &  ECE$_\downarrow$  &  NLL$_\downarrow$    &  Rank$_\downarrow$  &  ACC$^\uparrow$  &  ECE$_\downarrow$  &  NLL$_\downarrow$    &  Rank$_\downarrow$    &  ACC$^\uparrow$ &  ECE$_\downarrow$  &  NLL$_\downarrow$    &  Rank$_\downarrow$\\
\midrule
\textbf{SL1H}     & 89.87$\pm$0.18 & 6.32$\pm$0.24 & 41.88$\pm$1.65 & 5.3 & 90.02$\pm$0.27 & 5.18$\pm$0.52 & 35.59$\pm$2.23 & 5.0 & 90.07$\pm$0.43 & 5.81$\pm$0.72 & 38.01$\pm$3.61 & 5.7 \\
\midrule
\textbf{LS}       & 88