In [1]:
#import os, sys
#from time import time
import torch
from torch import Tensor
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
#from torch.cuda.amp import GradScaler, autocast
#from torch.utils.data._utils.collate import default_collate
#import copy
#from time import time
#import wandb
import matplotlib.pyplot as plt
from typing import Tuple

import sys
sys.path.append('..')
from dataset import AEDataset
from trainer import Trainer, WeakSupervisionTrainer
from model import DualBranchAE
from utils import *
from losses import MSELoss
from pretrainer import PreTrainer
from torchmetrics.classification import BinaryROC
from scipy import ndimage
from sklearn.ensemble import IsolationForest
from sklearn.linear_model import LogisticRegression

from tqdm.notebook import tqdm



In [2]:
# auto reload changes in .py files
%load_ext autoreload
%autoreload 2

In [3]:
# %cd example/

# Dataset

We currently work on the HPC data and within this, we built two different segmentation tasks. Further details are in the paper https://cg.cs.uni-bonn.de/backend/v1/files/publications/torayev-vcbm2020.pdf. Neither the whole dataset nor the model are in this repo. We will set you up once you started your work and give your access.

In [4]:
# which tasks are used is handled by "set". 1 is a binary task for debugging, 2 is multi-class 
# and so is 3 but with asymmetric classes w.r.t. the saggital plane (harder). Details for 
# set 2 and 3 are in the paper.
# 'modality' handles the target provided by __getitem__. Options are reconstruction and segmentation.
# When segmentation is selected, the labels are taken from the annotations attribute. This is also where
# the user-model interacts with the dataset. Ground truth masks are in the label attribute. All other parameters are
# from past experiments and alter the behaviour. This project has been around for a while, so some are not used anymore.

# normalize is usually set to true. Simply normalizes the input. Augment is legacy, we didn't have much success
# with data augmentation. balance takes care of data balancing during a batch. Some classes are under-
# represented so we show them to the model more often. It helps quite a bit during training so consider 
# integrating it. We can talk about how this is done in detail once you start. init defines how the user-model behaves. 
# We considered different behaviours w.r.t. to annotation style and quantity and such. 
# To_gpu moves ALL data to GPU. Since we only work on a single volume (i.e. couple hundred slices) 
# we move everything to GPU and avoid latency in dataloading. Takes a hefty chunk out of the VRAM though 
# but makes things faster.

# Feel free to re-write anything you want. This is partly dated code that could use a re-write anyways.

# Example:
# make a config first. This handles globals and is used through-out the script. Many things that were tried in
# experiments later have not yet made it into the config, but most have.

cfg = {
    # CONFIG
    'name': 'location-unsupervised',
    'project': 'IDVR-localization_pretrain',
    'log': False,
    'rank': 0,
    
    # DATA
    'data_dir': '../../../data/784565/Diffusion/',
    'data_path': '../../../data/784565/Diffusion/data.nii',
    'active_mask_path': '../../../data/784565/Diffusion/nodif_brain_mask.nii.gz',
    
    # SELF SUPERVISED PRE-TRAINING
    's_n_epochs': 20,
    's_batch_size': 16, # default: 8
    's_lr': 5e-4, #1e-4, 1e-5        
    
    # TRAINING WITH WEAK SUPERVISION
    'p_n_epochs': 100,
    'w_n_epochs': 10,
    'w_batch_size': 2,
    'w_lr': 5e-4,    #5e-5 
    'w_eval_freq': 100,
    
    # RANDOM FOREST
    'min_samples_leaf': 8,
    
    # USER MODEL
    'init_voxels': 200,
    'refinement_voxels': 200,
    'num_interactions': 10,
    'brush' : False,
    'slice_selection' : 'mean',
    'voxel_selection' : 'max', 
    'voxel_distribution' : 'uniform',
}

In [5]:
# we set balance to true. This also effects the dataloader later
balance = True
dataset = AEDataset(cfg, modality='segmentation', normalize=True,
                    set=2, augment=False, balance=balance, init='per_class', to_gpu=False)
# dataset = AEDataset(cfg, modality='segmentation', normalize=True,
                    # set=2, augment=False, balance=balance, init='three_slices', to_gpu=False)

# currently, there are no annotations. We can also enforce this with clear_annotations() at any point
dataset.clear_annotation()
# get initial annotations
annot = dataset.initial_annotation(seed=42)
# and update the dataset
dataset.update_annotation(annot)
print(f"number of annotations: {dataset.annotations.detach().cpu().sum()}")

# The dataset currently always provides 4 items. Input (the image), target (the input for reconstruction or 
# the annotations for segmentation), weights that mask out voxels which are not annotated for segmentation 
# and a brain mask for masking background during reconstruction
item = dataset[0]
print(item.keys())

Shape torch.Size([145, 145])
tensor([0., 1.])
tensor([0., 1.])
torch.Size([5])
torch.Size([5, 1, 1])
Shape torch.Size([145, 145])
tensor([0., 1.])
tensor([0., 1.])
torch.Size([5])
torch.Size([5, 1, 1])
Shape torch.Size([145, 145])
tensor([0., 1.])
tensor([0., 1.])
torch.Size([5])
torch.Size([5, 1, 1])
Shape torch.Size([145, 145])
tensor([0., 1.])
tensor([0., 1.])
torch.Size([5])
torch.Size([5, 1, 1])
Shape torch.Size([145, 145])
tensor([0., 1.])
tensor([0., 1.])
torch.Size([5])
torch.Size([5, 1, 1])
number of annotations: 1267.0
dict_keys(['input', 'target', 'weight', 'mask'])


In [6]:
print(dataset.annotations.sum())
print(dataset.annotations.sum(dim=0).unique(return_counts=True))

tensor(1267.)
(tensor([0., 1., 2.]), tensor([3047691,     601,     333]))


# Model and Inference

The overall pipeline is illustrated in the README.

In [7]:
# At first, we do not have annotations but still need features for the Random Forest. So we pre-train 
# on a reconstruction task and later re-use the same Encoder (the part of the network that outputs our features),
# simply replace the decoder and resume training. 

# init the model with segmentation decoder. Have a look at the source code for additional guidance. The dataset
# updates the config to contain labels. We initialize with one channel per class.
model = DualBranchAE(encoder    = 'dual',
                     decoder    = 'segmentation',
                     in_size    = 145,
                     n_classes  = len(cfg['labels']),
                     thresholds = 'learned', 
                     dropout = False, 
                     dropout_rate=0.5) #.to(cfg['rank'])

# example model from one of the experiments
#model_path = 'example_dual_xy_0_best.pt'
model_path = 'models/Test_best.pt'
#model_path = 'models/Dropout-0.5_best.pt'
# model_path = 'models/Dropout-0.2_best.pt'

# load the components
checkpoint           = torch.load(model_path)
model_state_dict     = checkpoint['model_dict']
encoder_state_dict   = {k.replace('encoder.', ''): v for k, v in model_state_dict.items() if 'encoder' in k}
# print(model_state_dict.keys())

# copy encoder weights to model. Decoder weights remain as they are, initialized as random
model.encoder.load_state_dict(encoder_state_dict, strict=True)

# Define the dataloader. If we use balanced sampling in the dataset, we also need the custom balanced_collate 
# function in the dataloader. This handles the unusal batching logic.

if balance:
    loader  = DataLoader(dataset, 
                         batch_size=cfg['w_batch_size'], 
                         shuffle=True, 
                         drop_last=False, 
                         collate_fn=balanced_collate)
else:
    loader  = DataLoader(dataset, 
                         batch_size=16, 
                         shuffle=True, 
                         drop_last=False)

In [8]:
# For evaluation, we are interested in the Random Forest (RF) prediction based on
# the CNN features. 

# write checkpoints for stuff that changes the behaviour of the dataset.
# E.g. balancing changes the __getitem__ method and thus influences 
# evaluation. Turn it off and on later if needed.
augment_checkpoint = dataset.augment
balance_checkpoint = dataset.balance
dataset.augment = False
dataset.balance = False

# define the layer you want the features from. This is usually the encoder output.
f_layer = 'encoder'
# Init the feature extractor. Have a look at PyTorchs Hook functionality.
extractor = FeatureExtractor(model, layers=[f_layer])

# Cache all features for a dataset and reformat/move to numpy for random forest stuff
hooked_results  = extractor(dataset)
features = hooked_results[f_layer]
features  = features.permute(0,2,3,1).numpy()
# In the utils file are a bunch of evaluation scripts, some are not used anymore.
# This one provides F1 scores for the whole dataset based on all ground truth labels
# and also the predictions themselve as given by the RF. We need them later to update the annotations with
# the user model.

# Turn dataset attributes to normal again
dataset.augment = augment_checkpoint
dataset.balance = balance_checkpoint


 #Now you can change the model and features to your liking and try again (e.g. via constrastive learning ;)).
 #The scores from the RF are the signal you need for evaluation, the rest is up to you.

In [9]:
x = dataset.label.permute(1,2,3,0)
print(x[dataset.brain_mask].sum(dim=1).unique(return_counts=True))
print(x[dataset.brain_mask][:, 1:].sum(dim=1).unique(return_counts=True))

(tensor([1., 2., 3.], dtype=torch.float64), tensor([800508,  43841,      1]))
(tensor([0., 1., 2., 3.], dtype=torch.float64), tensor([644241, 156267,  43841,      1]))


In [10]:
print(x[dataset.brain_mask].shape)
print(x[dataset.brain_mask][:, :1].sum(dim=1).unique(return_counts=True))

torch.Size([844350, 5])
(tensor([0., 1.], dtype=torch.float64), tensor([200109, 644241]))


# Uncertainty Measures

In [11]:
def print_results(n_annots: list, f1_scores: list):
    print(f'Iteration | # Annotations | F1 Score')
    print(f'----------|---------------|---------')
    for i, (n, f1) in enumerate(zip(n_annots, f1_scores)):
        if i in [0, 1, 2, 3, 4, 9, 14, 19, 24, 29, 34, 39, 44, 49, 54]:
            print(f'{i+1:>9} | {int(n):>13} | {f1:.4f}')

In [12]:
# wir brauchen für die erste Iteration mindestens 1 Annotation damit der RF funktioniert
def re_init_dataset():
    dataset.clear_annotation()
    annot = dataset.initial_annotation(seed=42)
    dataset.update_annotation(annot)

In [13]:
def train(method: str, n_epochs: int, measures: List[str]):
    re_init_dataset()
    print(f'Selection using {method}')
    #print(f"number of annotations: {dataset.annotations.detach().cpu().sum()}")
    scores, rf_prediction, unc, unc_pc = evaluate_RF(dataset, features, cfg, measures)
    #scores, rf_prediction, unc, unc_pc = evaluate_RF_with_uncertainty(dataset, features, cfg, measures)
    #scores, rf_prediction = evaluate_RF(dataset, features, cfg)
    print(f"Number of initial annotations: {dataset.annotations.detach().cpu().sum()}")
    print(f"Average F1 score for RF after initial user interaction:    {scores['Avg_f1_tracts'].item():.4f}")
    print()
    n_annots = []
    annots = []
    f1_scores = []
    rf_predictions = []
    uncs_pc = []
    uncs = []


    for i in tqdm(range(n_epochs), desc='User interaction', unit='iteration'):
        #print(f"Iteration {i+1}")
        if method == 'random':
            annot = dataset.refinement_annotation(prediction=rf_prediction, random=True, seed=42)
        elif method == 'ground-truth':
            annot = dataset.refinement_annotation(prediction=rf_prediction, seed=42)
        else:
            annot = dataset.refinement_annotation(prediction=rf_prediction, uncertainty_map=unc_pc[method], seed=42)
            #annot = dataset.uncertainty_refinement_annotation(prediction=rf_prediction, uncertainty_map=unc_pc[method], seed=42)

        dataset.update_annotation(annot)
        
        annots.append(dataset.annotations.detach().cpu())
        n_annots.append(dataset.annotations.detach().cpu().sum().item())
        #print(f"number of annotations: {dataset.annotations.detach().cpu().sum()}")
        scores, rf_prediction, unc, unc_pc = evaluate_RF(dataset, features, cfg, measures)
        #scores, rf_prediction, unc, unc_pc = evaluate_RF_with_uncertainty(dataset, features, cfg, measures)
        #scores, rf_prediction = evaluate_RF(dataset, features, cfg)
        rf_predictions.append(rf_prediction)
        #uncs_pc.append(unc_pc)
        #uncs.append(unc)
        f1_scores.append(scores['Avg_f1_tracts'].item())
        #print(f"Average F1 score for RF after additional user interaction: {scores['Avg_f1_tracts'].item():.4f}")
    
    print_results(n_annots, f1_scores)
    return n_annots, annots, f1_scores, rf_predictions, uncs_pc, uncs

### Ground Truth

In [37]:
ns, ans, f1s, rf_preds, uncs_pc, uncs = train(method='ground-truth', n_epochs=20, measures=['ground-truth'])

Selection using ground-truth
Number of initial annotations: 1267.0
Average F1 score for RF after initial user interaction:    0.3617



User interaction:   0%|          | 0/20 [00:00<?, ?iteration/s]

Iteration | # Annotations | F1 Score
----------|---------------|---------
        1 |          1968 | 0.4362
        2 |          2736 | 0.4664
        3 |          3394 | 0.4949
        4 |          4163 | 0.5281
        5 |          4608 | 0.5536
       10 |          8045 | 0.5885
       15 |         10708 | 0.6634
       20 |         13494 | 0.6734


### Entropy

In [38]:
ns_e, ans_e, f1s_e, rf_preds_e, uncs_pc_e, uncs_e = train(method='entropy', n_epochs=20, measures=['entropy'])

Selection using entropy
Number of initial annotations: 1267.0
Average F1 score for RF after initial user interaction:    0.3617



User interaction:   0%|          | 0/20 [00:00<?, ?iteration/s]

Iteration | # Annotations | F1 Score
----------|---------------|---------
        1 |          2375 | 0.3956
        2 |          3529 | 0.4189
        3 |          4696 | 0.4208
        4 |          5865 | 0.4296
        5 |          6973 | 0.4333
       10 |         12358 | 0.4836
       15 |         17100 | 0.4927
       20 |         21653 | 0.5037


### Spatial Distance

In [52]:
ns_sd, ans_sd, f1s_sd, rf_preds_sd, uncs_pc_sd, uncs_sd = train(method='spatial-distance', n_epochs=20, measures=['spatial-distance'])

Selection using spatial-distance
Number of initial annotations: 1267.0
Average F1 score for RF after initial user interaction:    0.3617



User interaction:   0%|          | 0/20 [00:00<?, ?iteration/s]

Iteration | # Annotations | F1 Score
----------|---------------|---------
        1 |          2310 | 0.4281
        2 |          3318 | 0.4528
        3 |          4261 | 0.4625
        4 |          5165 | 0.4683
        5 |          6110 | 0.4768
       10 |          9922 | 0.5045
       15 |         12894 | 0.5126
       20 |         15837 | 0.5175


### Feature Distance

In [53]:
ns_fd, ans_fd, f1s_fd, rf_preds_fd, uncs_pc_fd, uncs_fd = train(method='feature-distance', n_epochs=20, measures=['feature-distance'])

Selection using feature-distance
Number of initial annotations: 1267.0
Average F1 score for RF after initial user interaction:    0.3617



User interaction:   0%|          | 0/20 [00:00<?, ?iteration/s]

Iteration | # Annotations | F1 Score
----------|---------------|---------
        1 |          2170 | 0.4283
        2 |          2988 | 0.4441
        3 |          3779 | 0.4545
        4 |          4573 | 0.4605
        5 |          5465 | 0.4643
       10 |          9357 | 0.4847
       15 |         12932 | 0.4953
       20 |         16426 | 0.5125


### Random

In [41]:
ns_r, ans_r, f1s_r, rf_preds_r, uncs_pc_r, uncs_r = train(method='random', n_epochs=20, measures=['random'])

Selection using random
Number of initial annotations: 1267.0
Average F1 score for RF after initial user interaction:    0.3617



User interaction:   0%|          | 0/20 [00:00<?, ?iteration/s]

Iteration | # Annotations | F1 Score
----------|---------------|---------
        1 |          2243 | 0.4368
        2 |          3300 | 0.4568
        3 |          4265 | 0.4847
        4 |          5280 | 0.4880
        5 |          6337 | 0.5053
       10 |         11494 | 0.5356
       15 |         16644 | 0.5616
       20 |         21655 | 0.5893


In [45]:
x = rf_preds[-1] - dataset.label
print(f'Klasse 1: {x[0].unique(return_counts=True)}')
print(f'Klasse 2: {x[1].unique(return_counts=True)}')
print(f'Klasse 3: {x[2].unique(return_counts=True)}')
print(f'Klasse 4: {x[3].unique(return_counts=True)}')
print(f'Klasse 5: {x[4].unique(return_counts=True)}')

Klasse 1: (tensor([-1.,  0.,  1.], dtype=torch.float64), tensor([ 183390, 2861184,    4051]))
Klasse 2: (tensor([-1.,  0.,  1.], dtype=torch.float64), tensor([    582, 3007134,   40909]))
Klasse 3: (tensor([-1.,  0.,  1.], dtype=torch.float64), tensor([    307, 3009528,   38790]))
Klasse 4: (tensor([0., 1.], dtype=torch.float64), tensor([3045455,    3170]))
Klasse 5: (tensor([-1.,  0.,  1.], dtype=torch.float64), tensor([  15093, 2967059,   66473]))


In [92]:
mask = dataset.annotations.any(dim=0)
y = dataset.annotations[:, mask] - dataset.label[:, mask]
z = y.sum(dim=0)
print(z.unique(return_counts=True))

(tensor([-1.,  0.], dtype=torch.float64), tensor([1063, 1567]))


In [34]:
for i, annots in enumerate(ans_e):
    if i in (0,1,2,3,4, 9, 14, 19):
        torch.save(annots, f'annotations/annots_{i+1}_entropy.pt')

In [35]:
for i, annots in enumerate(ans_sd):
    if i in (0,1,2,3,4, 9, 14, 19):
        torch.save(annots, f'annotations/annots_{i+1}_sd.pt')

In [36]:
for i, annots in enumerate(ans_fd):
    if i in (0,1,2,3,4, 9, 14, 19):
        torch.save(annots, f'annotations/annots_{i+1}_fd.pt')

### Logistic Regression

In [37]:
error_maps = [torch.ne(rf_prediction, dataset.label) * 1 for rf_prediction in rf_preds]
error_maps_mean = [torch.any(error_map, dim=0) * 1 for error_map in error_maps]

In [38]:
abc = uncs_pc[0]['entropy'].flatten()
defg = uncs_pc[0]['spatial-distance'].flatten()
hijk = uncs_pc[0]['feature-distance'].flatten()
lmno = torch.stack((abc, defg, hijk), dim=1)
print(lmno.shape)

torch.Size([15243125, 3])


In [39]:
# Create and fit the logistic regression model
for i, (unc_map, error_map) in enumerate(tqdm(zip(uncs_pc, error_maps),total=len(uncs_pc), desc='Calculation', unit='iteration')):
    unc_map_e = unc_map['entropy'].flatten()
    unc_map_sd = unc_map['spatial-distance'].flatten()
    unc_map_fd = unc_map['feature-distance'].flatten()
    logreg_e = LogisticRegression(random_state=0, n_jobs=-1)
    logreg_e.fit(unc_map_e.reshape(-1, 1), error_map.flatten())
    logreg_sd = LogisticRegression(random_state=0, n_jobs=-1)
    logreg_sd.fit(unc_map_sd.reshape(-1, 1), error_map.flatten())
    logreg_fd = LogisticRegression(random_state=0, n_jobs=-1)
    logreg_fd.fit(unc_map_fd.reshape(-1, 1), error_map.flatten())
    print(f"Koeffizient für Iteration {i+1}: Entropy: {logreg_e.coef_}, SD: {logreg_sd.coef_}, FD: {logreg_fd.coef_}")
    logreg_c = LogisticRegression(random_state=0, n_jobs=-1)
    logreg_c.fit(torch.stack((unc_map_e, unc_map_sd, unc_map_fd), dim=1), error_map.flatten())
    print(f"Koeffizient für Iteration {i+1}: Combined: {logreg_c.coef_}")

Calculation:   0%|          | 0/20 [00:00<?, ?iteration/s]

Koeffizient für Iteration 1: Entropy: [[6.80838514]], SD: [[0.03420244]], FD: [[5.14872597]]
Koeffizient für Iteration 1: Combined: [[ 6.72757369 -0.0340217   0.63251322]]
Koeffizient für Iteration 2: Entropy: [[7.02794357]], SD: [[0.03219433]], FD: [[5.32265189]]
Koeffizient für Iteration 2: Combined: [[ 6.73033755 -0.03509393  1.2549075 ]]
Koeffizient für Iteration 3: Entropy: [[6.6982623]], SD: [[0.02620928]], FD: [[5.05152733]]
Koeffizient für Iteration 3: Combined: [[ 6.29420073 -0.0529336   1.9025076 ]]
Koeffizient für Iteration 4: Entropy: [[6.8116268]], SD: [[0.02158613]], FD: [[4.87395059]]
Koeffizient für Iteration 4: Combined: [[ 6.59178342 -0.06249266  1.17257243]]
Koeffizient für Iteration 5: Entropy: [[7.19238669]], SD: [[0.01998846]], FD: [[4.75844355]]
Koeffizient für Iteration 5: Combined: [[ 7.08846542 -0.05720406  0.35012152]]
Koeffizient für Iteration 6: Entropy: [[7.31138677]], SD: [[0.02119565]], FD: [[4.94982176]]
Koeffizient für Iteration 6: Combined: [[ 7.19909