In [1]:
import matplotlib.pyplot as plt
plt.show()

In [1]:
import matplotlib.pyplot as plt
import os, sys
from typing import Iterable, Dict, List, Callable, Tuple, Union, List

import numpy as np
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision.utils import save_image
import fiftyone as fo
import shutil
from torchmetrics import Dice

sys.path.append('../../')
from dataset import CalgaryCampinasDataset
from model.unet import UNet2D
from model.ae import AE
from model.dae import resDAE, AugResDAE
from model.wrapper import Frankenstein, ModelAdapter
from losses import DiceScoreCalgary, SurfaceDiceCalgary
from utils import  epoch_average, UMapGenerator, volume_collate
from trainer.unet_trainer import UNetTrainerCalgary
from data_utils import get_subset



Please cite the following paper when using nnUNet:

Isensee, F., Jaeger, P.F., Kohl, S.A.A. et al. "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation." Nat Methods (2020). https://doi.org/10.1038/s41592-020-01008-z


If you have questions or suggestions, feel free to open an issue at https://github.com/MIC-DKFZ/nnUNet



In [2]:
### Config
root = '../../../'
data_dir = 'data/conp-dataset/projects/calgary-campinas/CC359/Reconstructed/'
data_path = root + data_dir
debug = False
augment = False
site = 4

In [3]:
trainset = CalgaryCampinasDataset(
    data_path=data_path, 
    site=6,
    split='train',
    augment=augment, 
    normalize=True, 
    debug=debug
)

In [4]:
valset = CalgaryCampinasDataset(
    data_path=data_path, 
    site=6,
    split='validation',
    augment=augment, 
    normalize=True, 
    debug=debug
)

In [5]:
testsets = []
for site in [1,2,3,4,5]:
    testsets.append(
        CalgaryCampinasDataset(
            data_path=data_path, 
            site=site,
            split='all',
            augment=augment, 
            normalize=True, 
            debug=debug
        )
    )

In [10]:
model_path = f'../../../pre-trained-tmp/trained_UNets/calgary_unet0_augmentednnUNet_best.pt'
state_dict = torch.load(model_path)['model_state_dict']
n_chans_out = 1 
seg_model = UNet2D(
    n_chans_in=1, 
    n_chans_out=n_chans_out, 
    n_filters_init=8,
    dropout=False
)
seg_model.load_state_dict(state_dict)
seg_model.to(0)
print()




In [18]:
subsets = [
     get_subset(
        dataset,
        seg_model,
        criterion=nn.BCEWithLogitsLoss(reduction='none'),
        n_cases=50,
        fraction=0.1,
        batch_size=32
    ) for dataset in [trainset, valset, *testsets]
]

In [19]:
disabled_ids = ['shortcut0', 'shortcut1', 'shortcut2']
DAEs = nn.ModuleDict(
    {'up3': AugResDAE(
        in_channels = 64, 
        in_dim      = 32,
        latent_dim  = 256,
        depth       = 3,
        block_size  = 4)
    }
)

for layer_id in disabled_ids:
    DAEs[layer_id] = nn.Identity()

model = ModelAdapter(
    seg_model=seg_model,
    transformations=DAEs,
    disabled_ids=disabled_ids,
    copy=True
)
model_path = f'../../../pre-trained-tmp/trained_AEs/calgary_AugResDAE0_localAug_multiImgSingleView_res_balanced_same_best.pt'
state_dict = torch.load(model_path)['model_state_dict']
model.load_state_dict(state_dict)

# Remove trainiung hooks, add evaluation hooks
model.remove_all_hooks()        
model.hook_inference_transformations(model.transformations,
                           n_samples=1)
# Put model in evaluation state
model.to(0)
model.eval()
model.freeze_seg_model()

In [23]:
@torch.no_grad()
def get_downstream_perf(
    dataset,
    model: nn.Module,
    criterion: nn.Module,
    device: str = 'cuda:0',
    batch_size: int = 1
):
    dataloader = DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=False,
        drop_last=False
    )
    
    # collect evaluation per slice and cache
    #assert criterion.reduction == 'none'
    model.eval()
    loss_list = []
    for batch in dataloader:
        input_  = batch['input'].to(device)
        target  = batch['target'].to(device)
        net_out = (torch.sigmoid(model(input_)) > 0.5) * 1
        loss    = torch.tensor([
            criterion(net_out[:1], target).view(input_.shape[0], -1).mean(1),
            criterion(net_out[1:], target).view(input_.shape[0], -1).mean(1)
        ])
        loss_list.append(loss)
        
    return torch.stack(loss_list, dim=1)


In [28]:
for subset in subsets:
    print(get_downstream_perf(subset, model, DiceScoreCalgary()).mean(1)) 

tensor([0.9815, 0.9816])
tensor([0.9784, 0.9785])
tensor([0.9476, 0.9503])
tensor([0.9589, 0.9789])
tensor([0.9533, 0.9726])
tensor([0.9268, 0.9436])
tensor([0.8859, 0.9257])


In [27]:
tmp.mean(1)

tensor([0.9589, 0.9789])