In [1]:
import json

from src.params import RESULTS_DIR


def load_results():
    json_filenames = RESULTS_DIR.glob("*.json")
    results = []
    for json_filename in json_filenames:
        with open(json_filename, "r") as f:
            results.append(json.load(f))
    return results

In [2]:
results = load_results()

In [38]:
from src.params import DatasetName, ModelName

filters = {
    "dataset": DatasetName.CIFAR10.value,
    "strong_model": ModelName.VITB8_DINO.value,
    "single_weak":True,
    "debug": False,
    # "version": 4.5
    "version": 4.2
}
# if filters["dataset"] == DatasetName.IMAGENET.value:
    # filters["version"] = 4.1

def filter_results(results_list, filter_dict):
    filtered_results = [result for result in results_list if all(key in result["settings"] and result["settings"][key] == value for key, value in filter_dict.items())]
    return filtered_results

filtered_results = filter_results(results, filters)

In [39]:
len(filtered_results)

11

In [40]:
def maybe(x, default):
    if x is None or x == "None":
        return default
    return x

metrics = [
    "gt_to_st_xe__mean",
    "misfit_xe_error__mean",
    "stgt_misfit_xe_error__mean"
]

def print_metrics(metric_list, results_list):
    for metric in metric_list:
        print("==========")
        print(metric)
        for result in sorted(results_list, key=lambda x: int(maybe(x["settings"]["num_labels"], 1000))):
            print(f"Num Labels: {maybe(result['settings']['num_labels'], 1000)}: {result['results'][metric]}")

print_metrics(metrics, filtered_results)

gt_to_st_xe__mean


KeyError: 'num_labels'

In [41]:
metrics = [
    "gt_to_st_xe__mean",
    "misfit_xe_error__mean",
    "stgt_misfit_xe_error__mean"
]

def print_metrics(metric_list, results_list):
    for metric in metric_list:
        print("==========")
        print(metric)
        for result in sorted(results_list, key=lambda x: int(x["settings"]["num_heads"])):
            print(f"Num Heads: {result['settings']['num_heads']}: {result['results'][metric]}")

print_metrics(metrics, filtered_results)

gt_to_st_xe__mean
Num Heads: 1: 0.34802767634391785
Num Heads: 2: 0.33735391497612
Num Heads: 3: 0.3314712345600128
Num Heads: 5: 0.3280852735042572
Num Heads: 10: 0.33353111147880554
Num Heads: 20: 0.3344651758670807
Num Heads: 50: 0.33753758668899536
Num Heads: 75: 0.3350053131580353
Num Heads: 100: 0.3372291028499603
Num Heads: 200: 0.33724161982536316
Num Heads: 500: 0.34301161766052246
misfit_xe_error__mean
Num Heads: 1: -0.09769145399332047
Num Heads: 2: -0.08562387526035309
Num Heads: 3: -0.08123717457056046
Num Heads: 5: -0.07395131886005402
Num Heads: 10: -0.07452872395515442
Num Heads: 20: -0.07508983463048935
Num Heads: 50: -0.0745418593287468
Num Heads: 75: -0.0723423957824707
Num Heads: 100: -0.07313665747642517
Num Heads: 200: -0.0719628855586052
Num Heads: 500: -0.07781611382961273
stgt_misfit_xe_error__mean
Num Heads: 1: -0.055576521903276443
Num Heads: 2: -0.043233472853899
Num Heads: 3: -0.03765132650732994
Num Heads: 5: -0.03055957891047001
Num Heads: 10: -0.03234254

In [8]:
# metrics = [
#     "gt_to_wk_acc__mean",
#     "gt_to_wk_xe__mean",
#     "gt_to_st_acc__mean",
#     "gt_to_st_xe__mean",
#     "st_to_wk__mean",
#     "stgt_to_st_xe__mean",
#     "gain_xe__mean",
#     "stgt_gain__mean"
# ]

metrics = [
    "gt_to_wk_acc__mean",
    "gt_to_wk_xe__mean",
    "gt_to_st_acc__mean",
    "gt_to_st_xe__mean",
    "st_to_wk__mean",
    "stgt_to_st_xe__mean",
    "gain_xe__mean",
    "stgt_gain__mean"
]


def print_metrics(metric_list, results_list):
    for metric in metric_list:
        print("==========")
        print(metric)
        for result in results_list:
            print(result["results"][metric])

print_metrics(metrics, filtered_results)

gt_to_wk_acc__mean
0.5588001608848572
0.5587999820709229
0.5588001012802124
0.5588000416755676
0.5588001608848572
0.5587999224662781
0.5587999820709229
0.5588003396987915
0.5588003396987915
0.5588003396987915
gt_to_wk_xe__mean
1.9258971214294434
1.9258971214294434
1.925897240638733
1.9258968830108643
1.9258973598480225
1.9258971214294434
1.9258967638015747
1.925897240638733
1.9258968830108643
1.925897240638733
gt_to_st_acc__mean
0.6677002906799316
0.6577003002166748
0.6617000699043274
0.6603003144264221
0.6608001589775085
0.666000247001648
0.664400041103363
0.6618001461029053
0.6599002480506897
0.6659001111984253
gt_to_st_xe__mean
1.3466585874557495
1.3731857538223267
1.3697400093078613
1.3825262784957886
1.3879469633102417
1.3701244592666626
1.3572076559066772
1.3780063390731812
1.3721835613250732
1.3481605052947998
st_to_wk__mean
1.1931260824203491
1.2114840745925903
1.2047206163406372
1.204551339149475
1.2112374305725098
1.212406873703003
1.1987900733947754
1.2093071937561035
1.2048