In [1]:
from utils.yace import YACE

In [2]:
import torch
import torch.nn as nn
from sklearn.metrics import log_loss

In [3]:
import numpy as np

def set_seed(seed_value, use_cuda):
    if seed_value is not None:
        np.random.seed(seed_value)  # cpu vars
        torch.manual_seed(seed_value)  # cpu  vars
        random.seed(seed_value)  # Python
        torch.use_deterministic_algorithms(True)

        if use_cuda:
            torch.cuda.manual_seed(seed_value)
            torch.cuda.manual_seed_all(seed_value)
            torch.backends.cudnn.deterministic = True

In [4]:
import os, sys, random, argparse, os.path as osp
import pandas as pd
import numpy as np
import torch
import medmnist
from utils.data_handling import get_medmnist_loaders, get_medmnist_test_loader
from utils.data_handling import get_class_loaders, get_class_test_loader
from utils.evaluation import evaluate_cls
from utils.get_model_v2 import get_arch
from tqdm import trange, tqdm
import torchvision.transforms as tr

In [5]:
def test_one_epoch(model, loader, device):
    model.to(device)
    model.eval()
    probs_all, labels_all = [], []

    for i_batch, (inputs, labels) in enumerate(loader):
        inputs = inputs.to(device)
        logits = model(inputs)  # bs x n_classes
        probs = logits.softmax(dim=1).detach().cpu().numpy()
        labels = labels.numpy()
        if labels.ndim == 0:  # for 1-element batches labels degenerates to a scalar
            labels = np.expand_dims(labels, 0)
        probs_all.extend(probs)
        labels_all.extend(list(labels))

    return np.stack(probs_all), np.array(labels_all).squeeze()

def test_cls(model, test_loader, device):

    with torch.inference_mode():
        probs, labels = test_one_epoch(model, test_loader, device)

    del model
    torch.cuda.empty_cache()

    return probs, labels

In [6]:
from torchmetrics.classification import MulticlassCalibrationError

In [7]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

In [8]:
use_cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if use_cuda else 'cpu')

## Multi-heads with noisy wCE

In [9]:
def test_one_epoch_multihead(model, loader, device):
    ## IMPORTANT: Note that multi-head models in test time return softmax-activated tensors and not logits
    ## which is shitty because I cannot apply TC
    model.to(device)
    model.eval()

    probs_all, labels_all = [], []

    for i_batch, (inputs, labels) in enumerate(loader):
        inputs = inputs.to(device)
        probs = model(inputs).detach().cpu().numpy()

        labels = labels.numpy()
        if labels.ndim == 0:  # for 1-element batches labels degenerates to a scalar
            labels = np.expand_dims(labels, 0)
        probs_all.extend(probs)
        labels_all.extend(list(labels))


    return np.stack(probs_all), np.array(labels_all).squeeze()

In [10]:
# def test_one_epoch_multihead(model, loader, device):
#     model.to(device)
#     model.eval()

#     probs_all, labels_all = [], []

#     for i_batch, (inputs, labels) in enumerate(loader):
#         inputs = inputs.to(device)
#         logits = model(inputs)  # bs x n_heads x n_classes

# #       # average over the logits? --- TERRIBLE IDEA
# #         logits = torch.mean(logits, dim=1).squeeze()
# #         probs = logits.softmax(dim=1).detach().cpu().numpy()
#         # or average over the softmaxes?
    
#         probs_h0 = logits[:,0,:].softmax(dim=1).detach().cpu().numpy()
#         probs_h1 = logits[:,1,:].softmax(dim=1).detach().cpu().numpy()
#         probs = (probs_h0+probs_h1)/2


#         labels = labels.numpy()
#         if labels.ndim == 0:  # for 1-element batches labels degenerates to a scalar
#             labels = np.expand_dims(labels, 0)
#         probs_all.extend(probs)
#         labels_all.extend(list(labels))


#     return np.stack(probs_all), np.array(labels_all).squeeze()

In [11]:
def test_cls_multihead(model, test_loader, device):
    with torch.inference_mode():
        probs, labels = test_one_epoch_multihead(model, test_loader, device)

    del model
    torch.cuda.empty_cache()

    return probs, labels

In [12]:
from utils.data_handling import get_class_test_loader

In [13]:
from tqdm.notebook import tqdm
from utils.calib_tools import ace, tace, sce, ece as ece_

In [14]:
dataset='chaoyang'
method='ce'
assert dataset in ['chaoyang', 'mhist', 'kather', 'busi',  
                   'breakhist_40x','breakhist_100x','breakhist_200x','breakhist_400x']
data_path = osp.join('data', dataset)
csv_test = osp.join('data', 'test_'+dataset+'.csv')

tg_size = 224,224
if dataset=='kather': tg_size=150,150

test_loader  = get_class_test_loader(csv_test, data_path, tg_size, batch_size=64, num_workers=6)
num_classes = len(test_loader.dataset.classes)
model = get_arch('resnet18', num_classes)

In [15]:
state = torch.load('experiments/chaoyang/bl/r18_f0/model_checkpoint.pth',  map_location=device)
model.load_state_dict(state['model_state_dict'])

<All keys matched successfully>

In [16]:
from sklearn.metrics import accuracy_score as acc
from sklearn.metrics import roc_auc_score as auc
from sklearn.metrics import log_loss as error

# from utils.temperature_calibration import ModelWithTemperature
from utils.data_handling import get_medmnist_loaders

In [17]:
def print_results_multi_fold(dataset=None, model_name='resnet18', n_bins=15, method='bl', 
                             with_temp=False, with_ens=False, yace_metric=acc, blur=0):
    assert dataset is not None
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda:0' if use_cuda else 'cpu')
    seed=0
    set_seed(seed, use_cuda)
    if model_name in ['convnext', 'swin'] :
        torch.use_deterministic_algorithms(False)
        torch.backends.cudnn.deterministic = True

    assert dataset in ['chaoyang', 'kather', 'kvasir', 'pathmnist', 'retinamnist', 'dermamnist',
                       'busi', 'gbcu', 'mhist', 
                       'breakhist_40x','breakhist_100x','breakhist_200x','breakhist_400x']

    if 'mnist' in dataset:
        tg_size = 28 if model_name !='convnext' else 32
        test_loader = get_medmnist_test_loader(dataset, batch_size=128, num_workers=6, tg_size=tg_size)
        class_names = list(medmnist.INFO[dataset]['label'].keys())
        num_classes = len(class_names)
        if with_temp:
            _, val_loader = get_medmnist_loaders(dataset, batch_size=128, num_workers=6)
            
    else:
        data_path = osp.join('data', dataset)
        csv_test = osp.join('data', 'test_'+dataset+'.csv')
        tg_size = 224,224
        if dataset=='kather': tg_size=150,150
        test_loader  = get_class_test_loader(csv_test, data_path, tg_size, blur=blur, batch_size=64, num_workers=6)
        num_classes = len(test_loader.dataset.classes)
        class_names = ['C{}_probs'.format(i) for i in range(num_classes)]
        if with_temp:
            csv_train, csv_val = csv_test.replace('test', 'train'), csv_test.replace('test', 'val')
            _ , val_loader = get_class_loaders(csv_train, csv_val, data_path, tg_size, 
                                               batch_size=64, num_workers=6, see_classes=False)
        
    load_path = osp.join('experiments', dataset, method)
    
    ########################
    # results for individual models
    ########################    
    if model_name=='resnet18':
        load_path_this = osp.join(load_path, 'r18_f')
    elif model_name=='resnet34':
        load_path_this = osp.join(load_path, 'r34_f')
    elif model_name=='convnext':
        load_path_this = osp.join(load_path, 'cnx_f')        
    elif model_name=='resnet50':
        load_path_this = osp.join(load_path, 'r50_f')
    elif model_name=='mobilenet_v2':
        load_path_this = osp.join(load_path, 'm2_f')
    elif model_name=='swin':
        load_path_this = osp.join(load_path, 'swt_f')
    model = get_arch(model_name, num_classes)
    checkpoint_list = [osp.join(load_path_this + str(i), 'model_checkpoint.pth') for i in [0, 1, 2, 3, 4]]
    states = [torch.load(c,  map_location=device) for c in checkpoint_list]
    
    all_probs = []
    # Do inference for each model
    with torch.inference_mode():
        for i in range(len(states)):
            state=states[i]
            model.load_state_dict(state['model_state_dict'])
            if with_temp:
                tempered_model = ModelWithTemperature(model, num_classes, log=False).set_temperature(
                    val_loader, cross_validate='ece')
                probs, labels = test_cls(tempered_model, test_loader, device)
            else:
                probs, labels = test_cls(model, test_loader, device)
            all_probs.append(probs)
    
    ece = MulticlassCalibrationError(num_classes=num_classes, n_bins=n_bins, norm='l1')
    

    accs, aucs, eces, yaces, ces= [],[],[],[],[]
    for probs in all_probs:
        preds = np.argmax(probs, axis=1)
        test_auc, test_f1, test_acc, test_auc_all, test_f1_all = evaluate_cls(labels, preds, probs, 
                                                                              print_conf=False)
        e = ece(torch.from_numpy(probs), torch.from_numpy(labels)).item()        
        ce = log_loss(labels,probs)
        y = YACE(labels, probs, metric=yace_metric)
        
        aucs.append(100*test_auc)
        accs.append(100*test_acc)
        eces.append(100*e)
        yaces.append(100*y)
        ces.append(100*ce)
 
    ########################
    # print average results
    ######################## 
    print('AUC={:.2f}+/-{:.2f}, ACC={:.2f}+/-{:.2f}, ECE={:.2f}+/-{:.2f}, \
YACE={:.2f}+/-{:.2f}, CE={:.2f}+/-{:.2f}'.format(
np.mean(aucs), np.std(aucs), np.mean(accs), np.std(accs), np.mean(eces), np.std(eces), 
        np.mean(yaces), np.std(yaces), np.mean(ces), np.std(ces)))

    if with_ens:
        ########################
        # results for ensemble
        ########################  
        ens_probs = np.mean(all_probs, axis=0)
        ens_preds = np.argmax(ens_probs, axis=1)
        test_auc, test_f1, test_acc, test_auc_all, test_f1_all = evaluate_cls(labels, ens_preds, ens_probs, 
                                                                              print_conf=False)
        e = ece(torch.from_numpy(ens_probs), torch.from_numpy(labels)).item()
        y = YACE(labels, probs, metric=yace_metric)
        ce = log_loss(labels,ens_probs)

        print(30*'-')
        print('DEEP ENSEMBLES:')
        print('AUC={:.2f}, ACC={:.2f}, ECE={:.2f}, YACE={:.2f}, CE={:.2f}'.format(
               100*test_auc,100*test_acc,100*e,100*y,100*ce))
        
    mtrcs = [np.mean(aucs), np.mean(accs), np.mean(eces), np.mean(ces)]
    if with_ens:
        return mtrcs, [100*test_auc, 100*test_acc, 100*e, 100*ce]
    return mtrcs

In [18]:
def print_results_multi_head(dataset=None, model_name='resnet18', bal=False, 
                             n_bins=15, yace_metric=acc, nh=2, blur=0):
    assert dataset is not None
    
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda:0' if use_cuda else 'cpu')
#     # reproducibility
    seed=0
    set_seed(seed, use_cuda)
    if model_name in ['convnext', 'swin']:
        torch.use_deterministic_algorithms(False)
        torch.backends.cudnn.deterministic = True

    assert dataset in ['chaoyang', 'kather', 'kvasir', 'pathmnist', 'retinamnist', 'dermamnist',
                       'busi', 'gbcu', 'mhist', 
                       'breakhist_40x','breakhist_100x','breakhist_200x','breakhist_400x']
    

    if 'mnist' in dataset:
        tg_size = 28 if model_name !='convnext' else 32
        test_loader = get_medmnist_test_loader(dataset, batch_size=128, num_workers=6, tg_size=tg_size)
        class_names = list(medmnist.INFO[dataset]['label'].keys())
        num_classes = len(class_names)

    else:
        data_path = osp.join('data', dataset)
        csv_test = osp.join('data', 'test_'+dataset+'.csv')

        tg_size = 224,224
        if dataset=='kather': tg_size=150,150
            
        test_loader  = get_class_test_loader(csv_test, data_path, tg_size, blur=blur, batch_size=64, num_workers=6)
        num_classes = len(test_loader.dataset.classes)
        class_names = ['C{}_probs'.format(i) for i in range(num_classes)]

    
    load_path = osp.join('experiments', dataset, '{}hs_new'.format(str(nh))) # w_max=2
    
    ########################################################################
    load_path = osp.join('experiments', dataset, '{}hs'.format(str(nh)))     # old w_max=3/2
    load_path = osp.join('experiments', dataset, '{}hs_new'.format(str(nh))) # new w_max=4
    ########################################################################    
    
    
    if bal==True:
        load_path = load_path.replace('2hs_new', '2hs_bal_new')
    
    if model_name=='resnet18':
        load_path_this = osp.join(load_path, 'r18_f')
    elif model_name=='resnet34':
        load_path_this = osp.join(load_path, 'r34_f')
    elif model_name=='convnext':
        load_path_this = osp.join(load_path, 'cnx_f')
    elif model_name=='resnet50':
        load_path_this = osp.join(load_path, 'r50_f')
    elif model_name=='mobilenet_v2':
        load_path_this = osp.join(load_path, 'm2_f')
    elif model_name=='swin':
        load_path_this = osp.join(load_path, 'swt_f')

    ########################
    # results for multi-head
    ########################
    model = get_arch(model_name, num_classes, n_heads=nh, spe=False)
    checkpoint_list = [osp.join(load_path_this + str(i), 'model_checkpoint.pth') for i in [0, 1, 2, 3, 4]]
    states = [torch.load(c,  map_location=device) for c in checkpoint_list]
    #print('* Loading model {} from {}'.format(model_name, load_path_this[:-2]))
    all_probs = []
    # Do inference on val data
    with torch.inference_mode():
        for i in range(len(states)):
            state=states[i]        
            model.load_state_dict(state['model_state_dict'])
            probs, labels = test_cls_multihead(model, test_loader, device)
            all_probs.append(probs)
    
    ece = MulticlassCalibrationError(num_classes=num_classes, n_bins=15, norm='l1')

    
    accs, aucs, eces, yaces, ces= [],[],[],[],[]
    for probs in all_probs:
        preds = np.argmax(probs, axis=1)
        test_auc, test_f1, test_acc, test_auc_all, test_f1_all = evaluate_cls(labels, preds, probs, 
                                                                              print_conf=False)
        e = ece(torch.from_numpy(probs), torch.from_numpy(labels)).item()
        y = YACE(labels, probs, metric=yace_metric)
        ce = log_loss(labels,probs)                
        aucs.append(100*test_auc)
        accs.append(100*test_acc)
        eces.append(100*e)
        yaces.append(100*y)
        ces.append(100*ce)
        
    ########################
    # print average results
    ######################## 
    print('AUC={:.2f}+/-{:.2f}, ACC={:.2f}+/-{:.2f}, ECE={:.2f}+/-{:.2f}, \
YACE={:.2f}+/-{:.2f}, CE={:.2f}+/-{:.2f}'.format(
np.mean(aucs), np.std(aucs), np.mean(accs), np.std(accs), np.mean(eces), np.std(eces), 
        np.mean(yaces), np.std(yaces), np.mean(ces), np.std(ces)))
    
    mtrcs = [np.mean(aucs), np.mean(accs), np.mean(eces), np.mean(ces)]
    return mtrcs

In [19]:
from sklearn.metrics import accuracy_score as acc

In [None]:
def print_all(dataset, model_name, yace_metric=acc):
    metrcs = dict()
    print(60*'=')
    print('Single Head (avg) and Ensembles')
    print(60*'=')

    metrcs['std'], metrcs['ens'] = \
    print_results_multi_fold(dataset, model_name, with_ens=True, yace_metric=yace_metric)

    # LS
    print('\n')
    print(60*'=')
    print('LS gamma=0.05: Single Head (avg) and Ensembles')
    print(60*'=')
    metrcs['ls_005'] = \
    print_results_multi_fold(dataset, model_name, method='ls_005', with_ens=False, yace_metric=yace_metric)

    # MbLS
    print('\n')
    print(60*'=')
    print('MbLS m=6: Single Head (avg) and Ensembles')
    print(60*'=')
    metrcs['mbls_6'] = \
    print_results_multi_fold(dataset, model_name, method='mbls_6', with_ens=False, yace_metric=yace_metric)
    
    # MIXUP
    print('\n')   
    print(60*'=')
    print('MixUp gamma=0.2: Single Head (avg) and Ensembles')
    print(60*'=')
    metrcs['mxp_02'] = \
    print_results_multi_fold(dataset, model_name, method='mxp_02', with_ens=False, yace_metric=yace_metric)

    # DCA
    print('\n')   
    print(60*'=')
    print('DCA: Single Head (avg) and Ensembles')
    print(60*'=')
    metrcs['dca'] = \
    print_results_multi_fold(dataset, model_name, method='dca', with_ens=False, yace_metric=yace_metric)

    # Two-Headed
    print('\n') 
    print(60*'=')
    print('Unperturbed ByCephal')
    print(60*'=')
    metrcs['mh'] = \
    print_results_multi_head(dataset, model_name, bal=True, yace_metric=yace_metric)

#     # Two-Headed
    print(60*'*')
    print('Perturbed ByCephal')
    print(60*'*')
    metrcs['pmh'] = \
    print_results_multi_head(dataset, model_name, yace_metric=yace_metric)

#     # 4-Headed    
    print(60*'*')
    print('Perturbed 4Cephal')
    print(60*'*')
    metrcs['pmh4'] = \
    print_results_multi_head(dataset, model_name, yace_metric=yace_metric, nh=4)
    
    return metrcs

In [None]:
from sklearn.metrics import balanced_accuracy_score as balacc
from sklearn.metrics import matthews_corrcoef as mcc

yace_metric=balacc
# yace_metric=mcc

np.mean(aucs), np.mean(accs), np.mean(eces), np.mean(ces)

In [None]:
def fill_rows(col_model_dict, col1, col2, col3, col4, strings):
    oneh_str, DE_str, LS_str, MbLS_str, Mxp_str, DCA_str, mh_str, p2h_str, p4h_str = strings
    a, b, c = 1, 2, 3
    oneh_str=oneh_str.replace(col1, '{:.2f}'.format(col_model_dict['std'][a]))
    oneh_str=oneh_str.replace(col2, '{:.2f}'.format(col_model_dict['std'][b]))
    oneh_str=oneh_str.replace(col3,  '{:.2f}'.format(col_model_dict['std'][c]))
    oneh_str=oneh_str.replace(col4,  '{:.1f}'.format(col_model_dict['std'][-1]))
    
    DE_str=DE_str.replace(col1, '{:.2f}'.format(col_model_dict['ens'][a]))
    DE_str=DE_str.replace(col2, '{:.2f}'.format(col_model_dict['ens'][b]))
    DE_str=DE_str.replace(col3,  '{:.2f}'.format(col_model_dict['ens'][c]))
    DE_str=DE_str.replace(col4,  '{:.1f}'.format(col_model_dict['ens'][-1]))

    LS_str=LS_str.replace(col1, '{:.2f}'.format(col_model_dict['ls_005'][a]))
    LS_str=LS_str.replace(col2, '{:.2f}'.format(col_model_dict['ls_005'][b]))
    LS_str=LS_str.replace(col3,  '{:.2f}'.format(col_model_dict['ls_005'][c]))
    LS_str=LS_str.replace(col4,  '{:.1f}'.format(col_model_dict['ls_005'][-1]))

    MbLS_str=MbLS_str.replace(col1, '{:.2f}'.format(col_model_dict['mbls_6'][a]))
    MbLS_str=MbLS_str.replace(col2, '{:.2f}'.format(col_model_dict['mbls_6'][b]))
    MbLS_str=MbLS_str.replace(col3,  '{:.2f}'.format(col_model_dict['mbls_6'][c]))
    MbLS_str=MbLS_str.replace(col4,  '{:.1f}'.format(col_model_dict['mbls_6'][-1]))

    Mxp_str=Mxp_str.replace(col1, '{:.2f}'.format(col_model_dict['mxp_02'][a]))
    Mxp_str=Mxp_str.replace(col2, '{:.2f}'.format(col_model_dict['mxp_02'][b]))
    Mxp_str=Mxp_str.replace(col3,  '{:.2f}'.format(col_model_dict['mxp_02'][c]))
    Mxp_str=Mxp_str.replace(col4,  '{:.1f}'.format(col_model_dict['mxp_02'][-1]))

    DCA_str=DCA_str.replace(col1, '{:.2f}'.format(col_model_dict['dca'][a]))
    DCA_str=DCA_str.replace(col2, '{:.2f}'.format(col_model_dict['dca'][b]))
    DCA_str=DCA_str.replace(col3,  '{:.2f}'.format(col_model_dict['dca'][c]))
    DCA_str=DCA_str.replace(col4,  '{:.1f}'.format(col_model_dict['dca'][-1]))

    mh_str=mh_str.replace(col1, '{:.2f}'.format(col_model_dict['mh'][a]))
    mh_str=mh_str.replace(col2, '{:.2f}'.format(col_model_dict['mh'][b]))
    mh_str=mh_str.replace(col3,  '{:.2f}'.format(col_model_dict['mh'][c]))
    mh_str=mh_str.replace(col4,  '{:.1f}'.format(col_model_dict['mh'][-1]))

    p2h_str=p2h_str.replace(col1, '{:.2f}'.format(col_model_dict['pmh'][a]))
    p2h_str=p2h_str.replace(col2, '{:.2f}'.format(col_model_dict['pmh'][b]))
    p2h_str=p2h_str.replace(col3,  '{:.2f}'.format(col_model_dict['pmh'][c]))
    p2h_str=p2h_str.replace(col4,  '{:.1f}'.format(col_model_dict['pmh'][-1]))

    p4h_str=p4h_str.replace(col1, '{:.2f}'.format(col_model_dict['pmh4'][a]))
    p4h_str=p4h_str.replace(col2, '{:.2f}'.format(col_model_dict['pmh4'][b]))
    p4h_str=p4h_str.replace(col3,  '{:.2f}'.format(col_model_dict['pmh4'][c])) 
    p4h_str=p4h_str.replace(col4,  '{:.1f}'.format(col_model_dict['pmh4'][-1])) 
    
    return oneh_str, DE_str, LS_str, MbLS_str, Mxp_str, DCA_str, mh_str, p2h_str, p4h_str

In [None]:
def add_acc_ece_combined_ranks(dic):
    dic_acc = {key:value[1] for (key,value) in dic.items()} #acc
    # False for lower is better, True for higher is better
    dic_rank_acc = {key: rank for rank, key in enumerate(sorted(dic_acc, key=dic_acc.get, reverse=True), 1)}
       
    dic_ece = {key:value[2] for (key,value) in dic.items()}
    dic_rank_ece = {key: rank for rank, key in enumerate(sorted(dic_ece, key=dic_ece.get, reverse=False), 1)}

    dic_nll = {key:value[3] for (key,value) in dic.items()}
    # False for lower is better, True for higher is better
    dic_rank_nll = {key: rank for rank, key in enumerate(sorted(dic_nll, key=dic_nll.get, reverse=False), 1)}

    
    for key in dic:
        dic[key].extend([dic_rank_acc[key], dic_rank_ece[key], dic_rank_nll[key],
                         np.mean([dic_rank_acc[key], dic_rank_ece[key], dic_rank_nll[key]])])    
    return dic

# CHAOYANG

In [None]:
dataset='chaoyang'

In [None]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

In [None]:
# m_r18 = \
# print_all(dataset, 'resnet18', yace_metric=yace_metric) 

In [None]:
# m_mobilenet = \
# print_all(dataset, 'mobilenet_v2', yace_metric=yace_metric) 

In [None]:
# m_r34 = \
# print_all(dataset, 'resnet34', yace_metric=yace_metric) 

In [None]:
m_r50 = \
print_all(dataset, 'resnet50', yace_metric=yace_metric) 

In [None]:
m_cvx = \
print_all(dataset, 'convnext', yace_metric=yace_metric)

In [None]:
m_swt = \
print_all(dataset, 'swin', yace_metric=yace_metric) 

In [None]:
# m_r18_with_ranks = add_acc_ece_combined_ranks(m_r18)
m_r50_with_ranks = add_acc_ece_combined_ranks(m_r50)
m_cvx_with_ranks = add_acc_ece_combined_ranks(m_cvx)
m_swt_with_ranks = add_acc_ece_combined_ranks(m_swt)

In [None]:
oneh_str  ='\\textbf{OneH}     & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
DE_str    ='\\textbf{D-Ens}    & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
LS_str    ='\\textbf{LS}       & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
MbLS_str  ='\\textbf{MbLS}     & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
Mxp_str   ='\\textbf{MixUp}    & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
DCA_str   ='\\textbf{DCA}      & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
mh_str    ='\\textbf{MH}       & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
p2h_str   ='\\textbf{P2H}      & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
p4h_str   ='\\textbf{P4H}      & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'

In [None]:
strings = [oneh_str, DE_str, LS_str, MbLS_str, Mxp_str, DCA_str, mh_str, p2h_str, p4h_str]

# strings = fill_rows(m_mobilenet, 'xx1', 'xx2', 'xx3', strings)
# strings = fill_rows(m_r18, 'xx4', 'xx5', 'xx6', strings)
# strings = fill_rows(m_r34, 'xx7', 'xx8', 'xx9', strings)
# strings = fill_rows(m_r50, 'x10', 'x11', 'x12', strings)
# strings = fill_rows(m_cvx, 'x13', 'x14', 'x15', strings)

strings = fill_rows(m_r50, 'xx1', 'xx2', 'xx3', 'xx4', strings)
strings = fill_rows(m_cvx, 'xx5', 'xx6', 'xx7', 'xx8', strings)
strings = fill_rows(m_swt, 'xx9', 'x10', 'x11', 'x12', strings)

oneh_str, DE_str, LS_str, MbLS_str, Mxp_str, DCA_str, mh_str, p2h_str, p4h_str = strings

In [None]:
caption = r'\caption{Results on the \textbf{Chaoyang dataset} with different architectures and strategies.\
For each model, \unl{\textbf{best}} and \textbf{second best} ranks are marked.}\label{chaoyang}'

In [None]:
print('\\begin{table}[!t]')
print('\\renewcommand{\\arraystretch}{1.03}')
print('\\setlength\\tabcolsep{1.00pt}')
print('\\begin{center}')
print('\\begin{tabular}{c cccc cccc cccc}')
print('& \\multicolumn{4}{c}{\\textbf{ResNet50}} & \\multicolumn{4}{c}{\\textbf{ConvNeXt}} & \\multicolumn{4}{c}{\\textbf{Swin-Transformer}} \\\\')
print('\\cmidrule(lr){2-5} \\cmidrule(lr){6-9} \\cmidrule(lr){10-13} &  ACC$^\\uparrow$  &  ECE$_\\downarrow$  &  NLL$_\\downarrow$    &  Rank$_\\downarrow$  &  ACC$^\\uparrow$  &  ECE$_\\downarrow$  &  NLL$_\\downarrow$    &  Rank$_\\downarrow$    &  ACC$^\\uparrow$ &  ECE$_\\downarrow$  &  NLL$_\\downarrow$    &  Rank$_\\downarrow$\\\\')
print('\\midrule')
print(oneh_str)
print('\midrule')
print(LS_str)
print('\midrule')
print(MbLS_str)
print('\midrule')
print(Mxp_str)
print('\midrule')
# print(FL_str)
# print('\midrule')
print(DCA_str)
print('\midrule')
print('\midrule')
print(DE_str)
print('\midrule')
print('\midrule')
print(mh_str)
print('\midrule')
print(p2h_str)
print('\midrule')
print(p4h_str)
print('\\bottomrule')
print('\\\[-0.25cm]')
print('\\end{tabular}')
print(caption)
print('\\end{center}')
print('\\vspace{-1cm}')
print('\\end{table}')

# Kvasir

In [None]:
dataset='kvasir'

In [None]:
# m_r18 = \
# print_all(dataset, 'resnet18', yace_metric=yace_metric) # new 4/2

In [None]:
# m_mobilenet = \
# print_all(dataset, 'mobilenet_v2', yace_metric=yace_metric)  # new 4/2

In [None]:
# m_r34 = \
# print_all(dataset, 'resnet34', yace_metric=yace_metric)  # new 4/2

In [None]:
m_cvx = \
print_all(dataset, 'convnext', yace_metric=yace_metric) # new w_max=4/2

In [None]:
m_r50 = \
print_all(dataset, 'resnet50', yace_metric=yace_metric)

In [None]:
m_swt = \
print_all(dataset, 'swin', yace_metric=yace_metric)

In [None]:
oneh_str  ='\\textbf{OneH}     & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
DE_str    ='\\textbf{D-Ens}    & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
LS_str    ='\\textbf{LS}       & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
MbLS_str  ='\\textbf{MbLS}     & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
Mxp_str   ='\\textbf{MixUp}    & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
DCA_str   ='\\textbf{DCA}      & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
mh_str    ='\\textbf{MH}       & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
p2h_str   ='\\textbf{P2H}      & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
p4h_str   ='\\textbf{P4H}      & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'

In [None]:
m_r50_with_ranks = add_acc_ece_combined_ranks(m_r50)
m_cvx_with_ranks = add_acc_ece_combined_ranks(m_cvx)
m_swt_with_ranks = add_acc_ece_combined_ranks(m_swt)

In [None]:
strings = [oneh_str, DE_str, LS_str, MbLS_str, Mxp_str, DCA_str, mh_str, p2h_str, p4h_str]

# strings = fill_rows(m_mobilenet, 'xx1', 'xx2', 'xx3', strings)
# strings = fill_rows(m_r18, 'xx4', 'xx5', 'xx6', strings)
# strings = fill_rows(m_r34, 'xx7', 'xx8', 'xx9', strings)
# strings = fill_rows(m_r50, 'x10', 'x11', 'x12', strings)
# strings = fill_rows(m_cvx, 'x13', 'x14', 'x15', strings)

strings = fill_rows(m_r50, 'xx1', 'xx2', 'xx3', 'xx4', strings)
strings = fill_rows(m_cvx, 'xx5', 'xx6', 'xx7', 'xx8', strings)
strings = fill_rows(m_swt, 'xx9', 'x10', 'x11', 'x12', strings)


oneh_str, DE_str, LS_str, MbLS_str, Mxp_str, DCA_str, mh_str, p2h_str, p4h_str = strings

In [None]:
caption = r'\caption{Results on the \textbf{Kvasir dataset} with different architectures and strategies.\
For each model, \unl{\textbf{best}} and \textbf{second best} ranks are marked.}\label{kvasir}'

In [None]:
print('\\begin{table}[!t]')
print('\\renewcommand{\\arraystretch}{1.03}')
print('\\setlength\\tabcolsep{1.00pt}')
print('\\begin{center}')
print('\\begin{tabular}{c cccc cccc cccc}')
print('& \\multicolumn{4}{c}{\\textbf{ResNet50}} & \\multicolumn{4}{c}{\\textbf{ConvNeXt}} & \\multicolumn{4}{c}{\\textbf{Swin-Transformer}} \\\\')
print('\\cmidrule(lr){2-5} \\cmidrule(lr){6-9} \\cmidrule(lr){10-13} &  ACC$^\\uparrow$  &  ECE$_\\downarrow$  &  NLL$_\\downarrow$    &  Rank$_\\downarrow$  &  ACC$^\\uparrow$  &  ECE$_\\downarrow$  &  NLL$_\\downarrow$    &  Rank$_\\downarrow$    &  ACC$^\\uparrow$ &  ECE$_\\downarrow$  &  NLL$_\\downarrow$    &  Rank$_\\downarrow$\\\\')
print('\\midrule')
print(oneh_str)
print('\midrule')
print(LS_str)
print('\midrule')
print(MbLS_str)
print('\midrule')
print(Mxp_str)
print('\midrule')
# print(FL_str)
# print('\midrule')
print(DCA_str)
print('\midrule')
print('\midrule')
print(DE_str)
print('\midrule')
print('\midrule')
print(mh_str)
print('\midrule')
print(p2h_str)
print('\midrule')
print(p4h_str)
print('\\bottomrule')
print('\\\[-0.25cm]')
print('\\end{tabular}')
print(caption)
print('\\end{center}')
print('\\vspace{-1cm}')
print('\\end{table}')

# Pathmnist

In [None]:
dataset='pathmnist'

In [None]:
# m_r18 = \
# print_all(dataset, 'resnet18', yace_metric=yace_metric) # new w_max=4/2

In [None]:
# m_mobilenet = \
# print_all(dataset, 'mobilenet_v2', yace_metric=yace_metric) # new w_max=4/2

In [None]:
# m_r34 = \
# print_all(dataset, 'resnet34', yace_metric=yace_metric) # new w_max=4/2

In [None]:
m_r50 = \
print_all(dataset, 'resnet50', yace_metric=yace_metric)

In [None]:
m_cvx = \
print_all(dataset, 'convnext', yace_metric=yace_metric) # new w_max=4/2

In [None]:
m_swt = \
print_all(dataset, 'swin', yace_metric=yace_metric)

In [None]:
m_r50_with_ranks = add_acc_ece_combined_ranks(m_r50)
m_cvx_with_ranks = add_acc_ece_combined_ranks(m_cvx)
m_swt_with_ranks = add_acc_ece_combined_ranks(m_swt)

In [None]:
oneh_str  ='\\textbf{OneH}     & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
DE_str    ='\\textbf{D-Ens}    & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
LS_str    ='\\textbf{LS}       & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
MbLS_str  ='\\textbf{MbLS}     & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
Mxp_str   ='\\textbf{MixUp}    & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
DCA_str   ='\\textbf{DCA}      & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
mh_str    ='\\textbf{MH}       & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
p2h_str   ='\\textbf{P2H}      & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'
p4h_str   ='\\textbf{P4H}      & xx1 & xx2 & xx3 & xx4 & xx5 & xx6 & xx7 & xx8 & xx9 & x10 & x11 & x12 \\\\'

In [None]:
strings = [oneh_str, DE_str, LS_str, MbLS_str, Mxp_str, DCA_str, mh_str, p2h_str, p4h_str]

# strings = fill_rows(m_mobilenet, 'xx1', 'xx2', 'xx3', strings)
# strings = fill_rows(m_r18, 'xx4', 'xx5', 'xx6', strings)
# strings = fill_rows(m_r34, 'xx7', 'xx8', 'xx9', strings)
# strings = fill_rows(m_r50, 'x10', 'x11', 'x12', strings)
# strings = fill_rows(m_cvx, 'x13', 'x14', 'x15', strings)

strings = fill_rows(m_r50, 'xx1', 'xx2', 'xx3', 'xx4', strings)
strings = fill_rows(m_cvx, 'xx5', 'xx6', 'xx7', 'xx8', strings)
strings = fill_rows(m_swt, 'xx9', 'x10', 'x11', 'x12', strings)


oneh_str, DE_str, LS_str, MbLS_str, Mxp_str, DCA_str, mh_str, p2h_str, p4h_str = strings

In [None]:
caption = r'\caption{Results on the \textbf{PathMnist dataset} with different architectures and strategies.\
For each model, \unl{\textbf{best}} and \textbf{second best} ranks are marked.}\label{pathmnist}'

In [None]:
print('\\begin{table}[!t]')
print('\\renewcommand{\\arraystretch}{1.03}')
print('\\setlength\\tabcolsep{1.00pt}')
print('\\begin{center}')
print('\\begin{tabular}{c cccc cccc cccc}')
print('& \\multicolumn{4}{c}{\\textbf{ResNet50}} & \\multicolumn{4}{c}{\\textbf{ConvNeXt}} & \\multicolumn{4}{c}{\\textbf{Swin-Transformer}} \\\\')
print('\\cmidrule(lr){2-5} \\cmidrule(lr){6-9} \\cmidrule(lr){10-13} &  ACC$^\\uparrow$  &  ECE$_\\downarrow$  &  NLL$_\\downarrow$    &  Rank$_\\downarrow$  &  ACC$^\\uparrow$  &  ECE$_\\downarrow$  &  NLL$_\\downarrow$    &  Rank$_\\downarrow$    &  ACC$^\\uparrow$ &  ECE$_\\downarrow$  &  NLL$_\\downarrow$    &  Rank$_\\downarrow$\\\\')
print('\\midrule')
print(oneh_str)
print('\midrule')
print(LS_str)
print('\midrule')
print(MbLS_str)
print('\midrule')
print(Mxp_str)
print('\midrule')
# print(FL_str)
# print('\midrule')
print(DCA_str)
print('\midrule')
print('\midrule')
print(DE_str)
print('\midrule')
print('\midrule')
print(mh_str)
print('\midrule')
print(p2h_str)
print('\midrule')
print(p4h_str)
print('\\bottomrule')
print('\\\[-0.25cm]')
print('\\end{tabular}')
print(caption)
print('\\end{center}')
print('\\vspace{-1cm}')
print('\\end{table}')

# Kather

In [None]:
dataset='kather'

In [None]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

In [None]:
m_r18 = \
print_all(dataset, 'resnet18', yace_metric=yace_metric) 

In [None]:
m_mobilenet = \
print_all(dataset, 'mobilenet_v2', yace_metric=yace_metric)  # new w_max=4/2

In [None]:
m_r34 = \
print_all(dataset, 'resnet34', yace_metric=yace_metric) # new w_max=4/2

In [None]:
m_r50 = \
print_all(dataset, 'resnet50', yace_metric=yace_metric) # old w_max=3/2

In [None]:
m_cvx = \
print_all(dataset, 'convnext', yace_metric=yace_metric) # new w_max=4/2

In [None]:
oneh_str  ='\\textbf{OneH}     & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
DE_str  ='\\textbf{D-Ens}      & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
LS_str    ='\\textbf{LS}       & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
MbLS_str  ='\\textbf{MbLS}     & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
Mxp_str   ='\\textbf{MixUp}    & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
# FL_str    ='\\textbf{FL}       & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\\\'#  &  x13 $|$ x14 $|$ x15 \\'
DCA_str    ='\\textbf{DCA}     & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
mh_str    ='\\textbf{MH}       & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
p2h_str   ='\\textbf{P2H}      & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
p4h_str   ='\\textbf{P4H}      & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'

In [None]:
m_r50_with_ranks = add_acc_ece_combined_ranks(m_r50)
m_cvx_with_ranks = add_acc_ece_combined_ranks(m_cvx)
m_swt_with_ranks = add_acc_ece_combined_ranks(m_swt)

In [None]:
strings = [oneh_str, DE_str, LS_str, MbLS_str, Mxp_str, DCA_str, mh_str, p2h_str, p4h_str]

strings = fill_rows(m_mobilenet, 'xx1', 'xx2', 'xx3', strings)
strings = fill_rows(m_r18, 'xx4', 'xx5', 'xx6', strings)
strings = fill_rows(m_r34, 'xx7', 'xx8', 'xx9', strings)
strings = fill_rows(m_r50, 'x10', 'x11', 'x12', strings)
strings = fill_rows(m_cvx, 'x13', 'x14', 'x15', strings)

oneh_str, DE_str, LS_str, MbLS_str, Mxp_str, DCA_str, mh_str, p2h_str, p4h_str = strings

In [None]:
caption = r'\caption{Results on the \textbf{Kather dataset} with different architectures and strategies. \
For each model, \unl{\textbf{best}} and \textbf{second best} are marked. \
Performance is \textbf{NLL} $|$ \textbf{Accuracy} $|$ \textbf{ECE} ($\times100$) averaged over 5 training runs.}\label{kather}'

In [None]:
print('\\begin{table}[!t]')
print('\\renewcommand{\\arraystretch}{1.25}')
print('\\setlength\\tabcolsep{5.5pt}')
print('\\begin{center}')
print('\\begin{tabular}{cccccc}')
print('                  &  \\textbf{MobileNet}          &    \\textbf{ResNet18}        &    \\textbf{ResNet34}      &  \\textbf{ResNet50} &  \\textbf{ConvNeXt}\\\\')
print('\midrule')
print(oneh_str)
print('\midrule')
print(LS_str)
print('\midrule')
print(MbLS_str)
print('\midrule')
print(Mxp_str)
print('\midrule')
# print(FL_str)
# print('\midrule')
print(DCA_str)
print('\midrule')
print('\midrule')
print(DE_str)
print('\midrule')
print('\midrule')
print(mh_str)
print('\midrule')
print(p2h_str)
print('\midrule')
print(p4h_str)
print('\\bottomrule')
print('\\\[-0.25cm]')
print('\\end{tabular}')
print(caption)
print('\\end{center}')
print('\\vspace{-1cm}')
print('\\end{table}')

# Retinamnist

In [None]:
dataset='retinamnist'

In [None]:
# m_r18 = \
# print_all(dataset, 'resnet18', yace_metric=yace_metric) # new w_max=4/2

In [None]:
# m_mobilenet = \
# print_all(dataset, 'mobilenet_v2', yace_metric=yace_metric) # new w_max=4/2

In [None]:
# m_r34 = \
# print_all(dataset, 'resnet34', yace_metric=yace_metric) # new w_max=4/2

In [None]:
m_r50 = \
print_all(dataset, 'resnet50', yace_metric=yace_metric)

In [None]:
m_cvx = \
print_all(dataset, 'convnext', yace_metric=yace_metric) # new w_max=4/2

In [None]:
m_swt = \
print_all(dataset, 'swin', yace_metric=yace_metric)

In [None]:
oneh_str  ='\\textbf{OneH}     & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
DE_str  ='\\textbf{D-Ens}      & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
LS_str    ='\\textbf{LS}       & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
MbLS_str  ='\\textbf{MbLS}     & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
Mxp_str   ='\\textbf{MixUp}    & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
# FL_str    ='\\textbf{FL}       & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\\\'#  &  x13 $|$ x14 $|$ x15 \\'
DCA_str    ='\\textbf{DCA}     & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
mh_str    ='\\textbf{U2H}       & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
p2h_str   ='\\textbf{P2H}      & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
p4h_str   ='\\textbf{P4H}      & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'

In [None]:
m_r50_with_ranks = add_acc_ece_combined_ranks(m_r50)
m_cvx_with_ranks = add_acc_ece_combined_ranks(m_cvx)
m_swt_with_ranks = add_acc_ece_combined_ranks(m_swt)

# Dermamnist

In [None]:
dataset='dermamnist'

In [None]:
m_r18 = \
print_all(dataset, 'resnet18', yace_metric=yace_metric) # new w_max=4/2

In [None]:
m_mobilenet = \
print_all(dataset, 'mobilenet_v2', yace_metric=yace_metric) # new w_max=4/2

In [None]:
m_r34 = \
print_all(dataset, 'resnet34', yace_metric=yace_metric) # new w_max=4/2

In [None]:
m_r50 = \
print_all(dataset, 'resnet50', yace_metric=yace_metric)

In [None]:
m_cvx = \
print_all(dataset, 'convnext', yace_metric=yace_metric) # new w_max=4/2

In [None]:
oneh_str  ='\\textbf{OneH}     & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
DE_str  ='\\textbf{D-Ens}      & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
LS_str    ='\\textbf{LS}       & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
MbLS_str  ='\\textbf{MbLS}     & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
Mxp_str   ='\\textbf{MixUp}    & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
# FL_str    ='\\textbf{FL}       & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\\\'#  &  x13 $|$ x14 $|$ x15 \\'
DCA_str    ='\\textbf{DCA}     & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
mh_str    ='\\textbf{U2H}       & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
p2h_str   ='\\textbf{P2H}      & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'
p4h_str   ='\\textbf{P4H}      & xx1\\,$|$\\,xx2\\,$|$\\,xx3  &  xx4\\,$|$\\,xx5 $|$\\,xx6  &  xx7\\,$|$\\,xx8\\,$|$\\,xx9  &  x10\\,$|$\\,x11\\,$|$\\,x12 \\  &  x13 $|$ x14 $|$ x15 \\\\'

In [None]:
strings = [oneh_str, DE_str, LS_str, MbLS_str, Mxp_str, DCA_str, mh_str, p2h_str, p4h_str]

strings = fill_rows(m_mobilenet, 'xx1', 'xx2', 'xx3', strings)
strings = fill_rows(m_r18, 'xx4', 'xx5', 'xx6', strings)
strings = fill_rows(m_r34, 'xx7', 'xx8', 'xx9', strings)
strings = fill_rows(m_r50, 'x10', 'x11', 'x12', strings)
strings = fill_rows(m_cvx, 'x13', 'x14', 'x15', strings)

oneh_str, DE_str, LS_str, MbLS_str, Mxp_str, DCA_str, mh_str, p2h_str, p4h_str = strings

In [None]:
caption = r'\caption{Results on the \textbf{RetinaMnist dataset} with different architectures and strategies. \
For each model, \unl{\textbf{best}} and \textbf{second best} are marked. \
Performance is \textbf{ECEx} $|$ \textbf{Accuracy} $|$ \textbf{ECE} ($\times100$) averaged over 5 training runs.}\label{pathmnist}'

In [None]:
print('\\begin{table}[!t]')
print('\\renewcommand{\\arraystretch}{1.25}')
print('\\setlength\\tabcolsep{5.5pt}')
print('\\begin{center}')
print('\\begin{tabular}{cccccc}')
print('                  &  \\textbf{MobileNet}          &    \\textbf{ResNet18}        &    \\textbf{ResNet34}      &  \\textbf{ResNet50} &  \\textbf{ConvNeXt}\\\\')
print('\midrule')
print(oneh_str)
print('\midrule')
print(LS_str)
print('\midrule')
print(MbLS_str)
print('\midrule')
print(Mxp_str)
print('\midrule')
# print(FL_str)
# print('\midrule')
print(DCA_str)
print('\midrule')
print('\midrule')
print(DE_str)
print('\midrule')
print('\midrule')
print(mh_str)
print('\midrule')
print(p2h_str)
print('\midrule')
print(p4h_str)
print('\\bottomrule')
print('\\\[-0.25cm]')
print('\\end{tabular}')
print(caption)
print('\\end{center}')
print('\\vspace{-1cm}')
print('\\end{table}')