In [8]:
import warnings
import sys
import collections
from random import sample, seed

import torch
from torch import nn, Tensor
from torch.utils.data import Dataset, DataLoader, default_collate
from tqdm.auto import tqdm

warnings.filterwarnings("ignore")
sys.path.append('..')
from dataset import CalgaryCampinasDataset, ACDCDataset, MNMDataset
from model.unet import UNet2D, UNetEnsemble
from model.ae import AE
from model.dae import resDAE, AugResDAE
from model.wrapper import Frankenstein
from losses import (
    DiceScoreCalgary, 
    DiceScoreMMS, 
    SurfaceDiceCalgary,
    AccMMS
)
from utils import volume_collate
from eval.slice_wise import (
    PoolingMahalabonisDetector, 
    AEMahalabonisDetector, 
    MeanDistSamplesDetector, 
    EntropyDetector, 
    EnsembleEntropyDetector
)

In [2]:
# Globals
ROOT = '../../'
SEED = 42
debug=False
net_out='mms'
method='single'
task='corr'
scanner='A'
n_unets=1
post='localAug_multiImgSingleView_res_balanced_same'

data = 'data/mnm/'
data_path = ROOT + data

In [3]:


train_set = ACDCDataset(data='train', 
                        debug=debug)

train_loader = DataLoader(train_set, 
                          batch_size=32, 
                          shuffle=False,
                          drop_last=False,
                          num_workers=10)

valid_set = ACDCDataset(data='val', 
                        debug=debug)

valid_loader = DataLoader(valid_set, 
                          batch_size=1, 
                          shuffle=False, 
                          drop_last=False, 
                          num_workers=10)

if scanner != 'val':
    test_set = MNMDataset(vendor=scanner, 
                          debug=debug)
else:
    test_set = ACDCDataset(data=scanner, 
                           debug=debug)

test_loader = DataLoader(test_set, 
                         batch_size=1, 
                         shuffle=False, 
                         drop_last=False,
                         num_workers=10)



loading dataset
loading all case properties
loading dataset
loading all case properties
loading dataset
loading all case properties


In [4]:
# U-Nets
middle = 'unet' if net_out == 'calgary' else 'unet8_'
pre = 'calgary' if net_out == 'calgary' else 'acdc'
unet_names = [f'{pre}_{middle}{i}' for i in range(n_unets)] #TODO
unets = []
for name in unet_names:
    model_path = f'{ROOT}pre-trained-tmp/trained_UNets/{name}_best.pt'
    state_dict = torch.load(model_path)['model_state_dict']
    n_chans_out = 1 if net_out == 'calgary' else 4
    unet = UNet2D(n_chans_in=1, 
                  n_chans_out=n_chans_out, 
                  n_filters_init=8, 
                  dropout=False)
    unet.load_state_dict(state_dict)
    unets.append(unet)

metrics = []

In [5]:
from typing import Iterable, Dict, List, Callable, Tuple, Union, List

import torch #
from torch import Tensor, nn #
from torch.utils.data import Dataset, DataLoader, default_collate
import torch.nn.functional as F
from sklearn.cluster import KMeans #
from sklearn.metrics import pairwise_distances_argmin_min #
from sklearn.covariance import LedoitWolf #
from scipy.stats import binned_statistic #
from tqdm.auto import tqdm #
from torchmetrics import (
    SpearmanCorrCoef, 
    AUROC)
from losses import DiceScoreCalgary, DiceScoreMMS #
from utils import _activate_dropout, UMapGenerator


class MeanDistSamplesDetector(nn.Module):
    """
    Evaluation class for OOD and ESCE tasks based on VAEs.
    """
    
    def __init__(
        self,
        model: nn.Module, 
        n_samples: int,
        net_out: str,
        valid_loader: DataLoader,
        criterion: nn.Module, # e.g. DiceScoreCalgary()
        device: str = 'cuda:0',
        method: str = 'vae'
    ):
        super().__init__()
        self.device = device
        self.model = model.to(device)
        self.net_out  = net_out
        # Remove trainiung hooks, add evaluation hooks
        self.model.remove_all_hooks()        
        self.model.hook_inference_transformations(self.model.transformations,
                                        n_samples=n_samples)
        
        self.model.eval()
        self.model.freeze_seg_model()
        
        self.valid_loader = valid_loader
        self.criterion = criterion
        self.auroc = AUROC(task = 'binary')
        self.umap_generator = UMapGenerator(method=method,
                                            net_out=net_out)
        
    @torch.no_grad()
    def testset_ood_detection(self, test_loader: DataLoader) -> Dict[str, torch.Tensor]:
        if not hasattr(self, 'threshold'):
            valid_dists = []
            for batch in self.valid_loader:
                input_ = batch['input'].to(0)
                
                if self.net_out == 'calgary':
                    net_out_volume = []
                    umap_volume  = []

                    for input_chunk in input_:
                        umap, net_out = self.forward(input_chunk.unsqueeze(0).to(self.device))
                        net_out_volume.append(net_out[:1].detach().cpu())
                        umap_volume.append(umap)

                    net_out = torch.cat(net_out_volume, dim=0)
                    umap = torch.cat(umap_volume, dim=0)
                    
                if self.net_out == 'mms':
                    umap, net_out = self.forward(input_.to(self.device))
                score = torch.norm(umap).cpu()
                valid_dists.append(score)
                
            self.valid_dists = torch.tensor(valid_dists)
            self.valid_labels = torch.zeros(len(self.valid_dists), dtype=torch.uint8)
        
        test_dists = []
        for batch in test_loader:
            input_ = batch['input']

            if self.net_out == 'calgary':
                net_out_volume = []
                umap_volume  = []

                for input_chunk in input_:
                    umap, net_out = self.forward(input_chunk.unsqueeze(0).to(self.device))
                    net_out_volume.append(net_out[:1].detach().cpu())
                    umap_volume.append(umap)

                net_out = torch.cat(net_out_volume, dim=0)
                umap = torch.cat(umap_volume, dim=0)

            if self.net_out == 'mms':
                umap, net_out = self.forward(input_.to(self.device))

            score = torch.norm(umap).cpu()
            test_dists.append(score)
        self.test_dists = torch.tensor(test_dists).cpu()
        self.test_labels = torch.ones(len(self.test_dists), dtype=torch.uint8)
        
        self.pred =  torch.cat([self.valid_dists, self.test_dists]).squeeze()
        self.target = torch.cat([self.valid_labels, self.test_labels]).squeeze()
        print(self.pred.shape, self.target.shape)
        AUROC = self.auroc(self.pred, self.target)
        
        return AUROC    
        
    
    @torch.no_grad()
    def testset_correlation(self, test_loader: DataLoader) -> Dict[str, torch.Tensor]:
        corr_coeff = SpearmanCorrCoef()
        losses = []
        for batch in test_loader:
            input_ = batch['input']
            target = batch['target']
            
            if self.net_out == 'calgary':
                net_out_volume = []
                umap_volume  = []

                for input_chunk in input_:
                    umap, net_out = self.forward(input_chunk.unsqueeze(0).to(self.device))
                    net_out_volume.append(net_out[:1].detach().cpu())
                    umap_volume.append(umap)
                    
                net_out = torch.cat(net_out_volume, dim=0)
                umap = torch.cat(umap_volume, dim=0)
            
            if self.net_out == 'mms':
                target[target == -1] = 0
                # convert to one-hot encoding
                target = F.one_hot(target.long(), num_classes=4).squeeze(1).permute(0,3,1,2)
                #print(target.min(), target.max())
                umap, net_out = self.forward(input_.to(self.device))
            
#             try:
#                 print(net_out.shape, target.shape)
            loss = self.criterion(net_out, torch.argmax(target, dim=1))
#                 print(loss)
#             except:
#                 net_out = torch.rand(10, 4, 20, 20)
#                 target = torch.argmax(
#                     torch.tensor([1,0,0,0]).view(1, 4, 1, 1).repeat(10, 1, 20, 20),
#                     dim=1
#                 )
#                 self.criterion(net_out, target)
#                 print(target.min(), target.max())
            
            loss = loss.mean().float()
            
            score = torch.norm(umap)
            #print(loss.shape, score.shape)
            losses.append(1-loss.view(1))
            corr_coeff.update(score.cpu().view(1,), 1-loss.view(1,))
            
        return corr_coeff

    
    @torch.no_grad()  
    def forward(self, input_: torch.Tensor) -> torch.Tensor:
        net_out = self.model(input_).cpu()
        umap    = self.umap_generator(net_out).cpu()
        return umap, net_out[:1]

In [6]:
disabled_ids = ['shortcut0', 'shortcut1', 'shortcut2']
DAEs = nn.ModuleDict({'up3': AugResDAE(in_channels = 64, 
                                    in_dim      = 32,
                                    latent_dim  = 256 if net_out=='mms' else 64,
                                    depth       = 3,
                                    block_size  = 4)})


for layer_id in disabled_ids:
    DAEs[layer_id] = nn.Identity()

for i, unet in enumerate(tqdm(unets)):
    print(f"Method {method}, Unet {i} - {net_out}")

    model = Frankenstein(seg_model=unet,
                         transformations=DAEs,
                         disabled_ids=disabled_ids,
                         copy=True)
    model_path = f'{ROOT}pre-trained-tmp/trained_AEs/acdc_AugResDAE{i}_{post}_best.pt'
    #model_path = f'{ROOT}pre-trained-tmp/trained_AEs/{pre}_resDAE{i}_{post}_best.pt'
    #model_path = f'{ROOT}pre-trained-tmp/trained_AEs/acdc_epinet_CE-only_prior-1_best.pt'
    #model_path = f'{ROOT}pre-trained-tmp/trained_AEs/acdc_resDAE0_venus_best.pt'
    state_dict = torch.load(model_path)['model_state_dict']
    model.load_state_dict(state_dict)

    metrics.append({})
    detector = MeanDistSamplesDetector(model=model,
                                       n_samples=1,
                                       valid_loader=valid_loader,
                                       net_out=net_out,
                                       method='mse',
                                       criterion=AccMMS())
    if task == 'ood' or task == 'both':
        metrics[i]['ood'] = detector.testset_ood_detection(test_loader)
    if task == 'corr' or task == 'both':
        metrics[i]['corr'] = detector.testset_correlation(test_loader)

  0%|          | 0/1 [00:00<?, ?it/s]

Method single, Unet 0 - mms


In [7]:
metrics[0]['corr'].compute()

tensor(0.4242)

In [42]:
metrics[0]['corr'].compute()

tensor(0.4281)

In [1]:
import torch
import torchmetrics

In [9]:
acc = torchmetrics.Accuracy(
    task='multiclass', 
    num_classes=4, 
    multidim_average='samplewise', 
    average='macro', 
    ignore_index=3)

In [10]:
inp = torch.rand(10, 4, 20, 20)
tar = torch.argmax(
    torch.tensor([1,0,0,0]).view(1, 4, 1, 1).repeat(10, 1, 20, 20),
    dim=1
)

In [11]:
acc(inp, tar)

tensor([0.2450, 0.2300, 0.2425, 0.2325, 0.2500, 0.2275, 0.2625, 0.2350, 0.2450,
        0.2625])