In [2]:
import json
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch

from collections import defaultdict
import os.path as osp
from typing import Union, Dict, List
from itertools import combinations


def get_results(filename_format, seeds:List[int], ensemble_size:int, inverse:bool, show_val:bool):

    no_diversity=False
    res= defaultdict(list)
    
    for seed in seeds:
        filename= filename_format.format(seed=seed)
        filename = osp.join(filename,"summary.json")

        if not(osp.exists(filename)):
            continue

        with open(filename) as f:
            logs=json.load(f)

            res["test_acc_ensemble"].append(logs["ensemble-test-acc"])
            #res["best_single_model_test_acc"].append(max(logs["test-acc"]))

            for i in range(ensemble_size):
                res[f"test_m_{i+1}_acc"].append(logs["test-acc"][i])
                #res[f"val_m_{i+1}_acc"].append(logs[f"m{i+1}"]["valid-acc"][-1][1])

            pairwise_indexes = list(combinations(range(ensemble_size),2))
            test_sim = []
            unlabeled_sim = []
            for pairwise_idx in pairwise_indexes:
                i,j = pairwise_idx
                test_sim.append(logs["test_similarity"][i][j])
                unlabeled_sim.append(logs["unlabeled_final_similarty"][i][j])

            res["test_similarity"].append(np.array(test_sim).mean())
            res["unlabeled_final_similarity"].append(np.array(unlabeled_sim).mean())
    
    return res

def display_results(filename_format:str,title:str, seeds:List[int], ensemble_size=2 , inverse = False, show_val = False):

    res= get_results(filename_format=filename_format, seeds=seeds , ensemble_size=ensemble_size, inverse = inverse, show_val=show_val)
    df = pd.DataFrame(res)
    if df.empty:
        return False
    results = df.aggregate(["mean","std"])

    print(title)
    pd.options.display.float_format = "{:,.3f}".format
    display(results)##inverse
    return True


# Office Home P vs NP

In [6]:
from wilds.datasets.waterbirds_dataset import WaterbirdsDataset

def display_office_home_results():
    
    for alpha in [0.00001, 0.1]:
        for epoch in [30, 50, 60, 100]:
            print(f"EPOCHS {epoch} alpha {alpha}")
            waterbirds_filename_format= {
                #"DBAT_NP": f"/datasets/home/hbenoit/D-BAT-exp/office_home/oh-65cls/perturb=ood_is_test/resnet50_pretrained=False/ep[{epoch}]/ep=[{epoch}]_lrmax=0.001_alpha={alpha}_dataset=oh-65cls_model=resnet50_pretrained=False_scheduler=none_seed={{seed}}_opt=sgd_ensemble_size=2_no_diversity=False",
                "DBAT": f"/datasets/home/hbenoit/D-BAT-exp/office_home/oh-65cls/perturb=ood_is_test/resnet50_pretrained=True/ep[{epoch}]/ep=[{epoch}]_lrmax=0.001_alpha={alpha}_dataset=oh-65cls_model=resnet50_pretrained=True_scheduler=none_seed={{seed}}_opt=sgd_ensemble_size=2_no_diversity=False"
            }

            for name, filename_format in waterbirds_filename_format.items():
                title = str.upper(name)
                display_results(filename_format=filename_format, title=title, seeds=[0,1,2], show_val=False)

display_office_home_results()

EPOCHS 30 alpha 1e-05
DBAT


Unnamed: 0,test_acc_ensemble,test_m_1_acc,test_m_2_acc,test_similarity,unlabeled_final_similarity
mean,0.592,0.58,0.583,0.754,0.749
std,0.005,0.003,0.007,0.006,0.01


EPOCHS 50 alpha 1e-05
EPOCHS 60 alpha 1e-05
DBAT


Unnamed: 0,test_acc_ensemble,test_m_1_acc,test_m_2_acc,test_similarity,unlabeled_final_similarity
mean,0.597,0.583,0.582,0.748,0.753
std,0.004,0.003,0.009,0.006,0.01


EPOCHS 100 alpha 1e-05
EPOCHS 30 alpha 0.1
DBAT


Unnamed: 0,test_acc_ensemble,test_m_1_acc,test_m_2_acc,test_similarity,unlabeled_final_similarity
mean,0.583,0.58,0.555,0.726,0.733
std,0.005,0.003,0.007,0.008,0.017


EPOCHS 50 alpha 0.1
EPOCHS 60 alpha 0.1
