In [1]:
import numpy as np
import pandas as pd
import torch
import re
from pathlib import Path

In [2]:
def _to_1d_array(x):
    if isinstance(x, torch.Tensor):
        x = x.detach().cpu().numpy()
    if isinstance(x, (list, tuple, np.ndarray)):
        arr = np.array(x)
    else:
        return None
    if arr.shape == ():
        return None
    return arr.reshape(-1)

def extract_dice_and_count(value):
    arr = _to_1d_array(value)
    if arr is None or arr.size == 0:
        return None, None
    return float(np.nanmean(arr)), int(arr.size)

In [10]:
RESULT_DIR = Path('../../results/unet_eval_auto').resolve()
pt_files = sorted([p for p in RESULT_DIR.glob('*.pt')])
pattern_dataset = re.compile(r'^(?P<root>pmri|mnmv2)')

# Rebuild results with dice + ece
row_records = []
for f in pt_files:
    data = torch.load(f, map_location='cpu')
    for split_key, metrics in data.items():
        if not isinstance(metrics, dict):
            continue
        if 'dice' not in metrics:
            continue
        dice_val, case_count = extract_dice_and_count(metrics['dice'])
        if dice_val is None:
            continue
        ece_val, _ = extract_dice_and_count(metrics.get('ece', None))
        dataset_guess = pattern_dataset.match(f.name).group('root') if pattern_dataset.match(f.name) else 'unknown'
        context = f.stem
        row_records.append({
            'file': f.name,
            'dataset': dataset_guess,
            'context': context,
            'split': split_key.split('_')[-1],
            'dice': dice_val,
            'ece': ece_val,
            'case_count': case_count,
        })

results_df = pd.DataFrame(row_records)
print('Collected rows:', len(results_df))
print('Columns found:', results_df.columns.tolist())
if 'dice' in results_df.columns:
    coverage = results_df['dice'].notna().mean()
    print(f'Dice coverage: {coverage*100:.1f}% of rows')
if 'ece' in results_df.columns:
    coverage = results_df['ece'].notna().mean()
    print(f'ECE coverage: {coverage*100:.1f}% of rows')

if not results_df.empty:
    dice_overview = results_df.pivot_table(index='context', columns='split', values='dice')
    dice_overview = dice_overview.sort_values(by=[c for c in ['test','val','train'] if c in dice_overview.columns], ascending=False)
    print('\nDice overview sample:')
    print(dice_overview.head())
else:
    dice_overview = pd.DataFrame()

results_df

Collected rows: 12
Columns found: ['file', 'dataset', 'context', 'split', 'dice', 'ece', 'case_count']
Dice coverage: 100.0% of rows
ECE coverage: 100.0% of rows

Dice overview sample:
split                                          test     train       val
context                                                                
mnmv2_scanner-symphonytim                  0.864293  0.913121  0.870260
mnmv2_pathology-norm-vs-fall-scanners-all  0.724241  0.927814  0.906605
pmri_threet-to-onepointfivet               0.716169  0.927362  0.862062
pmri_promise12                             0.614633  0.926631  0.877784


  data = torch.load(f, map_location='cpu')


Unnamed: 0,file,dataset,context,split,dice,ece,case_count
0,mnmv2_pathology-norm-vs-fall-scanners-all.pt,mnmv2,mnmv2_pathology-norm-vs-fall-scanners-all,train,0.927814,0.02314,1085
1,mnmv2_pathology-norm-vs-fall-scanners-all.pt,mnmv2,mnmv2_pathology-norm-vs-fall-scanners-all,val,0.906605,0.034563,128
2,mnmv2_pathology-norm-vs-fall-scanners-all.pt,mnmv2,mnmv2_pathology-norm-vs-fall-scanners-all,test,0.724241,0.201755,587
3,mnmv2_scanner-symphonytim.pt,mnmv2,mnmv2_scanner-symphonytim,train,0.913121,0.026828,2390
4,mnmv2_scanner-symphonytim.pt,mnmv2,mnmv2_scanner-symphonytim,val,0.87026,0.049739,252
5,mnmv2_scanner-symphonytim.pt,mnmv2,mnmv2_scanner-symphonytim,test,0.864293,0.057573,3470
6,pmri_promise12.pt,pmri,pmri_promise12,train,0.926631,0.197525,461
7,pmri_promise12.pt,pmri,pmri_promise12,val,0.877784,0.194093,64
8,pmri_promise12.pt,pmri,pmri_promise12,test,0.614633,0.25831,1248
9,pmri_threet-to-onepointfivet.pt,pmri,pmri_threet-to-onepointfivet,train,0.927362,0.173464,1018


In [5]:
metrics

{'dice': tensor([2.4438e-01, 5.2803e-01, 8.4462e-01, 9.1397e-01, 9.2060e-01, 9.1120e-01,
         9.3288e-01, 9.0192e-01, 8.9091e-01, 8.1924e-01, 8.5874e-01, 6.1230e-01,
         0.0000e+00, 6.4241e-01, 8.6251e-01, 8.1441e-01, 8.5253e-01, 6.8489e-01,
         7.1266e-01, 6.7020e-01, 6.2134e-01, 5.7256e-01, 5.7904e-01, 7.6354e-01,
         8.7793e-01, 9.0841e-01, 8.9981e-01, 6.8159e-01, 4.5101e-01, 1.3069e-01,
         4.0104e-01, 3.7676e-01, 1.1373e-01, 3.5430e-01, 0.0000e+00, 7.1328e-01,
         8.6445e-01, 9.2220e-01, 9.0951e-01, 7.8850e-01, 0.0000e+00, 4.8608e-02,
         0.0000e+00, 0.0000e+00, 6.4793e-01, 9.0452e-01, 9.1815e-01, 8.8125e-01,
         8.9597e-01, 9.1851e-01, 8.7226e-01, 7.6442e-01, 4.8705e-01, 4.2709e-01,
         4.6349e-01, 7.5853e-01, 7.2752e-01, 3.5294e-01, 7.6939e-01, 0.0000e+00,
         1.2522e-01, 0.0000e+00, 5.4878e-01, 4.8256e-01, 7.5024e-01, 6.4059e-01,
         5.2529e-01, 3.2961e-01, 7.3002e-01, 2.7809e-01, 4.7852e-02, 9.9941e-02,
         0.0000e+00,