In [187]:
import matplotlib.pyplot as plt

import collections
import os, sys
import time
from typing import Iterable, Dict, Callable, Tuple

import numpy as np
import torch
from torch import Tensor, nn
from torch.utils.data import Dataset, DataLoader
from torch import optim
import torch.nn.functional as F
from torchvision.transforms import Resize, CenterCrop
from torchmetrics import SpearmanCorrCoef
import wandb
from tqdm.auto import tqdm
from batchgenerators.utilities.file_and_folder_operations import *
from nnunet.training.dataloading.dataset_loading import *
from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2

sys.path.append('..')
from utils import EarlyStopping, epoch_average, average_metrics, UMapGenerator, volume_collate
from dataset import CalgaryCampinasDataset, ACDCDataset, MNMDataset
from model.ae import AE
from model.unet import UNet2D, UNetEnsemble
from losses import MNMCriterionAE, CalgaryCriterionAE, SampleDice, UnetDice, DiceScoreCalgary, DiceScoreMMS #
from trainer.ae_trainer import AETrainerCalgary, AETrainerACDC


nnUnet_prefix = '../../../nnUNet/'


In [188]:

class UNetEnsemble(nn.Module):
    def __init__(
        self,
        unets: List,
    ):
        super().__init__()
        self.ensemble = nn.ModuleList(unets)

    def forward(self, x, reduce='none'):
        x_out = torch.cat([module(x.clone()) for module in self.ensemble])
        #print(x_out.shape)
        if reduce=='none':
            return x_out
        elif reduce=='mean':
            return x_out.mean(0)

In [206]:
cfg = {
    'debug': False,
    'log': True,
    'description': 'calgary_ae_test',
    'project': 'MICCAI2023',

    # Data params
    'n': 0,
    'root': '../../',
    'data_path': 'data/conp-dataset/projects/calgary-campinas/CC359/Reconstructed/',
    'train_site': 6,
    'unet': 'calgary_unet',
    'channel_out': 8,
    
    # Hyperparams
    'batch_size': 1,
    'augment': False,
    'difference': True,
    'loss': 'huber',
    'target': 'output',
    'identity_layers': ['shortcut0', 'shortcut1', 'shortcut2'],
    
    # outputs
    'plot_dir': '../experiments/unet/calgary/logs/'
}


description = cfg['description'] + str(cfg['n'])
if cfg['augment']:
    description += 'augment'

### data loading 
root      = cfg['root']
data_path = root + cfg['data_path']
train_set = CalgaryCampinasDataset(data_path=data_path, 
                                   site=cfg['train_site'], 
                                   augment=cfg['augment'], 
                                   normalize=True, 
                                   split='train', 
                                   debug=cfg['debug'])

valid_set = CalgaryCampinasDataset(data_path=data_path, 
                                   site=cfg['train_site'], 
                                   normalize=True,
                                   volume_wise=True,
                                   split='validation', 
                                   debug=cfg['debug'])


test_set = CalgaryCampinasDataset(data_path=data_path, 
                                  site=1, 
                                  normalize=True, 
                                  volume_wise=True,
                                  split='all', 
                                  debug=cfg['debug'])

train_loader = DataLoader(train_set, 
                          batch_size=cfg['batch_size'], 
                          shuffle=True, 
                          drop_last=False,
                          num_workers=10)

valid_loader = DataLoader(valid_set, 
                          batch_size=cfg['batch_size'], 
                          shuffle=False, 
                          drop_last=False,
                          num_workers=1,
                          collate_fn=volume_collate)

test_loader = DataLoader(test_set, 
                         batch_size=cfg['batch_size'], 
                         shuffle=False, 
                         drop_last=False,
                         num_workers=1,
                         collate_fn=volume_collate)

In [207]:
test_set = CalgaryCampinasDataset(data_path=data_path, 
                                  site=2, 
                                  normalize=True, 
                                  volume_wise=True,
                                  split='all', 
                                  debug=cfg['debug'])

test_loader = DataLoader(test_set, 
                         batch_size=cfg['batch_size'], 
                         shuffle=False, 
                         drop_last=False,
                         num_workers=1,
                         collate_fn=volume_collate)

In [208]:
# - load unets
unets = []
unet_path = cfg['unet']

for i in range(10):
    seg_model = UNet2D(n_chans_in=1, n_chans_out=1, n_filters_init=cfg['channel_out'])
    model_path = f'{root}pre-trained-tmp/trained_UNets/{unet_path + str(i)}_best.pt'
    state_dict = torch.load(model_path)['model_state_dict']
    seg_model.load_state_dict(state_dict)
    unets.append(seg_model)
    
# - verify that we have different models
param_sums = torch.tensor(
    [torch.tensor(
        [p.data.sum() for p in unets[i].parameters()]).sum()
     for i in range(10)]
)
assert torch.any(param_sums.mean() != param_sums)

# - build ensemble
ensemble = UNetEnsemble(unets)

In [209]:
class EnsembleEntropyDetector(nn.Module):
    """
    Evaluation class for OOD and ESCE tasks based on AEs.
    """
    
    def __init__(
        self, 
        model: nn.Module, 
        net_out: str,
        valid_loader: DataLoader,
        criterion: nn.Module,
        device: str = 'cuda:0'
    ):
        super().__init__()
        self.net_out = net_out
        self.device = device
        self.model = model.to(device)
        torch.manual_seed(42)
        self.ensemble_compositions = torch.cat([torch.arange(10).view(10,1), 
                                                torch.randint(0, 9, (10, 4))],
                                               dim=1)
        # Remove trainiung hooks, add evaluation hooks
        # self.model.remove_all_hooks()        
        
        self.valid_loader = valid_loader
        self.criterion = criterion
        self.umap_generator = UMapGenerator(method='probs',
                                            net_out=net_out)
        
        
    @torch.no_grad()
    def testset_ood_detection(self, test_loader: DataLoader) -> Dict[str, torch.Tensor]:
        
        if not hasattr(self, 'threshold'):
            valid_dists = []
            for batch in self.valid_loader:
                input_ = batch['input'].to(0)
                net_out = self.forward(input_.to(self.device))
                score = torch.norm(umap).cpu()
                valid_dists.append(score)
                    
            self.threshold = 0
            valid_dists = torch.tensor(valid_dists)
            self.threshold = torch.sort(valid_dists)[0][len(valid_dists) - (len(valid_dists) // 20) - 1]
        
        test_dists = []
        for batch in test_loader:
            input_ = batch['input']

            net_out = self.forward(input_.to(self.device))

            score = torch.norm(umap).cpu()
            test_dists.append(score)
            
        test_dists = torch.tensor(test_dists).cpu()
        accuracy = (test_dists > self.threshold).sum() / len(test_dists)
        
        return accuracy    
        
    
    @torch.no_grad()
    def testset_correlation(self, test_loader: DataLoader) -> Dict[str, torch.Tensor]:
        corr_coeffs = [SpearmanCorrCoef() for _ in range(10)]
        losses = []
        for batch in tqdm(test_loader):
            input_ = batch['input'].to(0)
            #print(input_.shape)
            target = batch['target']
            
            if self.net_out == 'calgary':
                net_out_volume = []
                #umap_volume  = []

                for input_chunk in input_:
                    net_out = self.forward(input_chunk.unsqueeze(0).to(self.device))
                    net_out_volume.append(net_out.detach().cpu())
                    #umap_volume.append(umap)
                    
                net_out = torch.stack(net_out_volume, dim=0)
                #print(net_out.shape)
                #umap = torch.cat(umap_volume, dim=0)
            
            if self.net_out == 'mms':
                target[target == -1] = 0
                # convert to one-hot encoding
                target = F.one_hot(target.long(), num_classes=4).squeeze(1).permute(0,3,1,2)
                net_out = self.forward(input_.to(self.device))
            
            #score = torch.norm(umap).cpu()
            
            for i, corr_coeff in enumerate(corr_coeffs):
                ensemble_idxs = self.ensemble_compositions[i]
                #umap = self.umap_generator(net_out[ensemble_idxs].mean(0, keepdim=True))
                
                if self.net_out == 'calgary':
                    loss = self.criterion(net_out[:, i:i+1].cpu(), target.cpu())
                    umap_volume  = []
                    for slc in net_out:
                        umap = self.umap_generator(slc[ensemble_idxs].mean(0, keepdim=True))
                        #break
                        umap_volume.append(umap)
                    umap = torch.cat(umap_volume, dim=0)
                
                if self.net_out == 'mms':
                    loss = self.criterion(net_out[i:i+1].cpu(), target.cpu())
                    umap = self.umap_generator(net_out[ensemble_idxs].mean(0, keepdim=True))
                    
                score = torch.norm(umap).cpu()    
                loss = loss.mean().cpu().float()
                corr_coeff.update(score, 1-loss)
                #losses.append(1-loss.view(1))
            
        return corr_coeffs

    
    @torch.no_grad()  
    def forward(self, input_: torch.Tensor) -> torch.Tensor:
        self.model.eval()
        net_out = self.model(input_)
        #umap = self.umap_generator(net_out.mean(0, keepdim=True))
        #score = torch.linalg.norm(umap, dim=(-2, -1))
        return net_out

In [210]:
# class EnsembleEntropyDetector(nn.Module):
#     """
#     Evaluation class for OOD and ESCE tasks based on AEs.
#     """
    
#     def __init__(
#         self, 
#         model: nn.Module, 
#         net_out: str,
#         valid_loader: DataLoader,
#         criterion: nn.Module,
#         device: str = 'cuda:0'
#     ):
#         super().__init__()
#         self.net_out = net_out
#         self.device = device
#         self.model = model.to(device)
#         # Remove trainiung hooks, add evaluation hooks
#         # self.model.remove_all_hooks()        
        
#         self.valid_loader = valid_loader
#         self.criterion = criterion
#         self.umap_generator = UMapGenerator(method='probs',
#                                             net_out=net_out)
        
        
#     @torch.no_grad()
#     def testset_ood_detection(self, test_loader: DataLoader) -> Dict[str, torch.Tensor]:
        
#         if not hasattr(self, 'threshold'):
#             valid_dists = []
#             for batch in self.valid_loader:
#                 input_ = batch['input'].to(0)
                
#                 if self.net_out == 'calgary':
#                     net_out_volume = []
#                     umap_volume  = []

#                     for input_chunk in input_:
#                         umap, net_out = self.forward(input_chunk.unsqueeze(0).to(self.device))
#                         net_out_volume.append(net_out.detach().cpu())
#                         umap_volume.append(umap)

#                     net_out = torch.cat(net_out_volume, dim=0)
#                     umap = torch.cat(umap_volume, dim=0)
                    
#                 if self.net_out == 'mms':
#                     umap, net_out = self.forward(input_.to(self.device))
#                 score = torch.norm(umap).cpu()
#                 valid_dists.append(score)
                    
#             self.threshold = 0
#             valid_dists = torch.tensor(valid_dists)
#             self.threshold = torch.sort(valid_dists)[0][len(valid_dists) - (len(valid_dists) // 20) - 1]
        
#         test_dists = []
#         for batch in test_loader:
#             input_ = batch['input']

#             if self.net_out == 'calgary':
#                 net_out_volume = []
#                 umap_volume  = []

#                 for input_chunk in input_:
#                     umap, net_out = self.forward(input_chunk.unsqueeze(0).to(self.device))
#                     net_out_volume.append(net_out.detach().cpu())
#                     umap_volume.append(umap)

#                 net_out = torch.cat(net_out_volume, dim=0)
#                 umap = torch.cat(umap_volume, dim=0)

#             if self.net_out == 'mms':
#                 umap, net_out = self.forward(input_.to(self.device))

#             score = torch.norm(umap).cpu()
#             test_dists.append(score)
            
#         test_dists = torch.tensor(test_dists).cpu()
#         accuracy = (test_dists > self.threshold).sum() / len(test_dists)
        
#         return accuracy    
        
    
#     @torch.no_grad()
#     def testset_correlation(self, test_loader: DataLoader) -> Dict[str, torch.Tensor]:
#         corr_coeffs = [SpearmanCorrCoef() for _ in range(10)]
#         losses = []
#         for batch in tqdm(test_loader):
#             input_ = batch['input'].to(0)
#             target = batch['target']
            
#             if self.net_out == 'calgary':
#                 net_out_volume = []
#                 umap_volume  = []

#                 for input_chunk in input_:
#                     umap, net_out = self.forward(input_chunk.unsqueeze(0).to(self.device))
#                     net_out_volume.append(net_out.detach().cpu())
#                     umap_volume.append(umap)
                    
#                 net_out = torch.stack(net_out_volume, dim=0)
#                 umap = torch.cat(umap_volume, dim=0)
            
#             if self.net_out == 'mms':
#                 target[target == -1] = 0
#                 # convert to one-hot encoding
#                 target = F.one_hot(target.long(), num_classes=4).squeeze(1).permute(0,3,1,2)
#                 umap, net_out = self.forward(input_.to(self.device))
            
#             score = torch.norm(umap).cpu()
            
#             for i, corr_coeff in enumerate(corr_coeffs):
#                 if self.net_out == 'calgary':
#                     loss = self.criterion(net_out[:, i:i+1].cpu(), target.cpu())
                
#                 if self.net_out == 'mms':
#                     loss = self.criterion(net_out[i:i+1].cpu(), target.cpu())
#                 loss = loss.mean().cpu().float()
#                 corr_coeff.update(score, 1-loss)
#                 #losses.append(1-loss.view(1))
            
#         return corr_coeffs

    
#     @torch.no_grad()  
#     def forward(self, input_: torch.Tensor) -> torch.Tensor:
#         self.model.eval()
#         net_out = self.model(input_)
#         umap = self.umap_generator(net_out.mean(0, keepdim=True))
#         #score = torch.linalg.norm(umap, dim=(-2, -1))
#         return umap, net_out

In [211]:
detector = EnsembleEntropyDetector(model=ensemble, 
                                   net_out='calgary', 
                                   valid_loader=valid_loader, 
                                   criterion=DiceScoreCalgary())


In [212]:
# tmp_acc = detector.testset_ood_detection(test_loader)

In [213]:
tmp_cor = detector.testset_correlation(test_loader)



  0%|          | 0/60 [00:00<?, ?it/s]

In [204]:
tmp_cor

[SpearmanCorrCoef(),
 SpearmanCorrCoef(),
 SpearmanCorrCoef(),
 SpearmanCorrCoef(),
 SpearmanCorrCoef(),
 SpearmanCorrCoef(),
 SpearmanCorrCoef(),
 SpearmanCorrCoef(),
 SpearmanCorrCoef(),
 SpearmanCorrCoef()]

In [215]:
torch.tensor([tmp_cor[i].compute() for i in range(10)]).mean()

tensor(0.6294)

'1'
tensor(0.8288), tensor(0.8383), tensor(0.8342), tensor(0.8465), tensor(0.8315), tensor(0.8254), tensor(0.8407), tensor(0.8318), tensor(0.8290), tensor(0.8295)

'2' : 0.6854
tensor(0.7273), tensor(0.7371), tensor(0.6117), tensor(0.7785), tensor(0.7030), tensor(0.6041), tensor(0.6962), tensor(0.6205), tensor(0.6614), tensor(0.7140)

In [165]:
cfg = {
        'debug': False,
        'log': False,
        'description': f'acdc_ae_test', #'mms_vae_for_nnUNet_fc3_0_bs50',
        'project': 'MICCAI2023',

        # Data params
        'n': 0,
        'root': '../../',
        'data_path': 'data/mnm/',
        'train_vendor': 'B',
        'unet': f'acdc_unet8_',
        'channel_out': 8,

        # Hyperparams
        'batch_size': 32,
        'augment': False,
        'difference': True,
        'loss': 'huber',  # huber or ce
        'target': 'output', #gt or output
        'disabled_ids': ['shortcut0', 'shortcut1', 'shortcut2']
}

description = cfg['description']
root = cfg['root']
debug = cfg['debug']

data = 'data/mnm/'
data_path = root + data
train_set = ACDCDataset(data='train', 
                        debug=debug)

train_loader = DataLoader(train_set, 
                          batch_size=32, 
                          shuffle=False,
                          drop_last=False,
                          num_workers=10)

valid_set = ACDCDataset(data='val', 
                        debug=debug)

valid_loader = DataLoader(valid_set, 
                          batch_size=1, 
                          shuffle=False, 
                          drop_last=False, 
                          num_workers=10)

test_set = MNMDataset(vendor='B', 
                      debug=debug)


test_loader = DataLoader(test_set, 
                         batch_size=1, 
                         shuffle=False, 
                         drop_last=False,
                         num_workers=10)

loading dataset
loading all case properties
loading dataset
loading all case properties
loading dataset
loading all case properties


In [166]:
# - load unets
unets = []
unet_path = cfg['unet']

for i in range(10):
    seg_model = UNet2D(n_chans_in=1, n_chans_out=4, n_filters_init=cfg['channel_out'])
    model_path = f'{root}pre-trained-tmp/trained_UNets/{unet_path + str(i)}_best.pt'
    state_dict = torch.load(model_path)['model_state_dict']
    seg_model.load_state_dict(state_dict)
    unets.append(seg_model)
    
# - verify that we have different models
param_sums = torch.tensor(
    [torch.tensor(
        [p.data.sum() for p in unets[i].parameters()]).sum()
     for i in range(10)]
)
assert torch.any(param_sums.mean() != param_sums)

# - build ensemble
ensemble = UNetEnsemble(unets)

In [168]:
detector = EnsembleEntropyDetector(model=ensemble, 
                                   net_out='mms', 
                                   valid_loader=valid_loader, 
                                   criterion=DiceScoreMMS())

In [101]:
# tmp_acc = detector.testset_ood_detection(test_loader)

KeyboardInterrupt: 

In [169]:
tmp_cor = detector.testset_correlation(test_loader)

  0%|          | 0/2642 [00:01<?, ?it/s]

KeyboardInterrupt: 

In [113]:
'B'
[tmp_cor[i].compute() for i in range(10)]

[tensor(0.2339),
 tensor(0.1467),
 tensor(0.2289),
 tensor(0.1222),
 tensor(0.1496),
 tensor(0.1020),
 tensor(0.2445),
 tensor(0.1846),
 tensor(0.0649),
 tensor(0.1746)]

In [109]:
[tmp_cor[i].compute() for i in range(10)]

[tensor(0.2736),
 tensor(0.2631),
 tensor(0.3509),
 tensor(0.0849),
 tensor(0.3006),
 tensor(0.2418),
 tensor(0.3077),
 tensor(0.2769),
 tensor(0.1515),
 tensor(0.2748)]

'B'
tensor(0.2339), tensor(0.1467), tensor(0.2289), tensor(0.1222), tensor(0.1496), tensor(0.1020), tensor(0.2445), tensor(0.1846), tensor(0.0649), tensor(0.1746)

'A'
tensor(0.2736), tensor(0.2631), tensor(0.3509), tensor(0.0849), tensor(0.3006), tensor(0.2418), tensor(0.3077), tensor(0.2769), tensor(0.1515), tensor(0.2748)

In [127]:
x = torch.ones((10,10,10))

In [129]:
idx = [0,1,3]
x[:, idx].shape

torch.Size([10, 3, 10])

In [130]:
import random


In [137]:
torch.randint(0, 9, (10, 4))

tensor([[4, 7, 5, 3],
        [7, 2, 6, 2],
        [7, 8, 0, 7],
        [3, 6, 4, 1],
        [1, 3, 0, 4],
        [4, 5, 3, 6],
        [6, 6, 1, 2],
        [0, 2, 7, 6],
        [3, 8, 3, 0],
        [5, 4, 0, 2]])

In [153]:
idx = torch.cat([torch.arange(10).view(10,1), torch.randint(0, 9, (10, 4))], dim=1)[0]

In [160]:
x[:, idx]

tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1.,

In [151]:
idx.data

[0, 5, 2, 2, 7]

In [155]:
x[:, [0, 5, 2, 2, 7]]

tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1.,