In [2]:
import json
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch

from collections import defaultdict
import os.path as osp
from typing import Union, Dict, List

from wilds.datasets.wilds_dataset import WILDSDataset

def get_eval_meta_args(logs, y_index, m_idx):
    test_metadata= np.array(logs["test-metas"][0])
    y_true = test_metadata[:,y_index]
    test_logits = np.array(logs["test-logits"][m_idx][0])
    y_pred = np.argmax(test_logits,axis=1)
    y_true = torch.tensor(y_true)
    y_pred= torch.tensor(y_pred)
    test_metadata= torch.tensor(test_metadata)

    return y_pred, y_true, test_metadata

def get_results(filename_format, dataset:WILDSDataset, seeds:List[int], meta_metrics:Union[List[str],None]=None):
    y_index=dataset.metadata_fields.index("y")

    no_diversity=False
    ensemble_size=2
    res= defaultdict(list)
    
    for seed in seeds:
        filename= filename_format.format(seed=seed)
        filename = osp.join(filename,"summary.json")

        if not(osp.exists(filename)):
            continue

        with open(filename) as f:
            logs=json.load(f)

            res["test_acc_ensemble"].append(logs["ensemble-test-acc"])
            #res["best_single_model_test_acc"].append(max(logs["test-acc"]))
            for i in range(ensemble_size):
                res[f"m_{i+1}_test_acc"].append(logs["test-acc"][i])
            res["test_similarity"].append(logs["test_similarity"][0][1])
            res["unlabeled_final_similarity"].append(logs["unlabeled_final_similarty"][0][1])

            ## worst group eval
            for m_idx in range(ensemble_size):
                y_pred, y_true, test_metadata = get_eval_meta_args(logs=logs, y_index=y_index, m_idx=m_idx)
                eval_res = dataset.eval(y_pred=y_pred, y_true=y_true, metadata=test_metadata)
                for meta_metric in meta_metrics:
                    res[f"m_{m_idx+1}_{meta_metric}"].append(eval_res[0][meta_metric])

    
    return res

def display_results(filename_format:str,title:str, seeds:List[int], dataset:WILDSDataset, meta_metrics:Union[List[str],None]=None ):

    res= get_results(filename_format=filename_format, dataset=dataset, meta_metrics=meta_metrics, seeds=seeds )
    df = pd.DataFrame(res)
    if df.empty:
        return
    results = df.aggregate(["mean","std"])

    print(title)
    pd.options.display.float_format = "{:,.3f}".format
    display(results)


In [10]:
res_path = "/datasets/home/liang/D-BAT-exp/results_reproduction_waterbird_cc/waterbird/perturb=ood_is_test/resnet50_pretrained=True/ep300/ep=300_lrmax=0.001_alpha=0.0001_dataset=waterbird_perturb_type=ood_is_test_model=resnet50_pretrained=True_scheduler=none_seed=0_opt=sgd_ensemble_size=2_no_diversity=False_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=True"
filename = osp.join(res_path,"summary.json")
with open(filename) as f:
    logs = json.load(f)

    test_logits = np.array(logs["test-logits"][1][0])


test_logits.shape

(512, 2)

## WATERBIRDS RESULTS

In [14]:
from wilds.datasets.waterbirds_dataset import WaterbirdsDataset

def display_all_waterbirds_results():

    dataset=WaterbirdsDataset(root_dir="./datasets/", download=False)
    meta_metrics=["acc_wg"]            f"d_bat": f"/datasets/home/hbenoit/D-BAT-exp/alpha=0.1_train/results_reproduction2/waterbird/perturb=ood_is_test/resnet50_pretrained=True/ep{epoch}/ep={epoch}_lrmax=0.001_alpha=0.1_dataset=waterbird_perturb_type=ood_is_test_model=resnet50_pretrained=True_scheduler=none_seed={{seed}}_opt=sgd_ensemble_size=2_no_diversity=False_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=True",

        "d_bat": "/datasets/home/liang/D-BAT-exp/results_reproduction1/waterbird/perturb=ood_is_test/resnet50_pretrained=True/ep300/ep=300_lrmax=0.001_alpha=0.0001_dataset=waterbird_perturb_type=ood_is_test_model=resnet50_pretrained=True_scheduler=none_seed={seed}_opt=sgd_ensemble_size=2_no_diversity=False_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=False",
        "np_d_bat": "/datasets/home/liang/D-BAT-exp/results_reproduction1/waterbird/perturb=ood_is_test/resnet50_pretrained=False/ep300/ep=300_lrmax=0.001_alpha=0.0001_dataset=waterbird_perturb_type=ood_is_test_model=resnet50_pretrained=False_scheduler=none_seed={seed}_opt=sgd_ensemble_size=2_no_diversity=False_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=False",
    }

    #"np_erm": "/datasets/home/liang/D-BAT-exp/results_reproduction1/waterbird/perturb=ood_is_test/resnet50_pretrained=False/ep300/ep=300_lrmax=0.001_alpha=0.0_dataset=waterbird_perturb_type=ood_is_test_model=resnet50_pretrained=False_scheduler=none_seed={seed}_opt=sgd_ensemble_size=2_no_diversity=True_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=False",
    #"erm": "/datasets/home/liang/D-BAT-exp/results_reproduction1/waterbird/perturb=ood_is_test/resnet50_pretrained=True/ep300/ep=300_lrmax=0.001_alpha=0.0_dataset=waterbird_perturb_type=ood_is_test_model=resnet50_pretrained=True_scheduler=none_seed={seed}_opt=sgd_ensemble_size=2_no_diversity=True_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=False",

    for name, filename_format in waterbirds_filename_format.items():
        title = str.upper(name)
        display_results(filename_format=filename_format, title=title, dataset=dataset, meta_metrics=meta_metrics, seeds=[0,1,2])

    for epoch in [30, 300]:
        print(f"EPOCH {epoch}, alpha 0.1\n")        
        waterbirds_filename_format= {
            "d_bat": f"/datasets/home/hbenoit/D-BAT-exp/alpha=0.1_train/results_reproduction2/waterbird/perturb=ood_is_test/resnet50_pretrained=True/ep{epoch}/ep={epoch}_lrmax=0.001_alpha=0.1_dataset=waterbird_perturb_type=ood_is_test_model=resnet50_pretrained=True_scheduler=none_seed={{seed}}_opt=sgd_ensemble_size=2_no_diversity=False_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=False",
            "np_d_bat": f"/datasets/home/hbenoit/D-BAT-exp/alpha=0.1_train/results_reproduction2/waterbird/perturb=ood_is_test/resnet50_pretrained=False/ep{epoch}/ep={epoch}_lrmax=0.001_alpha=0.1_dataset=waterbird_perturb_type=ood_is_test_model=resnet50_pretrained=False_scheduler=none_seed={{seed}}_opt=sgd_ensemble_size=2_no_diversity=False_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=False",
        }

        for name, filename_format in waterbirds_filename_format.items():
            title = str.upper(name)
            display_results(filename_format=filename_format, title=title, dataset=dataset, meta_metrics=meta_metrics, seeds=[0,1,2])


display_all_waterbirds_results()


EPOCH 300, alpha 0.0001

D_BAT


Unnamed: 0,test_acc_ensemble,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity,m_1_acc_wg,m_2_acc_wg
mean,0.876,0.863,0.885,0.954,0.943,0.697,0.645
std,0.002,0.005,0.005,0.009,0.014,0.02,0.02


NP_D_BAT


Unnamed: 0,test_acc_ensemble,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity,m_1_acc_wg,m_2_acc_wg
mean,0.629,0.625,0.632,0.864,0.855,0.103,0.098
std,0.013,0.017,0.01,0.003,0.003,0.013,0.02


EPOCH 30, alpha 0.1

D_BAT


Unnamed: 0,test_acc_ensemble,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity,m_1_acc_wg,m_2_acc_wg
mean,0.82,0.855,0.728,0.797,0.039,0.679,0.5
std,0.009,0.006,0.023,0.016,0.008,0.013,0.062


NP_D_BAT


Unnamed: 0,test_acc_ensemble,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity,m_1_acc_wg,m_2_acc_wg
mean,0.623,0.645,0.613,0.79,0.784,0.056,0.103
std,0.028,0.073,0.047,0.027,0.022,0.041,0.046


EPOCH 300, alpha 0.1

D_BAT


Unnamed: 0,test_acc_ensemble,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity,m_1_acc_wg,m_2_acc_wg
mean,0.812,0.863,0.672,0.736,0.013,0.697,0.427
std,0.011,0.005,0.049,0.052,0.002,0.02,0.081


NP_D_BAT


Unnamed: 0,test_acc_ensemble,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity,m_1_acc_wg,m_2_acc_wg
mean,0.638,0.616,0.636,0.719,0.161,0.16,0.269
std,0.001,0.016,0.004,0.058,0.144,0.063,0.018


# WATERBIRDS_CC RESULTS

In [17]:
from wilds.datasets.waterbirds_dataset import WaterbirdsDataset

def display_all_waterbirds_cc_results():

    dataset=WaterbirdsDataset(root_dir="./datasets/", download=False)
    meta_metrics=["acc_wg"]
    ##alpha=0.0001
    print("EPOCH 300, ALPHA 0.0001")
    waterbirds_filename_format= {
        "d_bat": "/datasets/home/liang/D-BAT-exp/results_reproduction_waterbird_cc/waterbird/perturb=ood_is_test/resnet50_pretrained=True/ep300/ep=300_lrmax=0.001_alpha=0.0001_dataset=waterbird_perturb_type=ood_is_test_model=resnet50_pretrained=True_scheduler=none_seed={seed}_opt=sgd_ensemble_size=2_no_diversity=False_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=True",
        "d_bat_np": "/datasets/home/liang/D-BAT-exp/results_reproduction_waterbird_cc/waterbird/perturb=ood_is_test/resnet50_pretrained=False/ep300/ep=300_lrmax=0.001_alpha=0.0001_dataset=waterbird_perturb_type=ood_is_test_model=resnet50_pretrained=False_scheduler=none_seed={seed}_opt=sgd_ensemble_size=2_no_diversity=False_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=True",
    }

    for name, filename_format in waterbirds_filename_format.items():
        title = str.upper(name)
        display_results(filename_format=filename_format, title=title, dataset=dataset, meta_metrics=meta_metrics, seeds=[0,1,2])


    for epoch in [30,300]:
        print(f"EPOCH {epoch}, alpha 0.1\n")        
        waterbirds_filename_format= {
            f"d_bat": f"/datasets/home/hbenoit/D-BAT-exp/alpha=0.1_train/results_reproduction2/waterbird/perturb=ood_is_test/resnet50_pretrained=True/ep{epoch}/ep={epoch}_lrmax=0.001_alpha=0.1_dataset=waterbird_perturb_type=ood_is_test_model=resnet50_pretrained=True_scheduler=none_seed={{seed}}_opt=sgd_ensemble_size=2_no_diversity=False_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=True",
            f"d_bat_np": f"/datasets/home/hbenoit/D-BAT-exp/alpha=0.1_train/results_reproduction2/waterbird/perturb=ood_is_test/resnet50_pretrained=False/ep{epoch}/ep={epoch}_lrmax=0.001_alpha=0.1_dataset=waterbird_perturb_type=ood_is_test_model=resnet50_pretrained=False_scheduler=none_seed={{seed}}_opt=sgd_ensemble_size=2_no_diversity=False_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=True",
        }

        for name, filename_format in waterbirds_filename_format.items():
            title = str.upper(name)
            display_results(filename_format=filename_format, title=title, dataset=dataset, meta_metrics=meta_metrics, seeds=[0,1,2])


display_all_waterbirds_cc_results()

EPOCH 300, ALPHA 0.0001
D_BAT


Unnamed: 0,test_acc_ensemble,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity,m_1_acc_wg,m_2_acc_wg
mean,0.68076,0.6478,0.71959,0.89213,0.90127,0.28755,0.30769
std,0.00597,0.00793,0.00693,0.01344,0.01803,0.03026,0.01282


D_BAT_NP


Unnamed: 0,test_acc_ensemble,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity,m_1_acc_wg,m_2_acc_wg
mean,0.58244,0.56173,0.61103,0.88293,0.89003,0.08974,0.04701
std,0.0089,0.00991,0.02415,0.03146,0.04605,0.02564,0.0074


EPOCH 30, alpha 0.1

D_BAT


Unnamed: 0,test_acc_ensemble,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity,m_1_acc_wg,m_2_acc_wg
mean,0.75469,0.649,0.78409,0.7224,0.0838,0.31136,0.65385
std,0.01103,0.00872,0.01238,0.02685,0.05454,0.03898,0.01981


D_BAT_NP


Unnamed: 0,test_acc_ensemble,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity,m_1_acc_wg,m_2_acc_wg
mean,0.54275,0.52134,0.60983,0.77758,0.74265,0.13187,0.13492
std,0.02023,0.0154,0.11071,0.16521,0.15591,0.01648,0.09356


EPOCH 300, alpha 0.1

D_BAT


Unnamed: 0,test_acc_ensemble,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity,m_1_acc_wg,m_2_acc_wg
mean,0.7609,0.64785,0.74209,0.67156,0.02895,0.28571,0.60256
std,0.0209,0.0078,0.00541,0.05121,0.00767,0.02747,0.07718


D_BAT_NP


Unnamed: 0,test_acc_ensemble,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity,m_1_acc_wg,m_2_acc_wg
mean,0.5955,0.5596,0.66471,0.74264,0.23065,0.09402,0.15385
std,0.01905,0.00527,0.03908,0.03169,0.0703,0.0148,0.0


## CAMELYON17 RESULTS

In [3]:
from wilds.datasets.camelyon17_dataset import Camelyon17Dataset


def display_all_camelyon_results():

    dataset=Camelyon17Dataset(root_dir="./datasets/", download=False)
    meta_metrics=[]
    #camelyon_filename_format= {
    #    "d_bat_is_test": "/datasets/home/hbenoit/D-BAT-exp/results_reproduction/camelyon17/perturb=ood_is_test/resnet50_pretrained=True/ep60/ep=60_lrmax=0.001_alpha=1e-06_dataset=camelyon17_perturb_type=ood_is_test_model=resnet50_pretrained=True_scheduler=none_seed={seed}_opt=sgd_ensemble_size=2_no_diversity=False_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=False",
    #    "d_bat_is_not_test": "/datasets/home/hbenoit/D-BAT-exp/results_reproduction/camelyon17/perturb=ood_is_not_test/resnet50_pretrained=True/ep60/ep=60_lrmax=0.001_alpha=1e-06_dataset=camelyon17_perturb_type=ood_is_not_test_model=resnet50_pretrained=True_scheduler=none_seed={seed}_opt=sgd_ensemble_size=2_no_diversity=False_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=False",
    #    "erm": "/datasets/home/hbenoit/D-BAT-exp/results_reproduction/camelyon17/perturb=ood_is_test/resnet50_pretrained=True/ep60/ep=60_lrmax=0.001_alpha=0.0_dataset=camelyon17_perturb_type=ood_is_test_model=resnet50_pretrained=True_scheduler=none_seed={seed}_opt=sgd_ensemble_size=2_no_diversity=True_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=False",
    #}

    camelyon_filename_format= {
        "d_bat_np":"/datasets/home/liang/D-BAT-exp/results_reproduction1/camelyon17/perturb=ood_is_test/resnet50_pretrained=False/ep60/ep=60_lrmax=0.001_alpha=1e-06_dataset=camelyon17_perturb_type=ood_is_test_model=resnet50_pretrained=False_scheduler=none_seed={seed}_opt=sgd_ensemble_size=2_no_diversity=False_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=False",
        "d_bat": "/datasets/home/liang/D-BAT-exp/results_reproduction1/camelyon17/perturb=ood_is_test/resnet50_pretrained=True/ep60/ep=60_lrmax=0.001_alpha=1e-06_dataset=camelyon17_perturb_type=ood_is_test_model=resnet50_pretrained=True_scheduler=none_seed={seed}_opt=sgd_ensemble_size=2_no_diversity=False_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=False",
    }
    
    print("CAMLEYON17 RESULTS")
    for name, filename_format in camelyon_filename_format.items():
        title = str.upper(name)
        display_results(filename_format=filename_format, title=title, dataset=dataset, meta_metrics=meta_metrics, seeds=[0,1,2])


display_all_camelyon_results()


CAMLEYON17 RESULTS
D_BAT_NP


Unnamed: 0,test_acc_ensemble,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity
mean,0.726,0.581,0.813,0.652,0.791
std,0.023,0.059,0.014,0.082,0.019


D_BAT


Unnamed: 0,test_acc_ensemble,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity
mean,0.918,0.812,0.936,0.831,0.928
std,0.006,0.014,0.004,0.013,0.007


# MULTI HEADS WATERBIRDS-CC

In [5]:
from wilds.datasets.waterbirds_dataset import WaterbirdsDataset

def display_all_waterbirds_cc_results():

    dataset=WaterbirdsDataset(root_dir="./datasets/", download=False)
    meta_metrics=["acc_wg"]
    ##alpha=0.0001


    for heads in [2,3]:
        waterbirds_filename_format= {
            f"d_bat_np": f"/datasets/home/hbenoit/D-BAT-exp/alpha=0.1_train/grey/h{heads}/waterbird/perturb=ood_is_test/resnet50_pretrained=False/ep300/ep=300_lrmax=0.001_alpha=0.1_dataset=waterbird_perturb_type=ood_is_test_model=resnet50_pretrained=False_scheduler=none_seed={{seed}}_opt=sgd_ensemble_size={heads}_no_diversity=False_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=True",
        }

        for name, filename_format in waterbirds_filename_format.items():
            title = str.upper(name)
            display_results(filename_format=filename_format, title=title, dataset=dataset, meta_metrics=meta_metrics, seeds=[0,1,2])




display_all_waterbirds_cc_results()

D_BAT_NP


Unnamed: 0,test_acc_ensemble,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity,m_1_acc_wg,m_2_acc_wg
mean,0.606,0.567,0.676,0.706,0.234,0.073,0.171
std,0.011,0.011,0.025,0.009,0.073,0.03,0.052


In [9]:
### Testing
## TEST METRICS
test_acc_ensemble = logs["ensemble-test-acc"]
best_single_model_test_acc = max(logs["test-acc"])
test_acc_of_sub_ensembles = logs["test_acc_ensemble_per_ens_size"]

ensemble_size = 2
best_val_acc_per_model = {f"m{i}": max([x[1] for x in logs[f"m{i}"]["valid-acc"]]) for i in range(1,ensemble_size+1)}
best_val_acc_per_model

NameError: name 'logs' is not defined