In [1]:
import pathlib
import os, sys
import numpy as np
from loguru import logger
from collections import defaultdict
from openset_imagenet.util import ccr_at_fpr
import openset_imagenet
from matplotlib import pyplot

In [2]:
def load_scores(args):
    # we sort them as follows: protocol, loss, algorithm
    scores = defaultdict(lambda: defaultdict(dict))
    ground_truths = {}
    for net in args["networks"]:
                output_directory = pathlib.Path(args["output_directory"])
                score_file = f"{net}.npz"
                if os.path.exists(score_file):
                    # remember files
                    results = np.load(score_file)
                    scores[net] = results["scores"] # only change in maxlogits case

                    if len(ground_truths) == 0:
                        ground_truths = results["gt"].astype(int)
                    else:
                        assert np.all(results["gt"] == ground_truths)

                    logger.info(f"Loaded score file {score_file} for net {net}")
                else:
                    logger.warning(f"Did not find score file {score_file} for net {net}")

    return scores, ground_truths

In [3]:
arguments = { 
  "output_directory": "experiments/ex_6",
  "networks": ["net_1_curr", "net_1_best", "net_2_curr", "net_2_best"]
  }
THRESHOLDS = {
              1e-3: "$10^{-3}$",
              1e-2: "$10^{-2}$",
              1e-1: "$10^{-1}$",
              1: "$1$",
}

scores, ground_truths = load_scores(arguments)
# we get ccr@fpr for each network
for net in arguments["networks"]:
    ccrs = ccr_at_fpr(ground_truths, scores[net], THRESHOLDS)
    print(f"Network {net}: {ccrs} for thresholds {THRESHOLDS.values()}")

[32m2024-06-03 11:12:08.024[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_scores[0m:[36m18[0m - [1mLoaded score file net_1_curr.npz for net net_1_curr[0m
[32m2024-06-03 11:12:08.029[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_scores[0m:[36m18[0m - [1mLoaded score file net_1_best.npz for net net_1_best[0m
[32m2024-06-03 11:12:08.037[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_scores[0m:[36m18[0m - [1mLoaded score file net_2_curr.npz for net net_2_curr[0m
[32m2024-06-03 11:12:08.041[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_scores[0m:[36m18[0m - [1mLoaded score file net_2_best.npz for net net_2_best[0m


Network net_1_curr: [None, 0.09533333333333334, 0.34933333333333333, 0.6446666666666667] for thresholds dict_values(['$10^{-3}$', '$10^{-2}$', '$10^{-1}$', '$1$'])
Network net_1_best: [None, 0.132, 0.38066666666666665, 0.6426666666666667] for thresholds dict_values(['$10^{-3}$', '$10^{-2}$', '$10^{-1}$', '$1$'])
Network net_2_curr: [0.025333333333333333, 0.09933333333333333, 0.33866666666666667, 0.6213333333333333] for thresholds dict_values(['$10^{-3}$', '$10^{-2}$', '$10^{-1}$', '$1$'])
Network net_2_best: [0.021333333333333333, 0.07733333333333334, 0.28, 0.5926666666666667] for thresholds dict_values(['$10^{-3}$', '$10^{-2}$', '$10^{-1}$', '$1$'])
