In [73]:
import json
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch

In [95]:
from wilds.datasets.waterbirds_dataset import WaterbirdsDataset
from wilds.datasets.wilds_dataset import WILDSDataset

### Testing

In [None]:
## TEST METRICS
test_acc_ensemble = logs["ensemble-test-acc"]
best_single_model_test_acc = max(logs["test-acc"])
test_acc_of_sub_ensembles = logs["test_acc_ensemble_per_ens_size"]

ensemble_size = 2
best_val_acc_per_model = {f"m{i}": max([x[1] for x in logs[f"m{i}"]["valid-acc"]]) for i in range(1,ensemble_size+1)}
best_val_acc_per_model

In [107]:
from collections import defaultdict
import os.path as osp
from typing import Union, Dict, List

def get_eval_meta_args(logs, y_index):
    test_metadata= np.array(logs["test-metas"][0])
    y_true = test_metadata[:,y_index]
    test_logits = np.array(logs["test-logits"][0][0])
    y_pred = np.argmax(test_logits,axis=1)
    y_true = torch.tensor(y_true)
    y_pred= torch.tensor(y_pred)
    test_metadata= torch.tensor(test_metadata)

    return y_pred, y_true, test_metadata

def results(filename_format, dataset:WILDSDataset, meta_metrics:Union[List[str],None]=None):
    y_index=dataset.metadata_fields.index("y")

    no_diversity=False
    ensemble_size=2
    res= defaultdict(list)

    for seed in range(3):
        filename= filename_format.format(seed=seed)
        filename = osp.join(filename,"summary.json")

        with open(filename) as f:
            logs=json.load(f)

            res["test_acc_ensemble"].append(logs["ensemble-test-acc"])
            res["best_single_model_test_acc"].append(max(logs["test-acc"]))
            for i in range(ensemble_size):
                res[f"m_{i+1}_test_acc"].append(logs["test-acc"][i])
            res["test_similarity"].append(logs["test_similarity"][0][1])
            res["unlabeled_final_similarity"].append(logs["unlabeled_final_similarty"][0][1])

            ## worst group eval
            y_pred, y_true, test_metadata = get_eval_meta_args(logs=logs, y_index=y_index)

            eval_res = dataset.eval(y_pred=y_pred, y_true=y_true, metadata=test_metadata)
            for meta_metric in meta_metrics:
                res[meta_metric].append(eval_res[0][meta_metric])

    
    return res

## WATERBIRDS OOD_IS_TEST REPRODUCTION RESULTS

In [109]:
filename_format = "/datasets/home/hbenoit/D-BAT-exp/results_reproduction/waterbird/perturb=ood_is_test/resnet50_pretrained=True/ep300/ep=300_lrmax=0.001_alpha=0.0001_dataset=waterbird_perturb_type=ood_is_test_model=resnet50_pretrained=True_scheduler=none_seed={seed}_opt=sgd_ensemble_size=2_no_diversity=False_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=False"
dataset=WaterbirdsDataset(root_dir="./datasets/", download=False)
res= results(filename_format=filename_format, dataset=dataset, meta_metrics=["acc_wg"], )
waterbirds_df = pd.DataFrame(res)
pd.options.display.float_format = "{:,.3f}".format
waterbirds_df.aggregate(["mean","std"])

Unnamed: 0,test_acc_ensemble,best_single_model_test_acc,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity,acc_wg
mean,0.894,0.898,0.898,0.886,0.969,0.968,0.628
std,0.002,0.004,0.004,0.003,0.006,0.006,0.022


## WATERBIRDS ERM REPRODUCTION RESULTS

In [111]:
filename_format = "/datasets/home/hbenoit/D-BAT-exp/results_reproduction/waterbird/perturb=ood_is_test/resnet50_pretrained=True/ep300/ep=300_lrmax=0.001_alpha=0.0_dataset=waterbird_perturb_type=ood_is_test_model=resnet50_pretrained=True_scheduler=none_seed={seed}_opt=sgd_ensemble_size=2_no_diversity=True_dbat_loss_type=v1_weight_decay=0.0001_no_nesterov_majority_only=False"
dataset=WaterbirdsDataset(root_dir="./datasets/", download=False)
res= results(filename_format=filename_format, dataset=dataset, meta_metrics=["acc_wg"])
waterbirds_df = pd.DataFrame(res)
pd.options.display.float_format = "{:,.3f}".format
waterbirds_df.aggregate(["mean","std"])

Unnamed: 0,test_acc_ensemble,best_single_model_test_acc,m_1_test_acc,m_2_test_acc,test_similarity,unlabeled_final_similarity,acc_wg
mean,0.885,0.885,0.885,0.884,0.978,0.979,0.641
std,0.002,0.0,0.0,0.002,0.006,0.005,0.013
