In [1]:
%load_ext autoreload
%autoreload 2
### Set CUDA device
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [None]:
import sys, os
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
import numpy as np
from torch import (
    sigmoid,
    softmax,
    stack,
    cat,
    corrcoef,
    zeros,
    sqrt,
    tensor,
    save,
    log,
    load,
    linspace,
    exp,
    triu_indices,
    manual_seed,
)
from time import time
from torch.nn.functional import one_hot
from torch.utils.data import DataLoader
from monai.metrics import (
    DiceMetric,
    compute_hausdorff_distance,
    SurfaceDiceMetric
)
from tqdm import tqdm
sys.path.append('../')
from data_utils import get_data_module, Transforms
from model.unet import get_unet_module
from losses import dice_per_class_loss, surface_loss

In [3]:
def pairwise_dice(
    predicted_segmentation, 
    num_classes
):  
    dice_scores = []
    batch_size = predicted_segmentation.shape[0]
    predicted_segmentation = one_hot(predicted_segmentation.squeeze(1), num_classes=num_classes).moveaxis(-1, 1)
    N = predicted_segmentation.shape[0]
    i_idx, j_idx = triu_indices(N, N, offset=1)

    pred = predicted_segmentation[i_idx]
    ref = predicted_segmentation[j_idx]

    dice_scores = DiceMetric(
        include_background=True, 
        reduction="none",
        num_classes=num_classes,
        ignore_empty=False
    )(pred, ref)[..., 1:].nanmean(-1).nan_to_num(0).cpu().detach().mean()

    return dice_scores

    # for i in range(batch_size):
    #     for j in range(i+1, batch_size):
    #         dice = DiceMetric(
    #             include_background=True, 
    #             reduction="none",
    #             num_classes=num_classes,
    #             ignore_empty=False
    #         )(predicted_segmentation[i].unsqueeze(0), predicted_segmentation[j].unsqueeze(0))[..., 1:].nanmean(-1).nan_to_num(0).cpu().detach()
    #         dice_scores.append(dice)

    # return tensor(dice_scores).mean()


def pairwise_hausdorff(
    predicted_segmentation, 
    num_classes,
    sigma: float
):  
    hausdorff_scores = []
    predicted_segmentation = one_hot(predicted_segmentation.squeeze(1), num_classes=num_classes).moveaxis(-1, 1)
    
    batch_size = predicted_segmentation.shape[0]
    for i in range(batch_size):
        for j in range(i+1, batch_size):
            hausdorff = compute_hausdorff_distance(
                y_pred=predicted_segmentation[i:i+1],
                y=predicted_segmentation[j:j+1],
                include_background=True,
                percentile=95,
            ).detach()
            hausdorff = exp(-(hausdorff ** 2) / (2 * sigma**2))[..., 1:].nanmean(-1).nan_to_num(0).cpu().detach()
            hausdorff_scores.append(hausdorff)

    return tensor(hausdorff_scores).mean()


def pairwise_surface_dice(
    predicted_segmentation, 
    num_classes,
):  
    surface_scores = []
    predicted_segmentation = one_hot(predicted_segmentation.squeeze(1), num_classes=num_classes).moveaxis(-1, 1)
    N = predicted_segmentation.shape[0]
    i_idx, j_idx = triu_indices(N, N, offset=1)

    pred = predicted_segmentation[i_idx]
    ref = predicted_segmentation[j_idx]

    surface_scores = SurfaceDiceMetric(
        include_background=True, 
        reduction="none",
        class_thresholds=[3] * num_classes,
    )(pred, ref).detach()[..., 1:].nanmean(-1).nan_to_num(0).cpu().detach().mean()

    return surface_scores
    # batch_size = predicted_segmentation.shape[0]
    # for i in range(batch_size):
    #     for j in range(i+1, batch_size):
    #         surface = SurfaceDiceMetric(
    #             include_background=True, 
    #             reduction="none",
    #             class_thresholds=[3] * num_classes,
    #         )(predicted_segmentation[i:i+1], predicted_segmentation[j:j+1]).detach()
    #         surface=surface[..., 1:].nanmean(-1).nan_to_num(0).cpu().detach()
    #         surface_scores.append(surface)

    # return tensor(surface_scores).mean()

In [None]:


UNET_CKPTS = {
    "mnmv2": 'mnmv2_symphony_dropout-0-1_2025-01-14-15-19', 
    'pmri': 'pmri_runmc_dropout-0-1_2025-01-14-15-58',
}

batch_size = 2

eval_metrics = {
    'dice': dice_per_class_loss,
    'surface': surface_loss
}

unet_cfg = OmegaConf.load('../configs/unet/monai_unet.yaml')


dataset = 'mnmv2'

if dataset == 'mnmv2':
    unet_cfg.out_channels = 4
    num_classes = 4
    data_cfg = OmegaConf.load('../configs/data/mnmv2.yaml')
    domain = 'Symphony'

else:
    unet_cfg.out_channels = 1
    num_classes = 2
    data_cfg = OmegaConf.load('../configs/data/pmri.yaml')
    domain = 'RUNMC'
    sigma = 6.9899

# for domain in ['siemens', 'ge', 'philips']:
print(f"Train Vendor: {domain}")
results = {}
data_cfg.dataset = dataset
data_cfg.domain = domain
data_cfg.non_empty_target = True

# datamodule = get_data_module(
#     cfg=data_cfg
# )

# datamodule.setup('test')

ckpt = UNET_CKPTS[data_cfg.dataset]
unet_cfg.checkpoint_path = f'../../{unet_cfg.checkpoint_dir}{ckpt}.ckpt'
unet_cfg.dropout = 0.1

unet = get_unet_module(
    cfg=unet_cfg,
    metadata=OmegaConf.to_container(unet_cfg),
    load_from_checkpoint=True
).model

test_domain = 'Trio'
test_dl = datamodule.test_dataloader()[test_domain]



# print(f"test_domain: {test_domain}")
# scores = {
#     'dice': [],
#     'surface': [],
#     'dice_agreement': [],
#     'surface_agreement': []
# }

for iteration in range(10):
    scores = {
        'dice': [],
        'surface': [],
        'dice_agreement': [],
        'surface_agreement': []
    }

    print(f"test_domain: {test_domain}")
    test_dataset = test_dl.dataset
    test_dl = DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False,
    )
    for batch in tqdm(test_dl):
        input = batch['input'].repeat(batch_size, 1, 1 ,1)
        target = batch['target']

        unet.eval()
        logits = unet(input[:1].cuda())
        for m in unet.modules():
            if m.__class__.__name__.startswith('Dropout'):
                m.train()

        # manual_seed(iteration)
        logits_dropout = unet(input.cuda())

        num_classes = max(logits_dropout.shape[1], 2)
        if num_classes > 2:
            predictions = logits.argmax(1, keepdim=True)
            predictions_dropout = logits_dropout.argmax(1, keepdim=True)
            
        else:
            predictions = (logits > 0) * 1
            predictions_dropout = (logits_dropout > 0) * 1

        similarity = predictions_dropout.float().std(0).mean()
        # calculate dice agreement
        dice_agreement = pairwise_dice(predictions_dropout, num_classes=num_classes)
        surface_agreement = pairwise_surface_dice(predictions_dropout, num_classes=num_classes)
        scores['dice_agreement'].append(dice_agreement.detach().cpu().view(1,))
        scores['surface_agreement'].append(surface_agreement.detach().cpu().view(1,))

        for key, fn in eval_metrics.items():
            _, _, true_score = fn(
                predicted_segmentation=predictions, 
                target_segmentation=target.cuda(),
                prediction=zeros((input.size(0), 1, num_classes)).cuda(),
                num_classes=num_classes,
                sigma=0,
                return_scores=True
            )

            scores[key].append(true_score.squeeze(1).detach().cpu())

        # break

    scores = {
        key: cat(scores[key]) for key in scores.keys()
    }

    # calculate correlation between entropy and dice/hausdorff
    corr_dice = corrcoef(stack([scores['dice_agreement'], scores['dice']], dim=0))[0,1]
    corr_surface = corrcoef(stack([scores['surface_agreement'], scores['surface']], dim=0))[0,1]

    print(f"Correlation Dice: {corr_dice} | Correlation Surface: {corr_surface}")

In [19]:
### OOD EVAL


UNET_CKPTS = {
    "mnmv2": 'mnmv2_symphony_dropout-0-1_2025-01-14-15-19', 
    'pmri': 'pmri_runmc_dropout-0-1_2025-01-14-15-58',
}

batch_size = 2

eval_metrics = {
    'dice': dice_per_class_loss,
    'surface': surface_loss
}

unet_cfg = OmegaConf.load('../configs/unet/monai_unet.yaml')

for it in range(0, 5):

    for dataset in ['pmri']:
        print(f"Dataset: {dataset}")

        if dataset == 'mnmv2':
            unet_cfg.out_channels = 4
            num_classes = 4
            data_cfg = OmegaConf.load('../configs/data/mnmv2.yaml')
            domain = 'Symphony'

        else:
            unet_cfg.out_channels = 1
            num_classes = 2
            data_cfg = OmegaConf.load('../configs/data/pmri.yaml')
            domain = 'RUNMC'
            sigma = 6.9899

        # for domain in ['siemens', 'ge', 'philips']:
        print(f"Train Vendor: {domain}")
        results = {}
        data_cfg.dataset = dataset
        data_cfg.domain = domain
        data_cfg.non_empty_target = True

        # if it == 0:
        datamodule = get_data_module(
            cfg=data_cfg
        )

        datamodule.setup('test')

        ckpt = UNET_CKPTS[data_cfg.dataset]
        unet_cfg.checkpoint_path = f'../../{unet_cfg.checkpoint_dir}{ckpt}.ckpt'
        unet_cfg.dropout = 0.1

        unet = get_unet_module(
            cfg=unet_cfg,
            metadata=OmegaConf.to_container(unet_cfg),
            load_from_checkpoint=True
        ).model

        for test_domain, test_dl in datamodule.test_dataloader().items():

            if 'train' in test_domain or 'val' in test_domain:
                continue

            test_dataset = test_dl.dataset
            test_dl = DataLoader(
                test_dataset,
                batch_size=1,
                shuffle=False,
            )
            print(f"test_domain: {test_domain}")
            scores = {
                'dice': [],
                'surface': [],
                'dice_agreement': [],
                'surface_agreement': []
            }

            manual_seed(it)
            for batch in tqdm(test_dl):
                input = batch['input'].repeat(batch_size, 1, 1 ,1)
                target = batch['target']

                unet.eval()
                logits = unet(input[:1].cuda())
                for m in unet.modules():
                    if m.__class__.__name__.startswith('Dropout'):
                        m.train()
                
                logits_dropout = unet(input.cuda())
            
                num_classes = max(logits_dropout.shape[1], 2)
                if num_classes > 2:
                    predictions = logits.argmax(1, keepdim=True)
                    predictions_dropout = logits_dropout.argmax(1, keepdim=True)
                    
                else:
                    predictions = (logits > 0) * 1
                    predictions_dropout = (logits_dropout > 0) * 1

                similarity = predictions_dropout.float().std(0).mean()
                # calculate dice agreement
                dice_agreement = pairwise_dice(predictions_dropout, num_classes=num_classes)
                surface_agreement = pairwise_surface_dice(predictions_dropout, num_classes=num_classes)
                scores['dice_agreement'].append(dice_agreement.detach().cpu().view(1,))
                scores['surface_agreement'].append(surface_agreement.detach().cpu().view(1,))

                for key, fn in eval_metrics.items():
                    _, _, true_score = fn(
                        predicted_segmentation=predictions, 
                        target_segmentation=target.cuda(),
                        prediction=zeros((input.size(0), 1, num_classes)).cuda(),
                        num_classes=num_classes,
                        sigma=0,
                        return_scores=True
                    )

                    scores[key].append(true_score.squeeze(1).detach().cpu())

            scores = {
                key: cat(scores[key]) for key in scores.keys()
            }

            # calculate correlation between entropy and dice/hausdorff
            corr_dice = corrcoef(stack([scores['dice_agreement'], scores['dice']], dim=0))[0,1]
            corr_surface = corrcoef(stack([scores['surface_agreement'], scores['surface']], dim=0))[0,1]
            results[test_domain] = {
                'scores': scores,
            }

            print(f"Correlation Dice: {corr_dice} | Correlation Surface: {corr_surface}")

        save(results, f'../../results/{dataset}_{domain}_score-agreement-{batch_size}-{it}.pt')


Dataset: pmri
Train Vendor: RUNMC
test_domain: BMC


100%|██████████| 324/324 [00:11<00:00, 28.23it/s]


Correlation Dice: 0.7855059504508972 | Correlation Surface: 0.5963176488876343
test_domain: I2CVB


100%|██████████| 505/505 [00:17<00:00, 28.29it/s]


Correlation Dice: 0.5624947547912598 | Correlation Surface: 0.48880690336227417
test_domain: UCL


100%|██████████| 171/171 [00:06<00:00, 28.15it/s]


Correlation Dice: 0.7880407571792603 | Correlation Surface: 0.6342188119888306
test_domain: BIDMC


100%|██████████| 197/197 [00:07<00:00, 28.14it/s]


Correlation Dice: 0.17424719035625458 | Correlation Surface: 0.17736348509788513
test_domain: HK


100%|██████████| 157/157 [00:05<00:00, 28.38it/s]


Correlation Dice: 0.7673707008361816 | Correlation Surface: 0.40917640924453735
Dataset: pmri
Train Vendor: RUNMC
test_domain: BMC


100%|██████████| 324/324 [00:11<00:00, 27.78it/s]


Correlation Dice: 0.7627540826797485 | Correlation Surface: 0.5852421522140503
test_domain: I2CVB


100%|██████████| 505/505 [00:18<00:00, 27.52it/s]


Correlation Dice: 0.5695006847381592 | Correlation Surface: 0.4669407606124878
test_domain: UCL


100%|██████████| 171/171 [00:06<00:00, 27.59it/s]


Correlation Dice: 0.756831169128418 | Correlation Surface: 0.6367197632789612
test_domain: BIDMC


100%|██████████| 197/197 [00:06<00:00, 28.35it/s]


Correlation Dice: 0.17942918837070465 | Correlation Surface: 0.16730129718780518
test_domain: HK


100%|██████████| 157/157 [00:05<00:00, 27.13it/s]


Correlation Dice: 0.7575271725654602 | Correlation Surface: 0.3883814513683319
Dataset: pmri
Train Vendor: RUNMC
test_domain: BMC


100%|██████████| 324/324 [00:11<00:00, 28.46it/s]


Correlation Dice: 0.7550802230834961 | Correlation Surface: 0.615959644317627
test_domain: I2CVB


100%|██████████| 505/505 [00:18<00:00, 27.45it/s]


Correlation Dice: 0.5631864070892334 | Correlation Surface: 0.49558478593826294
test_domain: UCL


100%|██████████| 171/171 [00:06<00:00, 28.06it/s]


Correlation Dice: 0.7716273665428162 | Correlation Surface: 0.655192494392395
test_domain: BIDMC


100%|██████████| 197/197 [00:07<00:00, 28.08it/s]


Correlation Dice: 0.15750263631343842 | Correlation Surface: 0.12678833305835724
test_domain: HK


100%|██████████| 157/157 [00:05<00:00, 26.89it/s]


Correlation Dice: 0.7519684433937073 | Correlation Surface: 0.4254618287086487
Dataset: pmri
Train Vendor: RUNMC
test_domain: BMC


100%|██████████| 324/324 [00:11<00:00, 28.90it/s]


Correlation Dice: 0.7759231925010681 | Correlation Surface: 0.6130805015563965
test_domain: I2CVB


100%|██████████| 505/505 [00:17<00:00, 28.69it/s]


Correlation Dice: 0.5713918805122375 | Correlation Surface: 0.509616494178772
test_domain: UCL


100%|██████████| 171/171 [00:05<00:00, 29.07it/s]


Correlation Dice: 0.7830401659011841 | Correlation Surface: 0.6543428301811218
test_domain: BIDMC


100%|██████████| 197/197 [00:06<00:00, 28.66it/s]


Correlation Dice: 0.1883888989686966 | Correlation Surface: 0.19922161102294922
test_domain: HK


100%|██████████| 157/157 [00:05<00:00, 28.17it/s]


Correlation Dice: 0.7747690081596375 | Correlation Surface: 0.4089813530445099
Dataset: pmri
Train Vendor: RUNMC
test_domain: BMC


100%|██████████| 324/324 [00:11<00:00, 28.78it/s]


Correlation Dice: 0.7810331583023071 | Correlation Surface: 0.610005259513855
test_domain: I2CVB


100%|██████████| 505/505 [00:18<00:00, 27.62it/s]


Correlation Dice: 0.5716392397880554 | Correlation Surface: 0.5139486193656921
test_domain: UCL


100%|██████████| 171/171 [00:05<00:00, 28.55it/s]


Correlation Dice: 0.7739604711532593 | Correlation Surface: 0.6454976201057434
test_domain: BIDMC


100%|██████████| 197/197 [00:06<00:00, 28.63it/s]


Correlation Dice: 0.16067102551460266 | Correlation Surface: 0.16493631899356842
test_domain: HK


100%|██████████| 157/157 [00:05<00:00, 28.74it/s]

Correlation Dice: 0.7289554476737976 | Correlation Surface: 0.3832800090312958





In [None]:
# 472 minutes for three eval runs across all datasets and tasks

In [7]:
# time measurements

### OOD EVAL


UNET_CKPTS = {
    "mnmv2": 'mnmv2_symphony_dropout-0-1_2025-01-14-15-19', 
    'pmri': 'pmri_runmc_dropout-0-1_2025-01-14-15-58',
}

batch_size = 15

eval_metrics = {
    'dice': dice_per_class_loss,
    'surface': surface_loss
}

unet_cfg = OmegaConf.load('../configs/unet/monai_unet.yaml')

for dataset in ['mnmv2', 'pmri']:
    print(f"Dataset: {dataset}")

    if dataset == 'mnmv2':
        unet_cfg.out_channels = 4
        num_classes = 4
        data_cfg = OmegaConf.load('../configs/data/mnmv2.yaml')
        domain = 'Symphony'



    else:
        unet_cfg.out_channels = 1
        num_classes = 2
        data_cfg = OmegaConf.load('../configs/data/pmri.yaml')
        domain = 'RUNMC'
        sigma = 6.9899

    # for domain in ['siemens', 'ge', 'philips']:
    print(f"Train Vendor: {domain}")
    results = {}
    data_cfg.dataset = dataset
    data_cfg.domain = domain
    data_cfg.non_empty_target = True

    datamodule = get_data_module(
        cfg=data_cfg
    )

    datamodule.setup('fit')

    ckpt = UNET_CKPTS[data_cfg.dataset]
    unet_cfg.checkpoint_path = f'../../{unet_cfg.checkpoint_dir}{ckpt}.ckpt'
    unet_cfg.dropout = 0.1

    unet = get_unet_module(
        cfg=unet_cfg,
        metadata=OmegaConf.to_container(unet_cfg),
        load_from_checkpoint=True
    ).model
    if dataset == 'mnmv2':
        data = datamodule.mnm_train

    else:
        data = datamodule.pmri_train

    input = data[10:11]['input'].repeat(batch_size, 1, 1 ,1)

    start = time()
    for i in range(100):
        unet.train()
        # logits = unet(input[:1].cuda())
        # for m in unet.modules():
        #     if m.__class__.__name__.startswith('Dropout'):
        #         m.train()
        _ = unet(input.cuda())

    time_taken = time() - start

    print(f'{time_taken / 100} Seconds per image for forward passes')

    start = time()
    for i in range(100):
        unet.train()
        # logits = unet(input[:1].cuda())
        # for m in unet.modules():
        #     if m.__class__.__name__.startswith('Dropout'):
        #         m.train()
        logits_dropout = unet(input.cuda())


        num_classes = max(logits_dropout.shape[1], 2)
        if num_classes > 2:
            predictions_dropout = logits_dropout.argmax(1, keepdim=True)
        else:
            predictions_dropout = (logits_dropout > 0) * 1

        dice_agreement = pairwise_dice(predictions_dropout, num_classes=num_classes)


    time_taken = time() - start

    print(f'{time_taken / 100} Seconds per image for forward passesd + volumetric Dice agreement')

    start = time()
    for i in range(100):
        unet.train()
        # logits = unet(input[:1].cuda())
        # for m in unet.modules():
        #     if m.__class__.__name__.startswith('Dropout'):
        #         m.train()
        logits_dropout = unet(input.cuda())


        num_classes = max(logits_dropout.shape[1], 2)
        if num_classes > 2:
            predictions_dropout = logits_dropout.argmax(1, keepdim=True)
        else:
            predictions_dropout = (logits_dropout > 0) * 1

        surface_agreement = pairwise_surface_dice(predictions_dropout, num_classes=num_classes)


    time_taken = time() - start

    print(f'{time_taken / 100} Seconds per image for forward passes + surface Dice agreement')

Dataset: mnmv2
Train Vendor: Symphony
0.015183804035186767 Seconds per image for forward passes
0.10741067409515381 Seconds per image for forward passesd + volumetric Dice agreement
1.879386966228485 Seconds per image for forward passes + surface Dice agreement
Dataset: pmri
Train Vendor: RUNMC
0.032966599464416504 Seconds per image for forward passes
0.0803830099105835 Seconds per image for forward passesd + volumetric Dice agreement
1.026671540737152 Seconds per image for forward passes + surface Dice agreement


Dataset: mnmv2
Train Vendor: Symphony
0.020400516986846924 Seconds per image for forward passes
0.12667125463485718  Seconds per image for forward passes + volumetric Dice agreement
1.933464493751526    Seconds per image for forward passes + surface Dice agreement



Dataset: pmri
Train Vendor: RUNMC
0.036640105247497556 Seconds per image for forward passes
0.09966156721115112  Seconds per image for forward passes + volumetric Dice agreement
1.0490441370010375   Seconds per image for forward passes + surface Dice agreement


Dataset: mnmv2
Train Vendor: Symphony
0.0204 Seconds per image for forward passes
0.1266 Seconds per image for forward passes + volumetric Dice agreement
1.9334 Seconds per image for forward passes + surface Dice agreement

Dataset: pmri
Train Vendor: RUNMC
0.0366 Seconds per image for forward passes
0.0996 Seconds per image for forward passes + volumetric Dice agreement
1.0490 Seconds per image for forward passes + surface Dice agreement