In [1]:
import os
import json
from tqdm import tqdm
import warnings

import numpy as np

from dpipe.io import load
from dpipe.dataset.wrappers import apply, cache_methods
from dpipe.im.metrics import dice_score, iou
from ood.dataset.cc359 import CC359
from ood.dataset.utils import Rescale3D, scale_mri
from ood.paths import CC359_DATA_PATH
from ood.torch.module.unet_mc_dropout import UNet3D_MC_Dropout
from ood.utils import sdice
from ood.metric.ood_metric import calc_ood_scores


data_path = CC359_DATA_PATH

# if `voxel_spacing[i]` is `None` when `i`-th dimension will be used without scaling
voxel_spacing = (1, 0.95, 0.95)
sdice_metric = lambda x, y, i: sdice(x, y, dataset.load_spacing(i), 1)

dataset = apply(Rescale3D(CC359(data_path), voxel_spacing), load_image=scale_mri)

in_distr_id = 1
n_folds = 6
ood_ids = [i for i in range(n_folds) if i != in_distr_id]
n_seeds = 5
bin_threshold = 0.5
eps = 1e-8

experiment_dir = '/mount/sdc/experiments/ood_playground/cc359/ensemble/'

## Dice drop with domain shift

In [2]:
ood_metrics = np.zeros((n_folds, n_seeds))

for seed in range(n_seeds):
    base_dir = os.path.join(experiment_dir, f'seed{seed}/experiment_0/')
    with open(os.path.join(base_dir, 'test_metrics', 'dice_score.json')) as file:
        metrics = json.load(file)

    ids = set(dataset.df[dataset.df['fold'] == in_distr_id].index)
    metric_keys = set(metrics.keys())
    ids = metric_keys.intersection(ids)

    in_distr_metric = np.mean([metrics[uid] for uid in ids])
    ood_metrics[in_distr_id, seed] = in_distr_metric
    print(f'In distribution: {in_distr_metric:.4f}')

    for fold_id in ood_ids:
        ids = list(dataset.df[dataset.df['fold'] == fold_id].index)

        ood_metric = np.mean([metrics[uid] for uid in ids])
        ood_metrics[fold_id, seed] = ood_metric
        print(f'OOD fold {fold_id}: {ood_metric:.4f}')
        
means = ood_metrics.mean(axis=1)
stds = ood_metrics.std(axis=1)

for i in range(n_folds):
    print(f'Fold {i}:\tood = {i != in_distr_id}\t{means[i]:.4f} ± {1.96 * stds[i]:.4f}')
    
print()
ood_ids = [i for i in range(n_folds) if i != in_distr_id]
print(f'In distr:\t{means[in_distr_id]:.4f} ± {1.96 * stds[in_distr_id]:.4f}')
print(f'OOD:\t\t{ood_metrics[ood_ids].mean():.4f} ± {1.96 * ood_metrics[ood_ids].std():.4f}')

In distribution: 0.9816
OOD fold 0: 0.9517
OOD fold 2: 0.6995
OOD fold 3: 0.9606
OOD fold 4: 0.9151
OOD fold 5: 0.8860
In distribution: 0.9786
OOD fold 0: 0.9427
OOD fold 2: 0.8180
OOD fold 3: 0.9468
OOD fold 4: 0.9212
OOD fold 5: 0.8992
In distribution: 0.9809
OOD fold 0: 0.9541
OOD fold 2: 0.8248
OOD fold 3: 0.9598
OOD fold 4: 0.9443
OOD fold 5: 0.9084
In distribution: 0.9807
OOD fold 0: 0.9412
OOD fold 2: 0.7124
OOD fold 3: 0.9571
OOD fold 4: 0.8927
OOD fold 5: 0.8836
In distribution: 0.9790
OOD fold 0: 0.9245
OOD fold 2: 0.6794
OOD fold 3: 0.9371
OOD fold 4: 0.7828
OOD fold 5: 0.8899
Fold 0:	ood = True	0.9429 ± 0.0205
Fold 1:	ood = False	0.9802 ± 0.0023
Fold 2:	ood = True	0.7468 ± 0.1212
Fold 3:	ood = True	0.9523 ± 0.0177
Fold 4:	ood = True	0.8912 ± 0.1110
Fold 5:	ood = True	0.8934 ± 0.0180

In distr:	0.9802 ± 0.0023
OOD:		0.8853 ± 0.1625


In [3]:
0.9802 - 0.8853

0.09489999999999998

In [4]:
ood_metrics = np.zeros((n_folds, n_seeds))

for seed in range(n_seeds):
    base_dir = os.path.join(experiment_dir, f'seed{seed}/experiment_0/')
    with open(os.path.join(base_dir, 'test_metrics', 'sdice_score.json')) as file:
        metrics = json.load(file)

    ids = set(dataset.df[dataset.df['fold'] == in_distr_id].index)
    metric_keys = set(metrics.keys())
    ids = metric_keys.intersection(ids)

    in_distr_metric = np.mean([metrics[uid] for uid in ids])
    ood_metrics[in_distr_id, seed] = in_distr_metric
    print(f'In distribution: {in_distr_metric:.4f}')

    for fold_id in ood_ids:
        ids = list(dataset.df[dataset.df['fold'] == fold_id].index)

        ood_metric = np.mean([metrics[uid] for uid in ids])
        ood_metrics[fold_id, seed] = ood_metric
        print(f'OOD fold {fold_id}: {ood_metric:.4f}')
        
means = ood_metrics.mean(axis=1)
stds = ood_metrics.std(axis=1)

for i in range(n_folds):
    print(f'Fold {i}:\tood = {i != in_distr_id}\t{means[i]:.4f} ± {1.96 * stds[i]:.4f}')
    
print()
ood_ids = [i for i in range(n_folds) if i != in_distr_id]
print(f'In distr:\t{means[in_distr_id]:.4f} ± {1.96 * stds[in_distr_id]:.4f}')
print(f'OOD:\t\t{ood_metrics[ood_ids].mean():.4f} ± {1.96 * ood_metrics[ood_ids].std():.4f}')

In distribution: 0.9116
OOD fold 0: 0.7130
OOD fold 2: 0.1953
OOD fold 3: 0.7532
OOD fold 4: 0.6052
OOD fold 5: 0.5313
In distribution: 0.8795
OOD fold 0: 0.6494
OOD fold 2: 0.2205
OOD fold 3: 0.6986
OOD fold 4: 0.5368
OOD fold 5: 0.4647
In distribution: 0.9123
OOD fold 0: 0.7405
OOD fold 2: 0.3246
OOD fold 3: 0.7729
OOD fold 4: 0.6621
OOD fold 5: 0.5596
In distribution: 0.9034
OOD fold 0: 0.6286
OOD fold 2: 0.1377
OOD fold 3: 0.7083
OOD fold 4: 0.4808
OOD fold 5: 0.4324
In distribution: 0.8883
OOD fold 0: 0.6120
OOD fold 2: 0.1915
OOD fold 3: 0.6506
OOD fold 4: 0.4436
OOD fold 5: 0.4904
Fold 0:	ood = True	0.6687 ± 0.0973
Fold 1:	ood = False	0.8990 ± 0.0256
Fold 2:	ood = True	0.2139 ± 0.1206
Fold 3:	ood = True	0.7167 ± 0.0844
Fold 4:	ood = True	0.5457 ± 0.1563
Fold 5:	ood = True	0.4957 ± 0.0892

In distr:	0.8990 ± 0.0256
OOD:		0.5281 ± 0.3635


In [5]:
0.8990 - 0.5281

0.3709

## Baseline OOD (softmax)

#### Distance from 0.5 (softmax)

In [6]:
det_accs = []
roc_aucs = []
tprs = []

for seed in range(n_seeds):
    labels = {}
    base_dir = os.path.join(experiment_dir, f'seed{seed}/experiment_0/test_predictions')
    filenames = os.listdir(base_dir)
    for i in tqdm(range(len(filenames))):
        preds = load(os.path.join(base_dir, filenames[i]))
        uncertainty_result = np.zeros_like(preds)
        uncertainty_result[preds > 0.5] = (1 - preds)[preds > 0.5]
        uncertainty_result[preds <= 0.5] = preds[preds <= 0.5]
        label = uncertainty_result.mean()
        labels[filenames[i].split('.')[0]] = label
        
    is_ood_true = np.array([dataset.df.fold[uid] != in_distr_id for uid in labels.keys()])
    det_acc, roc_auc, tpr = calc_ood_scores(np.array(list(labels.values())), is_ood_true)
    det_accs.append(det_acc)
    roc_aucs.append(roc_auc)
    tprs.append(tpr)

100%|██████████| 329/329 [07:07<00:00,  1.30s/it]


Detection accuracy: 0.9833
AUROC: 0.9982
TNR @ 95% TPR: 1.0000


100%|██████████| 329/329 [07:09<00:00,  1.31s/it]


Detection accuracy: 0.9833
AUROC: 0.9989
TNR @ 95% TPR: 1.0000


100%|██████████| 329/329 [07:09<00:00,  1.30s/it]


Detection accuracy: 0.9833
AUROC: 0.9973
TNR @ 95% TPR: 1.0000


100%|██████████| 329/329 [07:09<00:00,  1.31s/it]


Detection accuracy: 0.9833
AUROC: 0.9982
TNR @ 95% TPR: 1.0000


100%|██████████| 329/329 [07:10<00:00,  1.31s/it]

Detection accuracy: 0.9833
AUROC: 0.9981
TNR @ 95% TPR: 1.0000





In [7]:
print(f'Detection accuracy:\t{np.mean(det_accs):.4f} ± {1.96 * np.std(det_accs):.4f}')
print(f'AUROC:\t{np.mean(roc_aucs):.4f} ± {1.96 * np.std(roc_aucs):.4f}')
print(f'TNR @ 95% TPR:\t{np.mean(tprs):.4f} ± {1.96 * np.std(tprs):.4f}')

Detection accuracy:	0.9833 ± 0.0000
AUROC:	0.9981 ± 0.0010
TNR @ 95% TPR:	1.0000 ± 0.0000


In [5]:
det_accs = []
roc_aucs = []
tprs = []

for seed in [0]:
    labels = {}
    base_dir = os.path.join(experiment_dir, f'seed{seed}/experiment_0/test_predictions')
    filenames = os.listdir(base_dir)
    for i in tqdm(range(len(filenames))):
        preds = load(os.path.join(base_dir, filenames[i]))
        uncertainty_result = np.zeros_like(preds)
        uncertainty_result[preds > 0.5] = (1 - preds)[preds > 0.5]
        uncertainty_result[preds <= 0.5] = preds[preds <= 0.5]
        label = uncertainty_result.mean()
        labels[filenames[i].split('.')[0]] = label
        
    is_ood_true = np.array([dataset.df.fold[uid] != in_distr_id for uid in labels.keys()])
    det_acc, roc_auc, tpr = calc_ood_scores(np.array(list(labels.values())), is_ood_true)
    det_accs.append(det_acc)
    roc_aucs.append(roc_auc)
    tprs.append(tpr)

100%|██████████| 329/329 [05:20<00:00,  1.03it/s]

Detection accuracy: 0.9833333333333334
AUROC: 0.9982162764771461
TNR @ 95% TPR: 0.993920972644377





In [52]:
labels_arr = np.array(list(labels.values()))
folds = np.array([dataset.df.fold[uid] for uid in labels.keys()])
sorted_ids = np.argsort(labels_arr)
# print(labels_arr[sorted_ids])
print(is_ood_true[sorted_ids])
print(folds[sorted_ids])

[False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True False  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  T

#### Entropy (softmax)

In [10]:
det_accs = []
roc_aucs = []
tprs = []

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for seed in range(n_seeds):
        labels = {}
        base_dir = os.path.join(experiment_dir, f'seed{seed}/experiment_0/test_predictions')
        filenames = os.listdir(base_dir)
        for i in tqdm(range(len(filenames))):
            preds = load(os.path.join(base_dir, filenames[i]))
            uncertainty_result = - (preds * np.log2(preds + eps) + (1 - preds) * np.log2(1 - preds + eps))
            uncertainty_result[preds == 0] = 0
            uncertainty_result[preds == 1] = 0
            label = uncertainty_result.mean()
            labels[filenames[i].split('.')[0]] = label

        is_ood_true = np.array([dataset.df.fold[uid] != in_distr_id for uid in labels.keys()])
        det_acc, roc_auc, tpr = calc_ood_scores(np.array(list(labels.values())), is_ood_true)
        det_accs.append(det_acc)
        roc_aucs.append(roc_auc)
        tprs.append(tpr)

100%|██████████| 329/329 [13:38<00:00,  2.49s/it]


Detection accuracy: 0.9833333333333334
AUROC: 0.9983277591973244
TNR @ 95% TPR: 0.993920972644377


100%|██████████| 329/329 [15:26<00:00,  2.82s/it]


Detection accuracy: 0.9916387959866221
AUROC: 0.9994983277591973
TNR @ 95% TPR: 0.993920972644377


100%|██████████| 329/329 [13:18<00:00,  2.43s/it]


Detection accuracy: 0.9833333333333334
AUROC: 0.9983277591973244
TNR @ 95% TPR: 0.993920972644377


100%|██████████| 329/329 [13:59<00:00,  2.55s/it]


Detection accuracy: 0.9833333333333334
AUROC: 0.9984392419175028
TNR @ 95% TPR: 0.993920972644377


100%|██████████| 329/329 [13:41<00:00,  2.50s/it]

Detection accuracy: 0.9833333333333334
AUROC: 0.998773690078038
TNR @ 95% TPR: 0.993920972644377





In [11]:
print(f'Detection accuracy:\t{np.mean(det_accs):.4f} ± {1.96 * np.std(det_accs):.4f}')
print(f'AUROC:\t{np.mean(roc_aucs):.4f} ± {1.96 * np.std(roc_aucs):.4f}')
print(f'TNR @ 95% TPR:\t{np.mean(tprs):.4f} ± {1.96 * np.std(tprs):.4f}')

Detection accuracy:	 0.9849944258639912 ± 0.006644370122630949
AUROC:	 0.9986733556298774 ± 0.0008873912482556409
TNR @ 95% TPR:	 0.993920972644377 ± 0.0


## Ensemble predictions

### Mean uncertainty

In [12]:
labels = {}
dices = {}
sdices = {}

base_dir = os.path.join(experiment_dir, f'seed0/experiment_0/test_predictions')
filenames = os.listdir(base_dir)
uids = [fname.split('.')[0] for fname in filenames]

for uid in tqdm(uids):
    ensemble_preds = []
    for seed in range(n_seeds):
        base_dir = os.path.join(experiment_dir, f'seed{seed}/experiment_0/test_predictions')
        preds = load(os.path.join(base_dir, uid + '.npy.gz'))
        ensemble_preds.append(preds)
    
    ensemble_preds = np.array(ensemble_preds)
    mean_preds = ensemble_preds.mean(axis=0)
    std_preds = ensemble_preds.std(axis=0)
    true_mask = dataset.load_segm(uid)
    dices[uid] = dice_score(mean_preds > bin_threshold, true_mask > bin_threshold)
    sdices[uid] = sdice_metric(mean_preds > bin_threshold, true_mask > bin_threshold, uid)
    labels[uid] = std_preds.mean()
        
is_ood_true = np.array([dataset.df.fold[uid] != in_distr_id for uid in labels.keys()])
det_acc, roc_auc, tpr = calc_ood_scores(np.array(list(labels.values())), is_ood_true)

  0%|          | 1/329 [00:33<3:04:29, 33.75s/it]


KeyboardInterrupt: 

In [30]:
print('DICE SCORE')
ood_metrics = np.zeros((n_folds))
metrics = dices

ids = set(dataset.df[dataset.df['fold'] == in_distr_id].index)
metric_keys = set(metrics.keys())
ids = metric_keys.intersection(ids)

in_distr_metric = np.mean([metrics[uid] for uid in ids])
ood_metrics[in_distr_id] = in_distr_metric

for fold_id in ood_ids:
    ids = list(dataset.df[dataset.df['fold'] == fold_id].index)

    ood_metric = np.mean([metrics[uid] for uid in ids])
    ood_metrics[fold_id] = ood_metric

for i in range(n_folds):
    print(f'Fold {i}:\tood = {i != in_distr_id}\t{ood_metrics[i]:.4f}')
    
print()
ood_ids = [i for i in range(n_folds) if i != in_distr_id]
print(f'In distr:\t{ood_metrics[in_distr_id].mean():.4f}')
print(f'OOD:\t\t{ood_metrics[ood_ids].mean():.4f} ± {1.96 * ood_metrics[ood_ids].std():.4f}')

DICE SCORE
Fold 0:	ood = True	0.9493
Fold 1:	ood = False	0.9816
Fold 2:	ood = True	0.7535
Fold 3:	ood = True	0.9603
Fold 4:	ood = True	0.9100
Fold 5:	ood = True	0.8978

In distr:	0.9816
OOD:		0.8942 ± 0.1453


In [29]:
print('SDICE SCORE')
ood_metrics = np.zeros((n_folds))
metrics = sdices

ids = set(dataset.df[dataset.df['fold'] == in_distr_id].index)
metric_keys = set(metrics.keys())
ids = metric_keys.intersection(ids)

in_distr_metric = np.mean([metrics[uid] for uid in ids])
ood_metrics[in_distr_id] = in_distr_metric

for fold_id in ood_ids:
    ids = list(dataset.df[dataset.df['fold'] == fold_id].index)

    ood_metric = np.mean([metrics[uid] for uid in ids])
    ood_metrics[fold_id] = ood_metric

for i in range(n_folds):
    print(f'Fold {i}:\tood = {i != in_distr_id}\t{ood_metrics[i]:.4f}')
    
print()
ood_ids = [i for i in range(n_folds) if i != in_distr_id]
print(f'In distr:\t{ood_metrics[in_distr_id].mean():.4f}')
print(f'OOD:\t\t{ood_metrics[ood_ids].mean():.4f} ± {1.96 * ood_metrics[ood_ids].std():.4f}')

SDICE SCORE
Fold 0:	ood = True	0.6970
Fold 1:	ood = False	0.9122
Fold 2:	ood = True	0.1954
Fold 3:	ood = True	0.7538
Fold 4:	ood = True	0.5555
Fold 5:	ood = True	0.5065

In distr:	0.9122
OOD:		0.5416 ± 0.3825


### Sample diversity

In [14]:
labels = {}

base_dir = os.path.join(experiment_dir, f'seed0/experiment_0/test_predictions')
filenames = os.listdir(base_dir)
uids = [fname.split('.')[0] for fname in filenames]

for uid in tqdm(uids):
    ensemble_preds = []
    for seed in range(n_seeds):
        base_dir = os.path.join(experiment_dir, f'seed{seed}/experiment_0/test_predictions')
        preds = load(os.path.join(base_dir, uid + '.npy'))
        ensemble_preds.append(preds)
    
    ensemble_preds = np.array(ensemble_preds)
    mean_preds = ensemble_preds.mean(axis=0)
    ious = []
    for i in range(n_seeds):
        ious.append(1 - iou(mean_preds > bin_threshold, ensemble_preds[i] > bin_threshold))
    labels[uid] = np.mean(ious)
        
is_ood_true = np.array([dataset.df.fold[uid] != in_distr_id for uid in labels.keys()])
det_acc, roc_auc, tpr = calc_ood_scores(np.array(list(labels.values())), is_ood_true)

100%|██████████| 329/329 [17:24<00:00,  3.17s/it]

Detection accuracy: 0.9732998885172799
AUROC: 0.9869565217391304
TNR @ 95% TPR: 0.993920972644377





### std(volume) / mean(volume)

In [15]:
labels = {}
dices = {}

base_dir = os.path.join(experiment_dir, f'seed0/experiment_0/test_predictions')
filenames = os.listdir(base_dir)
uids = [fname.split('.')[0] for fname in filenames]

for uid in tqdm(uids):
    ensemble_preds = []
    for seed in range(n_seeds):
        base_dir = os.path.join(experiment_dir, f'seed{seed}/experiment_0/test_predictions')
        preds = load(os.path.join(base_dir, uid + '.npy'))
        ensemble_preds.append(preds)
    
    ensemble_preds = np.array(ensemble_preds)
    volumes = np.array([(pred > bin_threshold).sum() for pred in ensemble_preds])
    mean_volume = volumes.mean()
    std_volume = volumes.std()
#     true_mask = dataset.load_segm(uid)
#     dices[uid] = dice_score(mean_preds > bin_threshold, true_mask > bin_threshold)
    labels[uid] = std_volume / mean_volume
        
is_ood_true = np.array([dataset.df.fold[uid] != in_distr_id for uid in labels.keys()])
det_acc, roc_auc, tpr = calc_ood_scores(np.array(list(labels.values())), is_ood_true)

100%|██████████| 329/329 [05:33<00:00,  1.01s/it]

Detection accuracy: 0.9833333333333334
AUROC: 0.9861761426978818
TNR @ 95% TPR: 0.993920972644377





### Top uncertain voxels

In [16]:
top_elements = 1000000

labels = {}
dices = {}

base_dir = os.path.join(experiment_dir, f'seed0/experiment_0/test_predictions')
filenames = os.listdir(base_dir)
uids = [fname.split('.')[0] for fname in filenames]

for uid in tqdm(uids):
    ensemble_preds = []
    for seed in range(n_seeds):
        base_dir = os.path.join(experiment_dir, f'seed{seed}/experiment_0/test_predictions')
        preds = load(os.path.join(base_dir, uid + '.npy'))
        sorted_preds = np.sort(preds)
        ensemble_preds.append(sorted_preds.flatten()[-top_elements:])
    
    ensemble_preds = np.array(ensemble_preds)
    
#     mean_preds = ensemble_preds.mean(axis=0)
    std_preds = ensemble_preds.std(axis=0)
#     true_mask = dataset.load_segm(uid)
#     dices[uid] = dice_score(mean_preds > bin_threshold, true_mask > bin_threshold)
    labels[uid] = std_preds.mean()
        
is_ood_true = np.array([dataset.df.fold[uid] != in_distr_id for uid in labels.keys()])
det_acc, roc_auc, tpr = calc_ood_scores(np.array(list(labels.values())), is_ood_true)

100%|██████████| 329/329 [26:33<00:00,  4.84s/it]

Detection accuracy: 0.8946488294314381
AUROC: 0.8011705685618729
TNR @ 95% TPR: 0.8085106382978723





In [17]:
top_elements = 500000

labels = {}
dices = {}

base_dir = os.path.join(experiment_dir, f'seed0/experiment_0/test_predictions')
filenames = os.listdir(base_dir)
uids = [fname.split('.')[0] for fname in filenames]

for uid in tqdm(uids):
    ensemble_preds = []
    for seed in range(n_seeds):
        base_dir = os.path.join(experiment_dir, f'seed{seed}/experiment_0/test_predictions')
        preds = load(os.path.join(base_dir, uid + '.npy'))
        sorted_preds = np.sort(preds)
        ensemble_preds.append(sorted_preds.flatten()[-top_elements:])
    
    ensemble_preds = np.array(ensemble_preds)
    
#     mean_preds = ensemble_preds.mean(axis=0)
    std_preds = ensemble_preds.std(axis=0)
#     true_mask = dataset.load_segm(uid)
#     dices[uid] = dice_score(mean_preds > bin_threshold, true_mask > bin_threshold)
    labels[uid] = std_preds.mean()
        
is_ood_true = np.array([dataset.df.fold[uid] != in_distr_id for uid in labels.keys()])
det_acc, roc_auc, tpr = calc_ood_scores(np.array(list(labels.values())), is_ood_true)

100%|██████████| 329/329 [24:14<00:00,  4.42s/it]

Detection accuracy: 0.8746376811594203
AUROC: 0.7948160535117058
TNR @ 95% TPR: 0.7963525835866262





In [18]:
top_elements = 100000

labels = {}
dices = {}

base_dir = os.path.join(experiment_dir, f'seed0/experiment_0/test_predictions')
filenames = os.listdir(base_dir)
uids = [fname.split('.')[0] for fname in filenames]

for uid in tqdm(uids):
    ensemble_preds = []
    for seed in range(n_seeds):
        base_dir = os.path.join(experiment_dir, f'seed{seed}/experiment_0/test_predictions')
        preds = load(os.path.join(base_dir, uid + '.npy'))
        sorted_preds = np.sort(preds)
        ensemble_preds.append(sorted_preds.flatten()[-top_elements:])
    
    ensemble_preds = np.array(ensemble_preds)
    
#     mean_preds = ensemble_preds.mean(axis=0)
    std_preds = ensemble_preds.std(axis=0)
#     true_mask = dataset.load_segm(uid)
#     dices[uid] = dice_score(mean_preds > bin_threshold, true_mask > bin_threshold)
    labels[uid] = std_preds.mean()
        
is_ood_true = np.array([dataset.df.fold[uid] != in_distr_id for uid in labels.keys()])
det_acc, roc_auc, tpr = calc_ood_scores(np.array(list(labels.values())), is_ood_true)

100%|██████████| 329/329 [22:22<00:00,  4.08s/it]

Detection accuracy: 0.8729654403567447
AUROC: 0.8016164994425864
TNR @ 95% TPR: 0.7963525835866262





### Variance instead of std

In [22]:
labels = {}
dices = {}

base_dir = os.path.join(experiment_dir, f'seed0/experiment_0/test_predictions')
filenames = os.listdir(base_dir)
uids = [fname.split('.')[0] for fname in filenames]

for uid in tqdm(uids):
    ensemble_preds = []
    for seed in range(n_seeds):
        base_dir = os.path.join(experiment_dir, f'seed{seed}/experiment_0/test_predictions')
        preds = load(os.path.join(base_dir, uid + '.npy'))
        ensemble_preds.append(preds)
    
    ensemble_preds = np.array(ensemble_preds)
    mean_preds = ensemble_preds.mean(axis=0)
    std_preds = ensemble_preds.var(axis=0)
#     true_mask = dataset.load_segm(uid)
#     dices[uid] = dice_score(mean_preds > bin_threshold, true_mask > bin_threshold)
    labels[uid] = std_preds.mean()
        
is_ood_true = np.array([dataset.df.fold[uid] != in_distr_id for uid in labels.keys()])
det_acc, roc_auc, tpr = calc_ood_scores(np.array(list(labels.values())), is_ood_true)

100%|██████████| 329/329 [1:13:06<00:00, 13.33s/it]

Detection accuracy: 0.9833333333333334
AUROC: 0.9936454849498327
TNR @ 95% TPR: 0.993920972644377





In [23]:
top_elements = 500000

labels = {}
dices = {}

base_dir = os.path.join(experiment_dir, f'seed0/experiment_0/test_predictions')
filenames = os.listdir(base_dir)
uids = [fname.split('.')[0] for fname in filenames]

for uid in tqdm(uids):
    ensemble_preds = []
    for seed in range(n_seeds):
        base_dir = os.path.join(experiment_dir, f'seed{seed}/experiment_0/test_predictions')
        preds = load(os.path.join(base_dir, uid + '.npy'))
        sorted_preds = np.sort(preds)
        ensemble_preds.append(sorted_preds.flatten()[-top_elements:])
    
    ensemble_preds = np.array(ensemble_preds)
    
#     mean_preds = ensemble_preds.mean(axis=0)
    std_preds = ensemble_preds.var(axis=0)
#     true_mask = dataset.load_segm(uid)
#     dices[uid] = dice_score(mean_preds > bin_threshold, true_mask > bin_threshold)
    labels[uid] = std_preds.mean()
        
is_ood_true = np.array([dataset.df.fold[uid] != in_distr_id for uid in labels.keys()])
det_acc, roc_auc, tpr = calc_ood_scores(np.array(list(labels.values())), is_ood_true)

100%|██████████| 329/329 [25:14<00:00,  4.60s/it]

Detection accuracy: 0.705685618729097
AUROC: 0.705685618729097
TNR @ 95% TPR: 0.46504559270516715





## MC dropout

In [21]:
# p = 0.3

labels = {}
cur_labels = []

labels = load('/experiments/ood_playground/cc359/mc_dropout/03/mc_drop_10/experiment_0/test_metrics/get_all_labels_var.json')

is_ood_true = np.array([dataset.df.fold[uid] != in_distr_id for uid in labels.keys()])
print('Mean var - all voxels')
calc_ood_scores(np.array(list(labels.values())), is_ood_true)
print()

labels = load('/experiments/ood_playground/cc359/mc_dropout/03/mc_drop_10/experiment_0/test_metrics/get_top_n_labels_var_1000000.json')
print('Mean var - top 1000000 uncertain voxels')
calc_ood_scores(np.array(list(labels.values())), is_ood_true)
print()

labels = load('/experiments/ood_playground/cc359/mc_dropout/03/mc_drop_10/experiment_0/test_metrics/get_top_n_labels_var_500000.json')
print('Mean var - top 500000 uncertain voxels')
calc_ood_scores(np.array(list(labels.values())), is_ood_true)

Mean var - all voxels
Detection accuracy: 0.9816610925306577
AUROC: 0.9900780379041249
TNR @ 95% TPR: 0.9908814589665653

Mean var - top 1000000 uncertain voxels
Detection accuracy: 0.6053511705685619
AUROC: 0.30641025641025643
TNR @ 95% TPR: 0.3009118541033435

Mean var - top 500000 uncertain voxels
Detection accuracy: 0.5652173913043478
AUROC: 0.20328874024526197
TNR @ 95% TPR: 0.24620060790273557


(0.5652173913043478, 0.20328874024526197, 0.24620060790273557)

In [86]:
# p = 0.5

labels = {}
cur_labels = []

for i in range(1, 5):
    cur_labels.append(load(f'/experiments/ood_playground/cc359/mc_dropout/05/mc_drop_res{i}/experiment_0/test_metrics/get_top_n_labels_std.json'))

labels = {**cur_labels[0], **cur_labels[1], **cur_labels[2], **cur_labels[3]}
is_ood_true = np.array([dataset.df.fold[uid] != in_distr_id for uid in labels.keys()])
calc_ood_scores(np.array(list(labels.values())), is_ood_true)

Detection accuracy: 0.6003344481605352
AUROC: 0.317670011148272
TNR @ 95% TPR: 0.3009118541033435


(0.6003344481605352, 0.317670011148272, 0.3009118541033435)