In [1]:
import pathlib
import os, sys
import numpy as np
from loguru import logger
from collections import defaultdict
from openset_imagenet.util import ccr_at_fpr
import openset_imagenet
from matplotlib import pyplot

In [2]:
def load_scores(args):
    # we sort them as follows: protocol, loss, algorithm
    scores = defaultdict(lambda: defaultdict(dict))
    ground_truths = {}
    for net in args["networks"]:
                output_directory = pathlib.Path(args["output_directory"])
                score_file = f"{net}.npz"
                if os.path.exists(score_file):
                    # remember files
                    results = np.load(score_file)
                    scores[net] = results["scores"] # only change in maxlogits case

                    if len(ground_truths) == 0:
                        ground_truths = results["gt"].astype(int)
                    else:
                        assert np.all(results["gt"] == ground_truths)

                    logger.info(f"Loaded score file {score_file} for net {net}")
                else:
                    logger.warning(f"Did not find score file {score_file} for net {net}")

    return scores, ground_truths

In [3]:
arguments = { 
  "output_directory": "experiments/ex_1",
  "networks": ["net_1", "net_2", "net_3", "net_4", "net_5", "net_6"]
  }
THRESHOLDS = {
              1e-3: "$10^{-3}$",
              1e-2: "$10^{-2}$",
              1e-1: "$10^{-1}$",
              1: "$1$",
}

scores, ground_truths = load_scores(arguments)
# we get ccr@fpr for each network
for net in arguments["networks"]:
    ccrs = ccr_at_fpr(ground_truths, scores[net], THRESHOLDS)
    print(f"Network {net}: {ccrs} for thresholds {THRESHOLDS.values()}")

[32m2024-05-09 15:35:44.034[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_scores[0m:[36m18[0m - [1mLoaded score file net_1.npz for net net_1[0m
[32m2024-05-09 15:35:44.038[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_scores[0m:[36m18[0m - [1mLoaded score file net_2.npz for net net_2[0m
[32m2024-05-09 15:35:44.043[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_scores[0m:[36m18[0m - [1mLoaded score file net_3.npz for net net_3[0m
[32m2024-05-09 15:35:44.048[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_scores[0m:[36m18[0m - [1mLoaded score file net_4.npz for net net_4[0m
[32m2024-05-09 15:35:44.052[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_scores[0m:[36m18[0m - [1mLoaded score file net_5.npz for net net_5[0m


[32m2024-05-09 15:35:44.061[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_scores[0m:[36m18[0m - [1mLoaded score file net_6.npz for net net_6[0m


Network net_1: [0.0011, 0.0317, 0.5124, 0.9861] for thresholds dict_values(['$10^{-3}$', '$10^{-2}$', '$10^{-1}$', '$1$'])
Network net_2: [None, 0.0734, 0.6031, 0.987] for thresholds dict_values(['$10^{-3}$', '$10^{-2}$', '$10^{-1}$', '$1$'])
Network net_3: [0.0002, 0.0309, 0.5139, 0.9875] for thresholds dict_values(['$10^{-3}$', '$10^{-2}$', '$10^{-1}$', '$1$'])
Network net_4: [0.0001, 0.0429, 0.548, 0.9865] for thresholds dict_values(['$10^{-3}$', '$10^{-2}$', '$10^{-1}$', '$1$'])
Network net_5: [0.0059, 0.0716, 0.5342, 0.9871] for thresholds dict_values(['$10^{-3}$', '$10^{-2}$', '$10^{-1}$', '$1$'])
Network net_6: [0.0005, 0.046, 0.6094, 0.9867] for thresholds dict_values(['$10^{-3}$', '$10^{-2}$', '$10^{-1}$', '$1$'])
