In [1]:
import sys
sys.path.insert(0, "../")

import numpy as np
from mdu.eval.eval_utils import load_pickle
from mdu.data.constants import DatasetName
from collections import defaultdict
from mdu.data.data_utils import split_dataset_indices

In [2]:
ENSEMBLE_GROUPS = [
    [0, 1, 2, 3, 4],
    [5, 6, 7, 8, 9],
    [10, 11, 12, 13, 14],
    [15, 16, 17, 18, 19],
]

In [3]:
ind_dataset = DatasetName.CIFAR10.value
ood_dataset = DatasetName.TINY_IMAGENET.value

results = defaultdict(list)

In [4]:
for group in ENSEMBLE_GROUPS:
    all_ind_logits = []
    all_ood_logits = []
    for model_id in group:
        ind_res = load_pickle(
            f"../model_weights/{ind_dataset}/checkpoints/resnet18/CrossEntropy/{model_id}/{ind_dataset}.pkl"
        )
        ood_res = load_pickle(
            f"../model_weights/{ind_dataset}/checkpoints/resnet18/CrossEntropy/{model_id}/{ood_dataset}.pkl"
        )

        logits_ind = ind_res["embeddings"]
        all_ind_logits.append(ind_res["embeddings"][None])
        all_ood_logits.append(ood_res["embeddings"][None])

    y_ind = ind_res["labels"]
    y_ood = ood_res["labels"]

In [5]:
_, train_cond_idx, calib_idx, test_idx = split_dataset_indices(
    logits_ind,
    y_ind,
    train_ratio=0.0,
    calib_ratio=0.1,
    test_ratio=0.8,
)

y_train_cond = y_ind[train_cond_idx]
y_calib = y_ind[calib_idx]

X_train_cond = np.vstack(all_ind_logits)[:, train_cond_idx, :]
X_calib = np.vstack(all_ind_logits)[:, calib_idx, :]
X_test = np.vstack(all_ind_logits)[:, test_idx, :]

X_ood = np.vstack(all_ood_logits)

In [7]:
print(X_train_cond.shape)
print(X_calib.shape)
print(X_test.shape)
print(X_ood.shape)

(5, 1000, 10)
(5, 1800, 10)
(5, 7200, 10)
(5, 10000, 10)
