In [1]:
%load_ext autoreload
%autoreload 2
### Set CUDA device
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [None]:
import sys, os
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
import numpy as np
from torch import (
    sigmoid,
    softmax,
    stack,
    cat,
    corrcoef,
    zeros,
    sqrt,
    tensor,
    save,
    log,
    load,
    linspace,
    exp,
    triu_indices,
    manual_seed,
    load
)
from torch.special import entr
from time import time
from torch.nn.functional import one_hot
from torch.utils.data import DataLoader
from monai.metrics import (
    DiceMetric,
    compute_hausdorff_distance,
    SurfaceDiceMetric
)
from tqdm import tqdm
sys.path.append('../')
from model.unet import get_unet_module
from losses import dice_per_class_loss, surface_loss


In [None]:
### OOD EVAL


UNET_CKPTS = {
    "mnmv2": 'mnmv2_symphony_dropout-0-1_2025-01-14-15-19', 
    'pmri': 'pmri_runmc_dropout-0-1_2025-01-14-15-58',
}

batch_size = 15

eval_metrics = {
    'dice': dice_per_class_loss,
    'surface': surface_loss
}

unet_cfg = OmegaConf.load('../configs/unet/monai_unet.yaml')

for it in range(0, 1):

    for dataset in ['mnmv2']:
        print(f"Dataset: {dataset}")

        if dataset == 'mnmv2':
            unet_cfg.out_channels = 4
            num_classes = 4
            data_cfg = OmegaConf.load('../configs/data/mnmv2.yaml')
            domain = 'Symphony'

        else:
            unet_cfg.out_channels = 1
            num_classes = 2
            data_cfg = OmegaConf.load('../configs/data/pmri.yaml')
            domain = 'RUNMC'
            sigma = 6.9899

        # for domain in ['siemens', 'ge', 'philips']:
        print(f"Train Vendor: {domain}")
        results = {}
        data_cfg.dataset = dataset
        data_cfg.domain = domain
        data_cfg.non_empty_target = True

        # if it == 0:
        datamodule = get_data_module(
            cfg=data_cfg
        )

        datamodule.setup('test')

        ckpt = UNET_CKPTS[data_cfg.dataset]
        unet_cfg.checkpoint_path = f'../../{unet_cfg.checkpoint_dir}{ckpt}.ckpt'
        unet_cfg.dropout = 0.1

        unet = get_unet_module(
            cfg=unet_cfg,
            metadata=OmegaConf.to_container(unet_cfg),
            load_from_checkpoint=True
        ).model

        for test_domain, test_dl in datamodule.test_dataloader().items():

            if 'train' in test_domain or 'val' in test_domain:
                continue

            test_dataset = test_dl.dataset
            test_dl = DataLoader(
                test_dataset,
                batch_size=1,
                shuffle=False,
            )
            print(f"test_domain: {test_domain}")
            scores = {
                'dice': [],
                'surface': [],
                'dice_agreement': [],
                'surface_agreement': []
            }

            manual_seed(it)
            for batch in tqdm(test_dl):
                input = batch['input'].repeat(batch_size, 1, 1 ,1)
                target = batch['target']

                unet.eval()
                logits = unet(input[:1].cuda())
                for m in unet.modules():
                    if m.__class__.__name__.startswith('Dropout'):
                        m.train()
                
                logits_dropout = unet(input.cuda())
            
                num_classes = max(logits_dropout.shape[1], 2)
                if num_classes > 2:
                    predictions = logits.argmax(1, keepdim=True)
                    predictions_dropout = logits_dropout.argmax(1, keepdim=True)
                    probs = softmax(logits_dropout, dim=1)
                    
                else:
                    predictions = (logits > 0) * 1
                    predictions_dropout = (logits_dropout > 0) * 1
                    probs = sigmoid(logits_dropout)
                    probs = cat([1 - probs, probs], dim=1)

                entropy = 1 - (entr(probs.mean(0)).sum(0).mean() / sqrt(tensor(num_classes)))

                # dice_agreement = pairwise_dice(predictions_dropout, num_classes=num_classes)
                # surface_agreement = pairwise_surface_dice(predictions_dropout, num_classes=num_classes)
                scores['dice_agreement'].append(entropy.detach().cpu().view(1,))
                scores['surface_agreement'].append(entropy.detach().cpu().view(1,))

                for key, fn in eval_metrics.items():
                    _, _, true_score = fn(
                        predicted_segmentation=predictions, 
                        target_segmentation=target.cuda(),
                        prediction=zeros((input.size(0), 1, num_classes)).cuda(),
                        num_classes=num_classes,
                        sigma=0,
                        return_scores=True
                    )

                    scores[key].append(true_score.squeeze(1).detach().cpu())

            scores = {
                key: cat(scores[key]) for key in scores.keys()
            }

            # calculate correlation between entropy and dice/hausdorff
            corr_dice = corrcoef(stack([scores['dice_agreement'], scores['dice']], dim=0))[0,1]
            corr_surface = corrcoef(stack([scores['surface_agreement'], scores['surface']], dim=0))[0,1]
            results[test_domain] = {
                'scores': scores,
            }

            print(f"Correlation Dice: {corr_dice} | Correlation Surface: {corr_surface}")

        save(results, f'../../results/{dataset}_{domain}_aggregated-PE-{batch_size}-{it}.pt')


Dataset: mnmv2
Train Vendor: Symphony
Loading Trio data
Loading Avanto data
Loading HDxt data
Loading EXCITE data
Loading Explorer data
Loading Achieva data
test_domain: Trio


100%|██████████| 94/94 [00:03<00:00, 24.34it/s]


tensor(0.9752) tensor(0.9986)
tensor(0.9752) tensor(0.9986)
Correlation Dice: 0.34463590383529663 | Correlation Surface: 0.141238272190094
test_domain: Avanto


100%|██████████| 695/695 [00:29<00:00, 23.95it/s]


tensor(0.9703) tensor(0.9992)
tensor(0.9703) tensor(0.9992)
Correlation Dice: 0.23627662658691406 | Correlation Surface: 0.23752176761627197
test_domain: HDxt


100%|██████████| 426/426 [00:18<00:00, 23.29it/s]


tensor(0.9447) tensor(0.9990)
tensor(0.9447) tensor(0.9990)
Correlation Dice: 0.2559165060520172 | Correlation Surface: 0.27786508202552795
test_domain: EXCITE


100%|██████████| 459/459 [00:19<00:00, 23.71it/s]


tensor(0.9614) tensor(0.9983)
tensor(0.9614) tensor(0.9983)
Correlation Dice: 0.25722527503967285 | Correlation Surface: 0.3126246929168701
test_domain: Explorer


100%|██████████| 18/18 [00:00<00:00, 23.35it/s]


tensor(0.9884) tensor(0.9963)
tensor(0.9884) tensor(0.9963)
Correlation Dice: 0.4669869542121887 | Correlation Surface: 0.5198052525520325
test_domain: Achieva


100%|██████████| 1422/1422 [01:01<00:00, 23.24it/s]

tensor(0.9661) tensor(0.9994)
tensor(0.9661) tensor(0.9994)
Correlation Dice: 0.11796378344297409 | Correlation Surface: 0.14833880960941315





In [29]:
pmri_results = load('../../results/pmri_RUNMC_aggregated-PE-15-0.pt')
mnmv2_results = load('../../results/mnmv2_Symphony_aggregated-PE-15-0.pt')

In [43]:
dice_corrs = []
surface_corrs = []

for domain in pmri_results.keys():
    scores = pmri_results[domain]['scores']
    dice_corrs.append(corrcoef(stack([scores['dice_agreement'], scores['dice']], dim=0))[0,1])
    surface_corrs.append(corrcoef(stack([scores['surface_agreement'], scores['surface']], dim=0))[0,1])

print(f"PMRI Dice Correlation: {np.mean(dice_corrs):.4f} | PMRI Surface Correlation: {np.mean(surface_corrs):.4f}")

PMRI Dice Correlation: 0.2583 | PMRI Surface Correlation: 0.2898


In [44]:
surface_corrs

[tensor(0.2442),
 tensor(-0.0421),
 tensor(0.5589),
 tensor(0.1281),
 tensor(0.5598)]

In [42]:
dice_corrs = []
surface_corrs = []

for domain in mnmv2_results.keys():
    scores = mnmv2_results[domain]['scores']
    dice_corrs.append(corrcoef(stack([scores['dice_agreement'], scores['dice']], dim=0))[0,1])
    surface_corrs.append(corrcoef(stack([scores['surface_agreement'], scores['surface']], dim=0))[0,1])

print(f"MNMV2 Dice Correlation: {np.mean(dice_corrs):.4f} | MNMV2 Surface Correlation: {np.mean(surface_corrs):.4f}")

MNMV2 Dice Correlation: 0.2798 | MNMV2 Surface Correlation: 0.2729
