In [1]:
import pathlib
import os, sys
import numpy as np
from loguru import logger
from collections import defaultdict
from openset_imagenet.util import ccr_at_fpr
import openset_imagenet
from matplotlib import pyplot

In [2]:
def load_scores(args):
    # we sort them as follows: protocol, loss, algorithm
    scores = defaultdict(lambda: defaultdict(dict))
    ground_truths = {}
    for net in args["networks"]:
                output_directory = pathlib.Path(args["output_directory"])
                score_file = f"{net}.npz"
                if os.path.exists(score_file):
                    # remember files
                    results = np.load(score_file)
                    scores[net] = results["scores"] # only change in maxlogits case

                    if len(ground_truths) == 0:
                        ground_truths = results["gt"].astype(int)
                    else:
                        assert np.all(results["gt"] == ground_truths)

                    logger.info(f"Loaded score file {score_file} for net {net}")
                else:
                    logger.warning(f"Did not find score file {score_file} for net {net}")

    return scores, ground_truths

In [3]:
arguments = { 
  "output_directory": "experiments/ex_6",
  "networks": ["net_7_val_curr", "net_8_val_curr", "net_9_val_curr", "net_10_val_curr"9]
  }
THRESHOLDS = {
              1e-3: "$10^{-3}$",
              1e-2: "$10^{-2}$",
              1e-1: "$10^{-1}$",
              1: "$1$",
}

scores, ground_truths = load_scores(arguments)
# we get ccr@fpr for each network
for net in arguments["networks"]:
    ccrs = ccr_at_fpr(ground_truths, scores[net], THRESHOLDS, unk_label=-1)
    print(f"Network {net}: {ccrs} for thresholds {THRESHOLDS.values()}")

[32m2024-06-07 10:35:04.047[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_scores[0m:[36m18[0m - [1mLoaded score file net_7_val_curr.npz for net net_7_val_curr[0m
[32m2024-06-07 10:35:04.083[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_scores[0m:[36m18[0m - [1mLoaded score file net_8_val_curr.npz for net net_8_val_curr[0m
[32m2024-06-07 10:35:04.089[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_scores[0m:[36m18[0m - [1mLoaded score file net_9_val_curr.npz for net net_9_val_curr[0m
[32m2024-06-07 10:35:04.095[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_scores[0m:[36m18[0m - [1mLoaded score file net_10_val_curr.npz for net net_10_val_curr[0m


Network net_7_val_curr: [None, 0.09842192691029901, 0.33734772978959027, 0.6503322259136213] for thresholds dict_values(['$10^{-3}$', '$10^{-2}$', '$10^{-1}$', '$1$'])
Network net_8_val_curr: [None, None, 0.32918050941306753, 0.6324750830564784] for thresholds dict_values(['$10^{-3}$', '$10^{-2}$', '$10^{-1}$', '$1$'])
Network net_9_val_curr: [None, 0.07641196013289037, 0.20196566998892582, 0.6413344407530454] for thresholds dict_values(['$10^{-3}$', '$10^{-2}$', '$10^{-1}$', '$1$'])
Network net_10_val_curr: [None, None, 0.2728405315614618, 0.5700442967884828] for thresholds dict_values(['$10^{-3}$', '$10^{-2}$', '$10^{-1}$', '$1$'])
