In [3]:
import json

from src.params import RESULTS_DIR, DatasetName, ModelName


def load_results():
    json_filenames = RESULTS_DIR.glob("*.json")
    results = []
    for json_filename in json_filenames:
        with open(json_filename, "r") as f:
            results.append(json.load(f))
    return results

def filter_results(results_list, filter_dict):
    filtered_results = [result for result in results_list if all(key in result["settings"] and result["settings"][key] == value for key, value in filter_dict.items())]
    return filtered_results

In [15]:
# Vary k test

results = load_results()

filters = {
    "dataset": DatasetName.IMAGENET.value,
    "strong_model": ModelName.VITB8_DINO.value,
    # "exp_id": "vary-k-new",
    "exp_id": "w-to-s-new",
    "num_heads": "100",
    "weight_decay_fixed": True
}

filtered_results = filter_results(results, filters)


In [16]:
from collections import defaultdict

# Should be 5 for Cifar, 4 for imagenet
print(f"Num Results: {len(filtered_results)}")

# Check that there is 1 of each type for num heads
num_heads_count = defaultdict(int)
for result in filtered_results:
    num_heads_count[result["settings"]["num_heads"]] += 1

print(num_heads_count)

Num Results: 10
defaultdict(<class 'int'>, {'100': 10})


In [17]:


import numpy as np


def display_results_for_nh(results_list):
    metrics = [
        "wk<-gt_cross_entropy__mean",
        "wk<-gt_accuracy__mean",
        "stgt<-wk_kl_divergence__mean",
        "stgt<-wk_accuracy__mean",
        "st<-gt_cross_entropy__mean",
        "st<-wk_kl_divergence__mean",
        "st<-gt_accuracy__mean",
        "stgt<-st_kl_divergence__mean",
        "stgt<-st_accuracy__mean",
    ]
    # Weak model loss r
    # Weak model acc r
    # Strong model loss  
    # Strong model misfit  
    # Strong model test acc  
    # Discrepancy
    # Strong model loss r
    # Strong model test acc r
    # Discrepancy r

    metric_values = defaultdict(lambda: defaultdict(list))
    for res in results_list:
        for metric in metrics:
            metric_values[int(res["settings"]["num_heads"])][metric].append(res["results"][metric])
    
    # Take mean and std of each metric
    for _, metric_dict in metric_values.items():
        for metric, values in metric_dict.items():
            mean = np.mean(values)
            std = np.std(values)
            metric_dict[metric] = (mean, std)
    
    # Print the results in the desired form
    for num_heads, metric_dict in sorted(metric_values.items(), key=lambda x: x[0]):
        print(f"k={num_heads}")
        print(f"Weak model loss\t{metric_dict['wk<-gt_cross_entropy__mean'][0]}\t{metric_dict['wk<-gt_cross_entropy__mean'][1]}")
        print(f"Weak model acc\t{metric_dict['wk<-gt_accuracy__mean'][0]}\t{metric_dict['wk<-gt_accuracy__mean'][1]}")
        print(f"Weak model loss r\t{metric_dict['stgt<-wk_kl_divergence__mean'][0]}\t{metric_dict['stgt<-wk_kl_divergence__mean'][1]}")
        print(f"Weak model acc r\t{metric_dict['stgt<-wk_accuracy__mean'][0]}\t{metric_dict['stgt<-wk_accuracy__mean'][1]}")
        print(f"Strong model loss\t{metric_dict['st<-gt_cross_entropy__mean'][0]}\t{metric_dict['st<-gt_cross_entropy__mean'][1]}")
        print(f"Strong model misfit\t{metric_dict['st<-wk_kl_divergence__mean'][0]}\t{metric_dict['st<-wk_kl_divergence__mean'][1]}")
        print(f"Strong model test acc\t{metric_dict['st<-gt_accuracy__mean'][0]}\t{metric_dict['st<-gt_accuracy__mean'][1]}")
        discrepancy = metric_dict['wk<-gt_cross_entropy__mean'][0] - metric_dict['st<-wk_kl_divergence__mean'][0]- metric_dict['st<-gt_cross_entropy__mean'][0]
        print(f"Discrepancy\t{discrepancy}\t")
        print(f"Strong model loss r\t{metric_dict['stgt<-st_kl_divergence__mean'][0]}\t{metric_dict['stgt<-st_kl_divergence__mean'][1]}")
        print(f"Strong model test acc r\t{metric_dict['stgt<-st_accuracy__mean'][0]}\t{metric_dict['stgt<-st_accuracy__mean'][1]}")
        discrepancy_r = metric_dict['stgt<-wk_kl_divergence__mean'][0] - metric_dict['st<-wk_kl_divergence__mean'][0] - metric_dict['stgt<-st_kl_divergence__mean'][0]
        print(f"Discrepancy r\t{discrepancy_r}\t")

display_results_for_nh(filtered_results)

k=100
Weak model loss	1.9258971333503723	2.843585119694844e-07
Weak model acc	0.5588001608848572	9.233911862867873e-08
Weak model loss r	1.5178813815116883	2.636116074025152e-07
Weak model acc r	0.566400146484375	7.633118912497579e-08
Strong model loss	1.4742154240608216	0.002815564253154825
Strong model misfit	1.6401374459266662	0.0023354003413453045
Strong model test acc	0.6986101388931274	0.0022055672457203393
Discrepancy	-1.1884557366371156	
Strong model loss r	0.7102438569068908	0.0018074483135753958
Strong model test acc r	0.7579601585865021	0.002253075204909093
Discrepancy r	-0.8324999213218688	


In [8]:
# Tests for w-to-s-forward
results = load_results()

filters = {
    "dataset": DatasetName.CIFAR10.value,
    "strong_model": ModelName.VITB8_DINO.value,
    "num_heads": "100",
    "exp_id": "w-to-s-forward-3",
    "weight_decay_fixed": True,
}

filtered_results = filter_results(results, filters)

In [9]:
from collections import defaultdict

# Should be 2
print(f"Num Results: {len(filtered_results)}")

# Check that there is 1 of each type for model
model_count = defaultdict(int)
for result in filtered_results:
    model_count[result["settings"]["strong_model"]] += 1

print(model_count)

Num Results: 3
defaultdict(<class 'int'>, {'vitb8_dino': 3})


In [14]:
import numpy as np


def display_results_for_w_to_s(results_list):
    metrics = [
        "wk<-gt_cross_entropy__mean",
        "wk<-gt_accuracy__mean",
        "st<-gt_cross_entropy__mean",
        "st<-wk_kl_divergence__mean",
        "st<-gt_accuracy__mean",

        # "stgt<-gt_cross_entropy__mean",
        # "stgt<-gt_accuracy__mean",        
        # "stgt<-wk_kl_divergence__mean",
        # "stgt<-wk_accuracy__mean",
        # "stgt<-st_kl_divergence__mean",
        # "stgt<-st_accuracy__mean",

        "wk<-st_kl_divergence__mean",
        "wk<-st_cross_entropy__mean",
    ]
    # Weak model loss
    # Weak model acc
    # Strong model loss  
    # Strong model misfit  
    # Strong model test acc  
    # Discrepancy
    # Test r acc
    # Test r loss
    # Weak model loss r
    # Weak model acc r
    # Strong model loss r
    # Strong model test acc r
    # Discrepancy r

    metric_values = defaultdict(lambda: defaultdict(list))
    for res in results_list:
        for metric in metrics:
            metric_values[res["settings"]["strong_model"]][metric].append(res["results"][metric])
    
    # Take mean and std of each metric
    for _, metric_dict in metric_values.items():
        for metric, values in metric_dict.items():
            mean = np.mean(values)
            std = np.std(values)
            metric_dict[metric] = (mean, std)
    
    # Print the results in the desired form
    for strong_model, metric_dict in sorted(metric_values.items(), key=lambda x: x[0]):
        print(f"{strong_model}")
        print(f"Weak model loss\t{metric_dict['wk<-gt_cross_entropy__mean'][0]}\t{metric_dict['wk<-gt_cross_entropy__mean'][1]}")
        print(f"Weak model acc\t{metric_dict['wk<-gt_accuracy__mean'][0]}\t{metric_dict['wk<-gt_accuracy__mean'][1]}")
        print(f"Strong model loss\t{metric_dict['st<-gt_cross_entropy__mean'][0]}\t{metric_dict['st<-gt_cross_entropy__mean'][1]}")
        print(f"Forward-trained Reverse KL misfit\t{metric_dict['st<-wk_kl_divergence__mean'][0]}\t{metric_dict['st<-wk_kl_divergence__mean'][1]}")
        print(f"Strong model test acc\t{metric_dict['st<-gt_accuracy__mean'][0]}\t{metric_dict['st<-gt_accuracy__mean'][1]}")
        # discrepancy = metric_dict['wk<-gt_cross_entropy__mean'][0] - metric_dict['st<-wk_kl_divergence__mean'][0]- metric_dict['st<-gt_cross_entropy__mean'][0]
        # print(f"Discrepancy\t{discrepancy}\t")

        print(f"Forward KL Misfit\t{metric_dict['wk<-st_kl_divergence__mean'][0]}\t{metric_dict['wk<-st_kl_divergence__mean'][1]}")
        print(f"Forward KL Misfit XE\t{metric_dict['wk<-st_cross_entropy__mean'][0]}\t{metric_dict['wk<-st_cross_entropy__mean'][1]}")

display_results_for_w_to_s(filtered_results)

vitb8_dino
Weak model loss	0.5824153621991476	1.2247590229846988e-07
Weak model acc	0.796300212542216	7.434005313662571e-08
Strong model loss	0.40521153807640076	0.0021192049083189287
Forward-trained Reverse KL misfit	0.3681456645329793	0.0017983313355465205
Strong model test acc	0.8981336355209351	0.0011556690588221707
Forward KL Misfit	0.26816285649935406	0.00034359007508725554
Forward KL Misfit XE	1.1024143695831299	0.004544285036629746


In [None]:
# Tests for w-to-s
results = load_results()

filters = {
    "dataset": DatasetName.CIFAR10.value,
    "strong_model": ModelName.VITB8_DINO.value,
    "num_heads": "100",
    "exp_id": "w-to-s-new",
    "weight_decay_fixed": True,
}

filtered_results = filter_results(results, filters)

In [5]:
from collections import defaultdict

# Should be 2
print(f"Num Results: {len(filtered_results)}")

# Check that there is 1 of each type for model
model_count = defaultdict(int)
for result in filtered_results:
    model_count[result["settings"]["strong_model"]] += 1

print(model_count)

Num Results: 3
defaultdict(<class 'int'>, {'vitb8_dino': 3})


In [7]:
import numpy as np


def display_results_for_w_to_s(results_list):
    metrics = [
        "wk<-gt_cross_entropy__mean",
        "wk<-gt_accuracy__mean",
        "st<-gt_cross_entropy__mean",
        "st<-wk_kl_divergence__mean",
        "st<-gt_accuracy__mean",

        "stgt<-gt_cross_entropy__mean",
        "stgt<-gt_accuracy__mean",        
        "stgt<-wk_kl_divergence__mean",
        "stgt<-wk_accuracy__mean",
        "stgt<-st_kl_divergence__mean",
        "stgt<-st_accuracy__mean",

        # "wk<-st_kl_divergence__mean",
        "wk<-st_cross_entropy__mean",
    ]
    # Weak model loss
    # Weak model acc
    # Strong model loss  
    # Strong model misfit  
    # Strong model test acc  
    # Discrepancy
    # Test r acc
    # Test r loss
    # Weak model loss r
    # Weak model acc r
    # Strong model loss r
    # Strong model test acc r
    # Discrepancy r

    metric_values = defaultdict(lambda: defaultdict(list))
    for res in results_list:
        for metric in metrics:
            metric_values[res["settings"]["strong_model"]][metric].append(res["results"][metric])
    
    # Take mean and std of each metric
    for _, metric_dict in metric_values.items():
        for metric, values in metric_dict.items():
            mean = np.mean(values)
            std = np.std(values)
            metric_dict[metric] = (mean, std)
    
    # Print the results in the desired form
    for strong_model, metric_dict in sorted(metric_values.items(), key=lambda x: x[0]):
        print(f"{strong_model}")
        print(f"Weak model loss\t{metric_dict['wk<-gt_cross_entropy__mean'][0]}\t{metric_dict['wk<-gt_cross_entropy__mean'][1]}")
        print(f"Weak model acc\t{metric_dict['wk<-gt_accuracy__mean'][0]}\t{metric_dict['wk<-gt_accuracy__mean'][1]}")
        print(f"Strong model loss\t{metric_dict['st<-gt_cross_entropy__mean'][0]}\t{metric_dict['st<-gt_cross_entropy__mean'][1]}")
        print(f"Strong model misfit\t{metric_dict['st<-wk_kl_divergence__mean'][0]}\t{metric_dict['st<-wk_kl_divergence__mean'][1]}")
        print(f"Strong model test acc\t{metric_dict['st<-gt_accuracy__mean'][0]}\t{metric_dict['st<-gt_accuracy__mean'][1]}")
        discrepancy = metric_dict['wk<-gt_cross_entropy__mean'][0] - metric_dict['st<-wk_kl_divergence__mean'][0]- metric_dict['st<-gt_cross_entropy__mean'][0]
        print(f"Discrepancy\t{discrepancy}\t")

        print(f"Test r loss\t{metric_dict['stgt<-gt_cross_entropy__mean'][0]}\t{metric_dict['stgt<-gt_cross_entropy__mean'][1]}")
        print(f"Test r acc\t{metric_dict['stgt<-gt_accuracy__mean'][0]}\t{metric_dict['stgt<-gt_accuracy__mean'][1]}")        
        print(f"Weak model loss r\t{metric_dict['stgt<-wk_kl_divergence__mean'][0]}\t{metric_dict['stgt<-wk_kl_divergence__mean'][1]}")
        print(f"Weak model acc r\t{metric_dict['stgt<-wk_accuracy__mean'][0]}\t{metric_dict['stgt<-wk_accuracy__mean'][1]}")
        print(f"Strong model loss r\t{metric_dict['stgt<-st_kl_divergence__mean'][0]}\t{metric_dict['stgt<-st_kl_divergence__mean'][1]}")
        print(f"Strong model test acc r\t{metric_dict['stgt<-st_accuracy__mean'][0]}\t{metric_dict['stgt<-st_accuracy__mean'][1]}")
        discrepancy_r = metric_dict['stgt<-wk_kl_divergence__mean'][0] - metric_dict['st<-wk_kl_divergence__mean'][0] - metric_dict['stgt<-st_kl_divergence__mean'][0]

        print(f"Strong model loss r\t{metric_dict['stgt<-st_kl_divergence__mean'][0]}\t{metric_dict['stgt<-st_kl_divergence__mean'][1]}")
        print(f"Strong model test acc r\t{metric_dict['stgt<-st_accuracy__mean'][0]}\t{metric_dict['stgt<-st_accuracy__mean'][1]}")

        # print(f"Forward KL Misfit r\t{metric_dict['wk<-st_kl_divergence__mean'][0]}\t{metric_dict['wk<-st_kl_divergence__mean'][1]}")
        print(f"Forward KL Misfit XE r\t{metric_dict['wk<-st_cross_entropy__mean'][0]}\t{metric_dict['wk<-st_cross_entropy__mean'][1]}")
        print(f"Discrepancy r\t{discrepancy_r}\t")

display_results_for_w_to_s(filtered_results)

KeyError: 'wk<-st_cross_entropy__mean'

In [None]:


filters = {
    "dataset": DatasetName.IMAGENET.value,
    "strong_model": ModelName.RESNET50_DINO.value,
    "version": 6.0,
    "num_heads": '100',
    "debug": False,
    "exp_id": "w-to-s-forward-2"
}
# if filters["dataset"] == DatasetName.IMAGENET.value:
    # filters["version"] = 3.0



filtered_results = filter_results(results, filters)

In [3]:
print(f"Num Results: {len(filtered_results)}")
[res["settings"]["num_heads"] for res in filtered_results]

Num Results: 10


['100', '100', '100', '100', '100', '100', '100', '100', '100', '100']

In [4]:
import numpy as np

# GT Error metrics
metrics = [
    "wk<-gt_cross_entropy__mean",
    "wk<-gt_accuracy__mean",
    "st<-gt_cross_entropy__mean",
    "st<-wk_kl_divergence__mean",
    "st<-gt_accuracy__mean",
]

# # STGT Error metrics
# metrics += [
#     "stgt<-gt_cross_entropy__mean",
#     "stgt<-gt_accuracy__mean",
#     "stgt<-st_kl_divergence__mean",
#     "stgt<-wk_kl_divergence__mean",
#     "st<-wk_kl_divergence__mean",
#     "stgt<-st_accuracy__mean",
#     "stgt<-wk_accuracy__mean",
#     "st<-wk_accuracy__mean",
# ]

# metrics = [
#     "stgt<-wk_kl_divergence__mean",
#     "stgt<-wk_accuracy__mean",
#     "st<-gt_cross_entropy__mean",
#     "st<-wk_kl_divergence__mean",
#     "st<-gt_accuracy__mean",
#     "stgt<-st_kl_divergence__mean",
#     "stgt<-st_accuracy__mean",
# ]

def print_metrics(metric_list, results_list):
    for metric in metric_list:
        print("==========")
        print(metric)
        for result in results_list:
            print(result["results"][metric])

# print_metrics(metrics, filtered_results)

def get_metric_stats(metric_list, results_list):
    extracted_metrics = {}
    for metric in metric_list:
        extracted_metrics[metric] = np.array([result["results"][metric] for result in results_list], dtype=float)
    
    stats = {}
    for metric, values in extracted_metrics.items():
        stats[metric] = {
            "mean": np.mean(values),
            "std": np.std(values),
        }
    return stats

# Display the metric stats in a form I can copy and paste
def display_metric_stats(metric_stats):
    for metric, stats in metric_stats.items():
        print(f"{metric}: {stats['mean']} ± {stats['std']}")

stats = get_metric_stats(metrics, filtered_results)
display_metric_stats(stats)

wk<-gt_cross_entropy__mean: 1.9258971571922303 ± 3.5383277129306385e-07
wk<-gt_accuracy__mean: 0.5588001310825348 ± 8.940696716308594e-08
st<-gt_cross_entropy__mean: 1.648470950126648 ± 0.0039875988601694315
st<-wk_kl_divergence__mean: 1.0988789796829224 ± 0.003200920804224958
st<-gt_accuracy__mean: 0.6022701621055603 ± 0.002588475704606889
