### Imports, settings, globals

In [1]:
### Set CUDA device
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [2]:
from torch import Tensor
import sys
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List

sys.path.append('../')
from data_utils import get_eval_data
from model.unet import get_unet

In [3]:
### Load basic config
cfg = OmegaConf.load('../configs/basic_config.yaml')
OmegaConf.update(cfg, 'run.iteration', 0)

### UNet Models

In [4]:
### Set dataset, either brain or heart
DATA_KEY = 'heart'
OmegaConf.update(cfg, 'run.data_key', DATA_KEY)


### get model
# available models:
#     - default-8
#     - default-16
#     - monai-8-4-4
#     - monai-16-4-4
#     - monai-16-4-8
#     - monai-32-4-4
#     - monai-64-4-4
#     - swinunetr

#unet_name = 'monai-64-4-4'
unet_name = 'swinunetr'
args = unet_name.split('-')
cfg.unet[DATA_KEY].pre = unet_name
cfg.unet[DATA_KEY].arch = args[0]
cfg.unet[DATA_KEY].n_filters_init = None if unet_name == 'swinunetr' else int(args[1])
if args[0] == 'monai':
    cfg.unet[DATA_KEY].depth = int(args[2])
    cfg.unet[DATA_KEY].num_res_units = int(args[3])

unet, state_dict = get_unet(
    cfg,
    update_cfg_with_swivels=False,
    return_state_dict=True)
unet.load_state_dict(state_dict)
_ = unet.cuda()


### Data

In [7]:
# Set debug mode for only a small fraction of the datasets. Speeds up this cell by alot
cfg.debug = True

# update config with default values
OmegaConf.update(
    cfg, 
    'eval',
    OmegaConf.load('../configs/eval/unet_config.yaml')
)

# Wether and how you want to subset in case of Brain data. WARNING:
# After subsetting the eval below will not work with surface
# Dice anymore, because volumes are fragmented. 
APPLY_SUBSETTING = True
OmegaConf.update(cfg, 'eval.data.subset.apply', APPLY_SUBSETTING)
subset_params = {
    'n_cases': 256,  # selects at most so many cases
    'fraction': 0.1, # selects from the 10% worst cases w.r.t to a model
}
OmegaConf.update(
    cfg, 
    'eval.data.subset.params', 
    OmegaConf.create(subset_params)
)

if cfg.eval.data.subset.apply:
    subset_dict = OmegaConf.to_container(
        cfg.eval.data.subset.params, 
        resolve=True, 
        throw_on_missing=True
    )
    subset_dict['unet'] = unet
else:
    subset_dict = None

### select the datasets within the domain
# get training data
cfg.eval.data.training = True
# get validation data
cfg.eval.data.validation = True
# get testing data
# Options for Brain are any subset of [1, 2, 3, 4, 5] or 'all' 
# Options for Heart are any subset of ['A', 'B', 'C', 'D'] or 'all'
cfg.eval.data.testing = ['A']


raw_data = get_eval_data(
    train_set=cfg.eval.data.training,
    val_set=cfg.eval.data.validation,
    test_sets=cfg.eval.data.testing,    
    cfg=cfg,
    subset_dict=subset_dict
)

print(f'\nAvailable datasets are: {list(raw_data.keys())}')

loading dataset
loading all case properties


loading dataset
loading all case properties
loading dataset
loading all case properties

Available datasets are: ['train', 'val', 'A']


### Old Mahalanobis Detector Code

In [6]:
from typing import Dict, List, Callable, Tuple, List


import torch #
from torch.utils.data import DataLoader, default_collate
from torch import Tensor, nn #
from torchmetrics import (
    SpearmanCorrCoef, 
    AUROC
)
from sklearn.covariance import LedoitWolf #

from losses import DiceScoreCalgary, DiceScoreMMS #


class PoolingMahalabonisDetector(nn.Module):
    """
    Evaluation class for OOD and ESCE tasks based on https://arxiv.org/abs/2107.05975.
    """
    
    def __init__(
        self, 
        model: nn.Module, 
        layer_ids: List[str], 
        train_loader: DataLoader, 
        valid_loader: DataLoader,
        net_out: str,
        criterion: nn.Module = DiceScoreCalgary(),
        device: str = 'cuda:0'
    ):
        super().__init__()
        self.device       = device
        self.model        = model.to(device)
        self.model.eval()
        self.layer_ids    = layer_ids
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.net_out      = net_out
        self.criterion    = criterion
        self.pool         = nn.AvgPool3d(kernel_size=(2,2,2), stride=(2,2,2))
        self.auroc        = AUROC(task = 'binary')
        
        # Init score dict for each layer:
        self.latents   = {layer_id: [] for layer_id in self.layer_ids}
        self.mu        = {layer_id: None for layer_id in self.layer_ids}
        self.sigma_inv = {layer_id: None for layer_id in self.layer_ids}
        self.dist      = {layer_id : 0 for layer_id in self.layer_ids}
        
        self._get_latents()
        self._fit_gaussian_to_latents()
        
        
    @torch.no_grad()
    def _get_hook_fn(self, layer_id: str, mode: str = 'collect') -> Callable:
        
        def hook_fn(module: nn.Module, x: Tuple[Tensor]):
            x = x[0]
            while torch.prod(torch.tensor(x.shape[1:])) > 1e4:
                x = self.pool(x)
            x = self.pool(x)
            batch_size = x.shape[0]

            if mode == 'collect':
                self.latents[layer_id].append(x.view(batch_size, -1).detach().cpu())
            elif mode == 'single':
                self.dist[layer_id] = x.view(batch_size, -1).to(self.device)
                
        return hook_fn
    
    
    @torch.no_grad()        
    def _get_latents(self) -> None:
        handles = {}
        for layer_id in self.layer_ids:
            layer = self.model.get_submodule(layer_id)
            hook  = self._get_hook_fn(layer_id, mode='collect')
            handles[layer_id] = layer.register_forward_pre_hook(hook)

        for batch in self.train_loader:
            input_ = batch['input'].to(self.device)
            _ = self.model(input_)
                
        for layer_id in handles:
            self.latents[layer_id] = torch.cat(self.latents[layer_id], dim=0)
            handles[layer_id].remove()
        
        
    @torch.no_grad()         
    def _fit_gaussian_to_latents(self) -> None:
        for layer_id in self.layer_ids:
            self.mu[layer_id] = self.latents[layer_id].mean(0, keepdims=True).to(self.device)
            latents_centered = (self.latents[layer_id] - self.mu[layer_id].cpu()).detach().numpy()
            sigma = torch.from_numpy(LedoitWolf().fit(latents_centered).covariance_)
            self.sigma_inv[layer_id] = torch.linalg.inv(sigma).unsqueeze(0).to(self.device)
            
            
    @torch.no_grad()
    def testset_ood_detection(self, test_loader: DataLoader) -> Dict[str, torch.Tensor]:
        
        self.pred = {}
        self.target = {}
        
        valid_dists = {layer_id : [] for layer_id in self.layer_ids}
        for batch in self.valid_loader:
            input_ = batch['input']
            #print(input_.shape)
            if self.net_out == 'calgary':
                dist_volume = []
                for input_chunk in input_:
                    dist, _ = self.forward(input_chunk.unsqueeze(0).to(self.device))
                    dist_volume.append(dist.copy())
                dist = default_collate(dist_volume)
            elif self.net_out == 'mms': 
                dist, _ = self.forward(input_.to(self.device))
            for layer_id in self.layer_ids:
                if self.net_out == 'calgary':
                    valid_dists[layer_id].append(dist[layer_id].mean())
                elif self.net_out == 'mms':
                    valid_dists[layer_id].append(dist[layer_id])
        self.valid_dists = valid_dists
        self.valid_labels = {layer_id: torch.zeros(len(self.valid_dists[layer_id]), dtype=torch.uint8) 
                             for layer_id in self.layer_ids}
        #print(len(self.valid_dists['up3']), len(self.valid_labels['up3']))
            
#             self.thresholds = {layer_id : 0 for layer_id in self.layer_ids}
#             for layer_id in self.layer_ids:
#                 if self.net_out == 'calgary':
#                     valid_dists[layer_id] = torch.tensor(valid_dists[layer_id]).cpu()
#                 elif self.net_out == 'mms':
#                     valid_dists[layer_id] = torch.cat(valid_dists[layer_id], dim=0).cpu()
#                 self.thresholds[layer_id] = torch.sort(valid_dists[layer_id])[0][len(valid_dists[layer_id]) - (len(valid_dists[layer_id]) // 20) - 1]
                
                    
        test_dists = {layer_id : [] for layer_id in self.layer_ids}
        for batch in test_loader:
            input_ = batch['input']
            if self.net_out == 'calgary':
                dist_volume = []
                for input_chunk in input_:
                    dist, _ = self.forward(input_chunk.unsqueeze(0).to(self.device))
                    dist_volume.append(dist.copy())
                dist = default_collate(dist_volume)
            elif self.net_out == 'mms': 
                dist, _ = self.forward(input_.to(self.device))
            for layer_id in self.layer_ids:
                if self.net_out == 'calgary':
                    test_dists[layer_id].append(dist[layer_id].mean())
                elif self.net_out == 'mms':    
                    test_dists[layer_id].append(dist[layer_id])
        
        self.test_dists = test_dists
        self.test_labels = {layer_id: torch.ones(len(self.test_dists[layer_id]), dtype=torch.uint8) 
                             for layer_id in self.layer_ids}
            
            
        AUROC = {layer_id : 0 for layer_id in self.layer_ids}
        for layer_id in self.layer_ids:
            if self.net_out == 'calgary':
                self.valid_dists[layer_id] = torch.tensor(self.valid_dists[layer_id]).cpu()
                self.test_dists[layer_id]  = torch.tensor(self.test_dists[layer_id]).cpu()
            elif self.net_out == 'mms':
                self.valid_dists[layer_id] = torch.cat(self.valid_dists[layer_id], dim=0).cpu()
                self.test_dists[layer_id]  = torch.cat(self.test_dists[layer_id], dim=0).cpu()
            self.pred[layer_id]   = torch.cat([self.valid_dists[layer_id], self.test_dists[layer_id]]).squeeze()
            self.target[layer_id] = torch.cat([self.valid_labels[layer_id], self.test_labels[layer_id]]).squeeze()
            
            print(self.pred[layer_id].shape, self.target[layer_id].shape)
            
            AUROC[layer_id] = self.auroc(self.pred[layer_id], self.target[layer_id])
            #accuracy[layer_id] = ((test_dists[layer_id] > self.thresholds[layer_id]).sum() / len(test_dists[layer_id]))
                
        return AUROC
    
    
    
    @torch.no_grad()        
    def testset_correlation(self, test_loader: DataLoader) -> Dict[str, torch.Tensor]:
        corr_coeffs = {layer_id: SpearmanCorrCoef() for layer_id in self.layer_ids}
        for batch in test_loader:
            input_ = batch['input']
            target = batch['target']
            if self.net_out == 'calgary':
                dist_volume = []
                net_out_volume = []
                for input_chunk in input_:
                    dist, net_out = self.forward(input_chunk.unsqueeze(0).to(self.device))
                    dist_volume.append(dist.copy())
                    net_out_volume.append(net_out.cpu())
                dist = default_collate(dist_volume)            
                net_out = torch.cat(net_out_volume, dim=0)
            
            if self.net_out == 'mms':
                target[target == -1] = 0
                # convert to one-hot encoding
                target = F.one_hot(target.long(), num_classes=4).squeeze(1).permute(0,3,1,2)
                dist, net_out = self.forward(input_.to(self.device))            
            loss = self.criterion(net_out.cpu(), target)

            loss = loss.mean().float().cpu()
            for layer_id in self.layer_ids:
                corr_coeffs[layer_id].update(dist[layer_id].cpu().mean().view(1), 1-loss.view(1))

        return corr_coeffs


    @torch.no_grad()  
    def forward(self, input_: torch.Tensor) -> torch.Tensor:
        
        handles = {}
        for layer_id in self.layer_ids:
            layer = self.model.get_submodule(layer_id)
            hook  = self._get_hook_fn(layer_id, mode='single')
            handles[layer_id] = layer.register_forward_pre_hook(hook)
        
        net_out = self.model(input_)
        
        for layer_id in self.layer_ids:
            latent_centered = self.dist[layer_id].view(self.dist[layer_id].shape[0], 1, -1) - \
                self.mu[layer_id].unsqueeze(0)
            self.dist[layer_id] = latent_centered @ self.sigma_inv[layer_id] @ \
                latent_centered.permute(0,2,1)
            handles[layer_id].remove()
            
        return self.dist, net_out
    

### Refactor Mahalanobis Detection  


class Adapter:

    attr:
        swivel
        hook_type: forward or backward
    
    methods:
        


class Transformation:


    methods:
        transform (forward ?)
        score (mahalanobis or log likelihood https://stats.stackexchange.com/questions/97408/relation-of-mahalanobis-distance-to-log-likelihood)

In [None]:
class PoolingMahalanobis(nn.Module):
    def __init__(
        self,
        swivel: str,
        hook_fn: str = 'pre',
        transform: bool = False,
        device: str = 'cuda:0'
    ):
        super().__init__()
        # args
        self.swivel = swivel
        self.hook_fn = self.register_forward_pre_hook if hook_fn == 'pre' else self.register_forward_hook
        self.transform = transform
        self.device = device
        # attributes
        self.training_representations = []
        self.pool = nn.AvgPool3d(kernel_size=(2,2,2), stride=(2,2,2))

    ### private methods ###
        
    @torch.no_grad()
    def _reduce(self, x: Tensor) -> Tensor:
        # reduce dimensionality with 3D pooling to below 1e4 entries
        while torch.prod(torch.tensor(x.shape[1:])) > 1e4:
            x = self.pool(x)
        # reshape to (batch_size, n_features)
        x = x.reshape(x.shape[0], -1)
        return x

    @torch.no_grad()
    def _collect(self, x: Tensor) -> None:
        x = self._reduce(x)
        assert x.device == 'cpu', 'move data to cpu before storing it.'
        self.training_representations.append(x)

    @torch.no_grad()
    def _merge(self) -> None:
        self.training_representations = torch.cat(self.training_representations, dim=0)

    @torch.no_grad()
    def _estimate_gaussians(self, ledoitWolf=False) -> None:
        self.mu = self.training_representations.mean(0, keepdims=True).to(self.device)
        if ledoitWolf:
            self.sigma = torch.from_numpy(LedoitWolf().fit(self.training_representations).covariance_)
        else:
            self.sigma = torch.cov(self.training_representations)
        self.sigma_inv = torch.linalg.inv(self.sigma).unsqueeze(0).to(self.device)


    def _distance(self, x: Tensor) -> Tensor:
        # TODO: implement Mahalanobis distance
        # x = self._reduce(x)
        # x = x - self.mu
        # x = x @ self.sigma_inv @ x.T
        return x

    ### public methods ###

    def fit(self):
        self._merge()
        self._estimate_gaussians()


    def forward(self, x: Tensor) -> Tensor:
        if self.training:
            self._collect(x)
        
        else:
            self.batch_distances = self._distance(x)

        # implements identity function from a hooks perspective
        if self.transform:
            raise NotImplementedError('Implement transformation functionality')
        else:
            return x

