In [5]:
import numpy as np
import pandas as pd
import torch
import re
from pathlib import Path

In [6]:
def _to_1d_array(x):
    if isinstance(x, torch.Tensor):
        x = x.detach().cpu().numpy()
    if isinstance(x, (list, tuple, np.ndarray)):
        arr = np.array(x)
    else:
        return None
    if arr.shape == ():
        return None
    return arr.reshape(-1)

def extract_dice_and_count(value):
    arr = _to_1d_array(value)
    if arr is None or arr.size == 0:
        return None, None
    return float(np.nanmean(arr)), int(arr.size)

In [7]:
RESULT_DIR = Path('../../results/unet_eval_auto').resolve()
pt_files = sorted([p for p in RESULT_DIR.glob('*.pt')])
pattern_dataset = re.compile(r'^(?P<root>pmri|mnmv2)')

# Rebuild results with dice only
row_records = []
for f in pt_files:
    data = torch.load(f, map_location='cpu')
    for split_key, metrics in data.items():
        if not isinstance(metrics, dict):
            continue
        if 'dice' not in metrics:
            continue
        dice_val, case_count = extract_dice_and_count(metrics['dice'])
        if dice_val is None:
            continue
        dataset_guess = pattern_dataset.match(f.name).group('root') if pattern_dataset.match(f.name) else 'unknown'
        context = f.stem
        row_records.append({
            'file': f.name,
            'dataset': dataset_guess,
            'context': context,
            'split': split_key.split('_')[-1],
            'dice': dice_val,
            'case_count': case_count,
        })

results_df = pd.DataFrame(row_records)
print('Collected rows:', len(results_df))
print('Columns found:', results_df.columns.tolist())
if 'dice' in results_df.columns:
    coverage = results_df['dice'].notna().mean()
    print(f'Dice coverage: {coverage*100:.1f}% of rows')

if not results_df.empty:
    dice_overview = results_df.pivot_table(index='context', columns='split', values='dice')
    dice_overview = dice_overview.sort_values(by=[c for c in ['test','val','train'] if c in dice_overview.columns], ascending=False)
    print('\nDice overview sample:')
    print(dice_overview.head())
else:
    dice_overview = pd.DataFrame()

results_df

Collected rows: 12
Columns found: ['file', 'dataset', 'context', 'split', 'dice', 'case_count']
Dice coverage: 100.0% of rows

Dice overview sample:
split                                                   test     train  \
context                                                                  
mnmv2_scanner_source=SymphonyTim                    0.864001  0.911730   
mnmv2_pathology_protocol=norm_vs_fall_scanners=ALL  0.727115  0.925501   
pmri_threeT_to_onePointFiveT                        0.666439  0.937258   
pmri_promise12                                      0.591685  0.946627   

split                                                    val  
context                                                       
mnmv2_scanner_source=SymphonyTim                    0.866769  
mnmv2_pathology_protocol=norm_vs_fall_scanners=ALL  0.902778  
pmri_threeT_to_onePointFiveT                        0.918780  
pmri_promise12                                      0.875570  


  data = torch.load(f, map_location='cpu')


Unnamed: 0,file,dataset,context,split,dice,case_count
0,mnmv2_pathology_protocol=norm_vs_fall_scanners...,mnmv2,mnmv2_pathology_protocol=norm_vs_fall_scanners...,train,0.925501,1085
1,mnmv2_pathology_protocol=norm_vs_fall_scanners...,mnmv2,mnmv2_pathology_protocol=norm_vs_fall_scanners...,val,0.902778,128
2,mnmv2_pathology_protocol=norm_vs_fall_scanners...,mnmv2,mnmv2_pathology_protocol=norm_vs_fall_scanners...,test,0.727115,587
3,mnmv2_scanner_source=SymphonyTim.pt,mnmv2,mnmv2_scanner_source=SymphonyTim,train,0.91173,2390
4,mnmv2_scanner_source=SymphonyTim.pt,mnmv2,mnmv2_scanner_source=SymphonyTim,val,0.866769,252
5,mnmv2_scanner_source=SymphonyTim.pt,mnmv2,mnmv2_scanner_source=SymphonyTim,test,0.864001,3470
6,pmri_promise12.pt,pmri,pmri_promise12,train,0.946627,461
7,pmri_promise12.pt,pmri,pmri_promise12,val,0.87557,64
8,pmri_promise12.pt,pmri,pmri_promise12,test,0.591685,1248
9,pmri_threeT_to_onePointFiveT.pt,pmri,pmri_threeT_to_onePointFiveT,train,0.937258,1009
