In [6]:
import json
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch

from collections import defaultdict
import os.path as osp
from typing import Union, Dict, List
from itertools import combinations


def get_results(filename_format, seeds:List[int], ensemble_size:int, inverse:bool, show_val:bool):

    no_diversity=False
    res= defaultdict(list)
    
    for seed in seeds:
        filename= filename_format.format(seed=seed)
        filename = osp.join(filename,"summary.json")

        if not(osp.exists(filename)):
            continue

        with open(filename) as f:
            logs=json.load(f)

            res["test_acc_ensemble"].append(logs["ensemble-test-acc"])
            #res["best_single_model_test_acc"].append(max(logs["test-acc"]))

            for i in range(ensemble_size):
                res[f"test_m_{i+1}_acc"].append(logs["test-acc"][i])
                #res[f"val_m_{i+1}_acc"].append(logs[f"m{i+1}"]["valid-acc"][-1][1])

            pairwise_indexes = list(combinations(range(ensemble_size),2))
            test_sim = []
            unlabeled_sim = []
            for pairwise_idx in pairwise_indexes:
                i,j = pairwise_idx
                test_sim.append(logs["test_similarity"][i][j])
                unlabeled_sim.append(logs["unlabeled_final_similarty"][i][j])

            res["test_similarity"].append(np.array(test_sim).mean())
            res["unlabeled_final_similarity"].append(np.array(unlabeled_sim).mean())
    
    return res

def display_results(filename_format:str,title:str, seeds:List[int], ensemble_size=2 , inverse = False, show_val = False):

    res= get_results(filename_format=filename_format, seeds=seeds , ensemble_size=ensemble_size, inverse = inverse, show_val=show_val)
    df = pd.DataFrame(res)
    if df.empty:
        return df, False
    results = df.aggregate(["mean","std"])


    return results, True


# Office Home P vs NP

In [18]:
from wilds.datasets.waterbirds_dataset import WaterbirdsDataset

def display_office_home_results():
    
    for alpha in [0.00001, 0.1]:
        for epoch in [30, 50, 60, 100, 300]:
            waterbirds_filename_format= {
                #"DBAT_NP": f"/datasets/home/hbenoit/D-BAT-exp/office_home/oh-65cls/perturb=ood_is_test/resnet50_pretrained=False/ep[{epoch}]/ep=[{epoch}]_lrmax=0.001_alpha={alpha}_dataset=oh-65cls_model=resnet50_pretrained=False_scheduler=none_seed={{seed}}_opt=sgd_ensemble_size=2_no_diversity=False",
                #"DBAT": f"/datasets/home/hbenoit/D-BAT-exp/office_home/oh-65cls/perturb=ood_is_test/resnet50_pretrained=True/ep[{epoch}]/ep=[{epoch}]_lrmax=0.001_alpha={alpha}_dataset=oh-65cls_model=resnet50_pretrained=True_scheduler=none_seed={{seed}}_opt=sgd_ensemble_size=2_no_diversity=False",
                #"VIT_B_16": f"/datasets/home/hbenoit/D-BAT-exp/office_home/oh-65cls/perturb=ood_is_test/['vit_b_16']_pretrained=True/ep[{epoch}]/ep=[{epoch}]_lrmax=0.001_alpha={alpha}_dataset=oh-65cls_model=['vit_b_16']_pretrained=True_scheduler=none_seed={{seed}}_opt=sgd_ensemble_size=2_no_diversity=False",
                #"VIT_B_16_NP": f"/datasets/home/hbenoit/D-BAT-exp/office_home/oh-65cls/perturb=ood_is_test/['vit_b_16']_pretrained=False/ep[{epoch}]/ep=[{epoch}]_lrmax=0.001_alpha={alpha}_dataset=oh-65cls_model=['vit_b_16']_pretrained=False_scheduler=none_seed={{seed}}_opt=sgd_ensemble_size=2_no_diversity=False",
                #"VIT_MAE": f"/datasets/home/hbenoit/D-BAT-exp/office_home/oh-65cls/perturb=ood_is_test/['vit_mae']_pretrained=True/ep[{epoch}]/ep=[{epoch}]_lrmax=0.001_alpha={alpha}_dataset=oh-65cls_model=['vit_mae']_pretrained=True_scheduler=none_seed={{seed}}_opt=sgd_ensemble_size=2_no_diversity=False",
                "":""
            }
            for model in ["vit_mae", "vit_dino", "robust_resnet50", "resnet50SwAV","resnet50SIMCLRv2", "resnet50MocoV2"]:
                waterbirds_filename_format[model] = f"/datasets/home/hbenoit/D-BAT-exp/office_home/oh-65cls/perturb=ood_is_test/['{model}']_pretrained=True/ep[{epoch}]/ep=[{epoch}]_lrmax=0.001_alpha={alpha}_dataset=oh-65cls_model=['{model}']_pretrained=True_scheduler=none_seed={{seed}}_opt=sgd_ensemble_size=2_no_diversity=False"


            for name, filename_format in waterbirds_filename_format.items():
                title = str.upper(name)
                df, can_display = display_results(filename_format=filename_format, title=title, seeds=[0,1,2], show_val=False)
                if can_display:
                    print(f"EPOCHS {epoch} alpha {alpha}")
                    print(title)
                    pd.options.display.float_format = "{:,.3f}".format
                    display(df)



display_office_home_results()

EPOCHS 30 alpha 1e-05
VIT_MAE


Unnamed: 0,test_acc_ensemble,test_m_1_acc,test_m_2_acc,test_similarity,unlabeled_final_similarity
mean,,0.009,0.02,0.002,0.002
std,,0.008,0.001,0.002,0.004


EPOCHS 30 alpha 1e-05
VIT_DINO


Unnamed: 0,test_acc_ensemble,test_m_1_acc,test_m_2_acc,test_similarity,unlabeled_final_similarity
mean,,0.013,0.017,0.015,0.012
std,,0.007,0.004,0.004,0.004


EPOCHS 30 alpha 1e-05
ROBUST_RESNET50


Unnamed: 0,test_acc_ensemble,test_m_1_acc,test_m_2_acc,test_similarity,unlabeled_final_similarity
mean,0.562,0.536,0.555,0.7,0.7
std,0.001,0.007,0.007,0.003,0.011


EPOCHS 30 alpha 1e-05
RESNET50SWAV


Unnamed: 0,test_acc_ensemble,test_m_1_acc,test_m_2_acc,test_similarity,unlabeled_final_similarity
mean,0.225,0.206,0.213,0.459,0.457
std,0.004,0.014,0.009,0.008,0.011


EPOCHS 30 alpha 1e-05
RESNET50SIMCLRV2


Unnamed: 0,test_acc_ensemble,test_m_1_acc,test_m_2_acc,test_similarity,unlabeled_final_similarity
mean,0.465,0.439,0.461,0.687,0.684
std,0.009,0.006,0.005,0.009,0.008


EPOCHS 30 alpha 1e-05
RESNET50MOCOV2


Unnamed: 0,test_acc_ensemble,test_m_1_acc,test_m_2_acc,test_similarity,unlabeled_final_similarity
mean,0.292,0.252,0.271,0.405,0.438
std,0.004,0.015,0.012,0.049,0.041
