In [4]:
import json
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch

from collections import defaultdict
import os.path as osp
from typing import Union, Dict, List

from wilds.datasets.wilds_dataset import WILDSDataset

def get_eval_meta_args(logs, y_index):
    test_metadata= np.array(logs["test-metas"][0])
    y_true = test_metadata[:,y_index]
    test_logits = np.array(logs["test-logits"][0][0])
    y_pred = np.argmax(test_logits,axis=1)
    y_true = torch.tensor(y_true)
    y_pred= torch.tensor(y_pred)
    test_metadata= torch.tensor(test_metadata)

    return y_pred, y_true, test_metadata

def get_results(filename_format, dataset:WILDSDataset, meta_metrics:Union[List[str],None]=None):
    y_index=dataset.metadata_fields.index("y")

    no_diversity=False
    ensemble_size=2
    res= defaultdict(list)

    for seed in range(3):
        filename= filename_format.format(seed=seed)
        filename = osp.join(filename,"summary.json")

        if not(osp.exists(filename)):
            continue

        with open(filename) as f:
            logs=json.load(f)

            res["test_acc_ensemble"].append(logs["ensemble-test-acc"])
            #res["best_single_model_test_acc"].append(max(logs["test-acc"]))
            for i in range(ensemble_size):
                res[f"m_{i+1}_test_acc"].append(logs["test-acc"][i])
            res["test_similarity"].append(logs["test_similarity"][0][1])
            res["unlabeled_final_similarity"].append(logs["unlabeled_final_similarty"][0][1])

            ## worst group eval
            y_pred, y_true, test_metadata = get_eval_meta_args(logs=logs, y_index=y_index)

            eval_res = dataset.eval(y_pred=y_pred, y_true=y_true, metadata=test_metadata)
            for meta_metric in meta_metrics:
                res[meta_metric].append(eval_res[0][meta_metric])

    
    return res

def display_results(filename_format:str,title:str, dataset:WILDSDataset, meta_metrics:Union[List[str],None]=None ):

    res= get_results(filename_format=filename_format, dataset=dataset, meta_metrics=meta_metrics )
    df = pd.DataFrame(res)
    if df.empty:
        return
    results = df.aggregate(["mean","std"])

    print(title)
    pd.options.display.float_format = "{:,.3f}".format
    display(results)


## WATERBIRDS RESULTS

In [5]:
from wilds.datasets.waterbirds_dataset import WaterbirdsDataset

def display_all_waterbirds_results():

    dataset=WaterbirdsDataset(root_dir="./datasets/", download=False)
    meta_metrics=["acc_wg"]
    waterbirds_filename_format= {
        "d_bat": "/datasets/home/hbenoit/D-BAT-exp/results_reproduction/waterbird/perturb=ood_is_test/resnet50_pretrained=True/ep300/ep=300_lrmax=0.001_alpha=0.0001_dataset=waterbird_perturb_type=ood_is_test_model=resnet50_pretrained=True_scheduler=none_seed={seed}_opt=sgd_ensemble_size=2_no_diversity=False_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=False",
        "erm": "/datasets/home/hbenoit/D-BAT-exp/results_reproduction/waterbird/perturb=ood_is_test/resnet50_pretrained=True/ep300/ep=300_lrmax=0.001_alpha=0.0_dataset=waterbird_perturb_type=ood_is_test_model=resnet50_pretrained=True_scheduler=none_seed={seed}_opt=sgd_ensemble_size=2_no_diversity=True_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=False",
        "np_d_bat": "/datasets/home/hbenoit/D-BAT-exp/results_reproduction/waterbird/perturb=ood_is_test/resnet50_pretrained=False/ep300/ep=300_lrmax=0.001_alpha=0.0001_dataset=waterbird_perturb_type=ood_is_test_model=resnet50_pretrained=False_scheduler=none_seed={seed}_opt=sgd_ensemble_size=2_no_diversity=False_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=False",
        "np_erm": "/datasets/home/hbenoit/D-BAT-exp/results_reproduction/waterbird/perturb=ood_is_test/resnet50_pretrained=False/ep300/ep=300_lrmax=0.001_alpha=0.0_dataset=waterbird_perturb_type=ood_is_test_model=resnet50_pretrained=False_scheduler=none_seed={seed}_opt=sgd_ensemble_size=2_no_diversity=True_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=False",
    }
    
    for name, filename_format in waterbirds_filename_format.items():
        title = str.upper(name)
        display_results(filename_format=filename_format, title=title, dataset=dataset, meta_metrics=meta_metrics)


display_all_waterbirds_results()


D_BAT


Unnamed: 0,test_acc_ensemble,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity,acc_wg
mean,0.894,0.898,0.886,0.969,0.968,0.628
std,0.002,0.004,0.003,0.006,0.006,0.022


ERM


Unnamed: 0,test_acc_ensemble,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity,acc_wg
mean,0.885,0.885,0.884,0.978,0.979,0.641
std,0.002,0.0,0.002,0.006,0.005,0.013


NP_D_BAT


Unnamed: 0,test_acc_ensemble,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity,acc_wg
mean,0.66,0.664,0.661,0.871,0.861,0.083
std,0.012,0.013,0.004,0.004,0.019,0.009


NP_ERM


Unnamed: 0,test_acc_ensemble,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity,acc_wg
mean,0.664,0.623,0.719,0.802,0.806,0.128
std,,,,,,


## CAMELYON17 RESULTS

In [11]:
from wilds.datasets.camelyon17_dataset import Camelyon17Dataset


def display_all_camelyon_results():

    dataset=Camelyon17Dataset(root_dir="./datasets/", download=False)
    meta_metrics=[]
    camelyon_filename_format= {
        "d_bat_is_test": "/datasets/home/hbenoit/D-BAT-exp/results_reproduction/camelyon17/perturb=ood_is_test/resnet50_pretrained=True/ep60/ep=60_lrmax=0.001_alpha=1e-06_dataset=camelyon17_perturb_type=ood_is_test_model=resnet50_pretrained=True_scheduler=none_seed={seed}_opt=sgd_ensemble_size=2_no_diversity=False_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=False",
        "d_bat_is_not_test": "/datasets/home/hbenoit/D-BAT-exp/results_reproduction/camelyon17/perturb=ood_is_not_test/resnet50_pretrained=True/ep60/ep=60_lrmax=0.001_alpha=1e-06_dataset=camelyon17_perturb_type=ood_is_not_test_model=resnet50_pretrained=True_scheduler=none_seed={seed}_opt=sgd_ensemble_size=2_no_diversity=False_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=False",
        "erm": "/datasets/home/hbenoit/D-BAT-exp/results_reproduction/camelyon17/perturb=ood_is_test/resnet50_pretrained=True/ep60/ep=60_lrmax=0.001_alpha=0.0_dataset=camelyon17_perturb_type=ood_is_test_model=resnet50_pretrained=True_scheduler=none_seed={seed}_opt=sgd_ensemble_size=2_no_diversity=True_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=False",
    }
    
    print("CAMLEYON17 RESULTS")
    for name, filename_format in camelyon_filename_format.items():
        title = str.upper(name)
        display_results(filename_format=filename_format, title=title, dataset=dataset, meta_metrics=meta_metrics)


display_all_camelyon_results()


CAMLEYON17 RESULTS
D_BAT_IS_TEST


Unnamed: 0,test_acc_ensemble,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity
mean,0.867,0.793,0.922,0.833,0.733
std,0.01,0.003,0.008,0.006,0.012


D_BAT_IS_NOT_TEST


Unnamed: 0,test_acc_ensemble,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity
mean,0.921,0.888,0.892,0.854,0.824
std,,,,,


ERM


Unnamed: 0,test_acc_ensemble,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity
mean,0.937,0.929,0.935,0.955,0.943
std,,,,,


In [9]:
### Testing
## TEST METRICS
test_acc_ensemble = logs["ensemble-test-acc"]
best_single_model_test_acc = max(logs["test-acc"])
test_acc_of_sub_ensembles = logs["test_acc_ensemble_per_ens_size"]

ensemble_size = 2
best_val_acc_per_model = {f"m{i}": max([x[1] for x in logs[f"m{i}"]["valid-acc"]]) for i in range(1,ensemble_size+1)}
best_val_acc_per_model

NameError: name 'logs' is not defined