In [1]:
import matplotlib.pyplot as plt
plt.show()

In [22]:
import matplotlib.pyplot as plt
import os, sys
from typing import Iterable, Dict, List, Callable, Tuple, Union, List

import numpy as np
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision.utils import save_image
import fiftyone as fo
import shutil
from torchmetrics import Dice

sys.path.append('../../')
from dataset import CalgaryCampinasDataset
from model.unet import UNet2D
from model.ae import AE
from model.dae import resDAE, AugResDAE
from model.wrapper import Frankenstein, ModelAdapter
from losses import DiceScoreCalgary, SurfaceDiceCalgary
from utils import  epoch_average, UMapGenerator, volume_collate
from trainer.unet_trainer import UNetTrainerCalgary
from data_utils import get_subset

In [3]:
### Config
root = '../../../'
data_dir = 'data/conp-dataset/projects/calgary-campinas/CC359/Reconstructed/'
data_path = root + data_dir
debug = False
augment = False
site = 4

In [4]:
trainset = CalgaryCampinasDataset(
    data_path=data_path, 
    site=6,
    split='train',
    augment=augment, 
    normalize=True, 
    debug=debug
)

train_loader = DataLoader(
    trainset, 
    batch_size=1, 
    shuffle=False, 
    drop_last=False
)

In [5]:
valset = CalgaryCampinasDataset(
    data_path=data_path, 
    site=6,
    split='validation',
    augment=augment, 
    normalize=True, 
    debug=debug
)

val_loader = DataLoader(
    valset, 
    batch_size=1, 
    shuffle=False, 
    drop_last=False
)

In [6]:
testset = CalgaryCampinasDataset(
    data_path=data_path, 
    site=site,
    split='all',
    augment=augment, 
    normalize=True, 
    debug=debug
)

test_loader = DataLoader(
    testset, 
    batch_size=1, 
    shuffle=False, 
    drop_last=False
)

In [7]:
testloader = []
for site in [1,2,3,4,5]:
    
    testloader.append(
        DataLoader(
            CalgaryCampinasDataset(
                data_path=data_path, 
                site=site,
                split='all',
                augment=augment, 
                normalize=True, 
                debug=debug
            ),
            batch_size=1, 
            shuffle=False, 
            drop_last=False
        )
    )


In [201]:
model_path = f'../../../pre-trained-tmp/trained_UNets/calgary_unet0_augmentednnUNet_best.pt'
state_dict = torch.load(model_path)['model_state_dict']
n_chans_out = 1 
seg_model = UNet2D(
    n_chans_in=1, 
    n_chans_out=n_chans_out, 
    n_filters_init=8,
    dropout=False
)
seg_model.load_state_dict(state_dict)
# criterion   = nn.BCEWithLogitsLoss()
# eval_metrics = {
#     "Volumetric Dice": DiceScoreCalgary(),
#     "Surface Dice": SurfaceDiceCalgary()
# }

# unet_trainer = UNetTrainerCalgary(
#     model=seg_model, 
#     criterion=criterion, 
#     train_loader=None, 
#     valid_loader=None, 
#     root=root, 
#     eval_metrics=eval_metrics, 
#     description=f'calgary_unet0_augmentednnUNet',
#     log=False
# )

<All keys matched successfully>

In [202]:
seg_model

UNet2D(
  (init_path): Sequential(
    (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): ReLU()
    (2): ResBlock(
      (conv_path): Sequential(
        (0): PreActivationND(
          (bn): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activation): ReLU()
          (layer): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): PreActivationND(
          (bn): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activation): ReLU()
          (layer): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
    )
    (3): ResBlock(
      (conv_path): Sequential(
        (0): PreActivationND(
          (bn): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activation): ReLU()
          (layer): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), paddi

In [195]:
#unet_trainer.load_model()

In [208]:
disabled_ids = ['shortcut0', 'shortcut1', 'shortcut2']
DAEs = nn.ModuleDict(
    {'up3': AugResDAE(
        in_channels = 64, 
        in_dim      = 32,
        latent_dim  = 256,
        depth       = 3,
        block_size  = 4)
    }
)

for layer_id in disabled_ids:
    DAEs[layer_id] = nn.Identity()

# for i, unet in enumerate(tqdm(unets)):
#     #print(f"Method {method}, Unet {i} - {net_out}")
#     if net_out == 'calgary':
#         dataloader = DataLoader(
#             datasets[i],
#             batch_size=1, 
#             shuffle=False, 
#             drop_last=False, 
#         )

model = ModelAdapter(
    seg_model=seg_model,
    transformations=DAEs,
    disabled_ids=disabled_ids,
    copy=True
)
model_path = f'../../../pre-trained-tmp/trained_AEs/calgary_AugResDAE0_localAug_multiImgSingleView_res_balanced_same_best.pt'
state_dict = torch.load(model_path)['model_state_dict']
model.load_state_dict(state_dict)


# Remove trainiung hooks, add evaluation hooks
model.remove_all_hooks()        
model.hook_inference_transformations(model.transformations,
                           n_samples=1)
# Put model in evaluation state
model.to(0)
model.eval()
model.freeze_seg_model()

In [209]:
xin = img.unsqueeze(0)
model.cpu()
seg_model.eval()
print()





In [213]:
out1 = seg_model(xin)
out2 = seg_model(xin)
out3 = model.seg_model(xin)[:1]
out4 = model(xin)[:1]

In [216]:
(out1 != out4).sum()

tensor(0)

In [215]:
(out4 != out3).sum()

tensor(0)

In [10]:
def get_dice_for_subset(dataloader, fraction, trainer):
    
    def test_set(dataloader, trainer=trainer, metric=DiceScoreCalgary()):
        trainer.model.eval()
        scores = []
        for batch in dataloader:
            input_ = batch['input'].to(trainer.device)
            target = batch['target'].to(trainer.device)
            net_out = trainer.inference_step(input_)
            scores.append(metric(net_out,target).detach().mean().cpu())

        return scores

    len_    = len(dataloader)
    n_cases = int(len_ * fraction)
    subset  = trainer.get_subset(dataloader, fraction=fraction, n_cases=n_cases)
    loader  = DataLoader(subset, batch_size=1, shuffle=False, drop_last=False)
    scores  = test_set(loader)
    
    return torch.tensor(scores)

In [18]:
seg_model.cuda()
scores = []
for loader in [train_loader, val_loader, *testloader]:
    scores.append(get_dice_for_subset(loader, 0.05, unet_trainer))

  kmeans = KMeans(n_clusters=n_cases).fit(kmeans_in)


In [19]:
[s.mean() for s in scores]

[tensor(0.9808),
 tensor(0.9730),
 tensor(0.6513),
 tensor(0.9619),
 tensor(0.9567),
 tensor(0.8479),
 tensor(0.9365)]

In [24]:
subsets = [
     get_subset(
        loader.dataset,
        seg_model,
        criterion=nn.BCEWithLogitsLoss(reduction='none'),
        n_cases=25,
        fraction=0.05,
        batch_size=32
    ) for loader in [train_loader, val_loader, *testloader]
]

In [11]:
names = ['train', 'validation', *[f'test{k+1}' for k in range(len(testloader))]]

In [228]:
fo.delete_non_persistent_datasets()

In [222]:
seg_model.cpu()
model.cpu()
seg_model.eval()
print()




In [223]:
    
    
class UMapGenerator(nn.Module):
    """
    Calculates uncertainty maps from UNets in different ways.
    
    PyTorch Module to generate uncertainty maps from
    * VAE samples
    * Entropy in drop out samples
    * Entropy in model outputs
    """
    
    def __init__(
        self,
        method  = 'ae',
        net_out = 'mms'
    ):
        super().__init__()
        self.method  = method
        self.net_out = net_out
        self.m       = nn.Softmax(dim=1) if net_out=='mms' else nn.Sigmoid()
        self.ce      = nn.CrossEntropyLoss(reduction='none') if net_out=='mms' else nn.BCEWithLogitsLoss(reduction='none')
    
    @torch.no_grad()
    def forward(self, x: Tensor) -> Tensor:
        
        if self.method == 'none':
            return None
        
        x = x.detach()
        
        #################################
        ### experimental / M&M only   ###
        #################################
        
        if self.method == 'cross_entropy':
            umap = self.ce(x[:1], self.m(x[1:]))
            umap = umap.mean(dim=0, keepdims=True)
            
        elif self.method == 'entropy':          
            x_prob = self.m(x[:1])
            umap = torch.distributions.Categorical(x_prob.permute(0,2,3,1)).entropy()

        elif self.method == 'kl_divergence':
            x_in = F.log_softmax(x[:1], dim=1)
            umap = self.kl(x_in, self.m(x[1:]))
            umap = umap.sum(dim=(0,1), keepdims=True)
            
        elif self.method == 'mse':
            x      = self.m(x)
            x     -= x.min(dim=1, keepdims=True).values
            x     /= x.sum(dim=1, keepdims=True)
            umap   = torch.pow(x[:1] - x[1:], 2).mean(0, keepdim=True)
            umap   = umap.mean(dim=1, keepdims=True)            
            
        #################################
        ###   old umaps from MICCAI   ###
        #################################
        
        if self.method == 'ae':
            if self.net_out == 'mms':                
                umap = self.ce(x[:1], self.m(x[1:]))
                #umap = umap.mean(dim=(0, 1), keepdims=True)
                #print(umap.shape)
                umap = umap.mean(dim=0, keepdims=True)
#                 x      = self.m(x)
#                 x     -= x.min(dim=1, keepdims=True).values
#                 x     /= x.sum(dim=1, keepdims=True)
#                 umap   = torch.pow(x[:1] - x[1:], 2).mean(0, keepdim=True)
#                 umap   = umap.mean(dim=1, keepdims=True)
                
            elif self.net_out == 'calgary':
                x    = torch.sigmoid(x)
                umap = torch.pow(x[:1] - x[1:], 2).mean(0, keepdim=True)
#                 umap = self.ce(x[:1] - self.m(x[1:]))
#                 umap = 
                
                
        elif self.method == 'entropy':          

            if self.net_out == 'mms':
                #print('x', x.shape)
                #x_argmax  = torch.argmax(x, dim=1)
                #print('2',x_argmax.shape)
                #x_one_hot = F.one_hot(x_argmax, num_classes=4).permute(0,3,1,2).float()
                #print('3',x_one_hot.shape)
                x_softmax = F.softmax(x, dim=1)
                #print('soft',x_softmax.shape)
                #x_mean    = x_one_hot.mean(dim=0, keepdims=True)
                x_mean    = x_softmax.mean(dim=0, keepdims=True)
                #print('4',x_mean.shape)
                umap = torch.distributions.Categorical(x_mean.permute(0,2,3,1)).entropy()
                #print('5',umap.shape)
                #umap      = - x_mean * torch.log(x_mean)
                #umap      = umap.sum(dim=1, keepdims=True)

            elif self.net_out == 'calgary':
                x_probs = torch.sigmoid(x[1:])
                x_mean  = x_probs.mean(dim=0, keepdims=True)
                umap    = - x_mean * torch.log(x_mean) - (1-x_mean) * torch.log(1-x_mean)
                
        elif self.method == 'probs':
            if self.net_out == 'mms':
                x_probs = F.softmax(x, dim=1)
                umap = torch.distributions.Categorical(x_probs.permute(0,2,3,1)).entropy()
                #umap    = - x_probs * torch.log(x_probs)
                #umap    = umap.sum(dim=1, keepdims=True)
                
            elif self.net_out == 'calgary':
                x_probs = torch.sigmoid(x)
                #print(x_probs.min(), x_probs.max())
                #umap = torch.distributions.Categorical(x_probs.permute(0,2,3,1)).entropy()
                umap    = - x_probs * torch.log(x_probs+1e-6) - (1-x_probs) * torch.log(1-x_probs+1e-6)
        
        #print(umap.shape)
        return umap

In [229]:
try:
    dataset.delete()
except:
    print("dataset name already available or dataset didnt exist")


# build first dataset
n_unets = 2
export_dir = os.path.expanduser('~') + '/fiftyone/calgary-dataset'
# init dataset
dataset = fo.Dataset(name="calgary1")
# add group and all Groups we need
#dataset.add_group_field("group", default="train")

# set mask targets
## ground truth targets
dataset.mask_targets = {
    "ground_truth": {0: "background",
                     1: "foreground",}
}
# error map labels
# for i in range(n_unets):
#     dataset.mask_targets[f'errormap_it:{i}'] = {1: 'error'}
    
# make temporary dir for data handling
os.makedirs('tmp', exist_ok=True)
path = 'tmp/'
# init sample list. We save each sample here and add it to the
# dataset in the end
samples = []
# init dice score class
# dcs = Dice(num_classes=1, ignore_index=0)
# init umap generator
umap_generator_baseline = UMapGenerator(method='probs', net_out='calgary')
umap_generator_AE = UMapGenerator(method='cross_entropy', net_out='calgary')

# itertatively make samples
for subset, name in zip(subsets, names):
    for i in range(0, len(subset)):
        # get data
        data = subset[i]
        img = data['input']
        mask = data['target']
        mask[mask < 0] = 0

        # save image to disk
        img_path  = path + f'img_{name}_{i}.png'
        #print(img_path)
        img_norm  = img - img.min()
        img_norm /= img_norm.max()

        save_image(img_norm, img_path)
        torch.save(img, path + f'img_{name}_{i}.pt')

        sample_image = fo.Sample(
            filepath=img_path,
            ground_truth=fo.Segmentation(
                mask=mask.squeeze().numpy()
            ),
            tags=[name]
        )
    
        unet_output = seg_model(img.unsqueeze(0))
        ae_output   = model(img.unsqueeze(0))
        pred_unet = (torch.sigmoid(unet_output) > 0.5) * 1
        pred_ae = (torch.sigmoid(ae_output[1:]) > 0.5) * 1
        err_map_unet  = (pred_unet != mask)
        err_map_ae    = (pred_ae   != mask)
        umap_baseline = umap_generator_baseline(unet_output)
        umap_ae       = umap_generator_AE(ae_output)
        
        sample_image[f'pred_unet']  = fo.Segmentation(mask=pred_unet.squeeze().numpy())
        sample_image[f'pred_ae']    = fo.Segmentation(mask=pred_ae.squeeze().numpy())
        sample_image[f'error_unet'] = fo.Segmentation(mask=err_map_unet.squeeze().numpy())
        sample_image[f'error_ae']   = fo.Segmentation(mask=err_map_ae.squeeze().numpy())
        sample_image[f'umap_unet']  = fo.Heatmap(map=umap_baseline.squeeze().numpy())
        sample_image[f'umap_ae']    = fo.Heatmap(map=umap_ae.squeeze().numpy())

        samples.append(sample_image)
        
# add samples to dataset
dataset.add_samples(samples)
print()




 100% |█████████████████| 175/175 [8.8s elapsed, 0s remaining, 20.4 samples/s]       



In [231]:
session = fo.launch_app(dataset, port=8090)

In [233]:
tmp = torch.load('../../../results-tmp/results/eval/calgary/pixel/calgary-localAug_multiImgSingleView_res_balanced_same-4-0')

In [234]:
tmp.auc_pr

0.6256674528121948