In [12]:
#import os, sys
#from time import time
import torch
from torch import Tensor
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
#from torch.cuda.amp import GradScaler, autocast
#from torch.utils.data._utils.collate import default_collate
#import copy
#from time import time
#import wandb
import matplotlib.pyplot as plt
from typing import Tuple

import sys
sys.path.append('..')
from dataset import AEDataset
from trainer import Trainer, WeakSupervisionTrainer
from model import DualBranchAE
from utils import *
from losses import MSELoss
from pretrainer import PreTrainer
from torchmetrics.classification import BinaryROC
from scipy import ndimage
from sklearn.ensemble import IsolationForest

In [13]:
# auto reload changes in .py files
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
# %cd example/

# Dataset

We currently work on the HPC data and within this, we built two different segmentation tasks. Further details are in the paper https://cg.cs.uni-bonn.de/backend/v1/files/publications/torayev-vcbm2020.pdf. Neither the whole dataset nor the model are in this repo. We will set you up once you started your work and give your access.

In [15]:
# which tasks are used is handled by "set". 1 is a binary task for debugging, 2 is multi-class 
# and so is 3 but with asymmetric classes w.r.t. the saggital plane (harder). Details for 
# set 2 and 3 are in the paper.
# 'modality' handles the target provided by __getitem__. Options are reconstruction and segmentation.
# When segmentation is selected, the labels are taken from the annotations attribute. This is also where
# the user-model interacts with the dataset. Ground truth masks are in the label attribute. All other parameters are
# from past experiments and alter the behaviour. This project has been around for a while, so some are not used anymore.

# normalize is usually set to true. Simply normalizes the input. Augment is legacy, we didn't have much success
# with data augmentation. balance takes care of data balancing during a batch. Some classes are under-
# represented so we show them to the model more often. It helps quite a bit during training so consider 
# integrating it. We can talk about how this is done in detail once you start. init defines how the user-model behaves. 
# We considered different behaviours w.r.t. to annotation style and quantity and such. 
# To_gpu moves ALL data to GPU. Since we only work on a single volume (i.e. couple hundred slices) 
# we move everything to GPU and avoid latency in dataloading. Takes a hefty chunk out of the VRAM though 
# but makes things faster.

# Feel free to re-write anything you want. This is partly dated code that could use a re-write anyways.

# Example:
# make a config first. This handles globals and is used through-out the script. Many things that were tried in
# experiments later have not yet made it into the config, but most have.

cfg = {
    # CONFIG
    'name': 'location-unsupervised',
    'project': 'IDVR-localization_pretrain',
    'log': False,
    'rank': 0,
    
    # DATA
    'data_dir': '../../../data/784565/Diffusion/',
    'data_path': '../../../data/784565/Diffusion/data.nii',
    'active_mask_path': '../../../data/784565/Diffusion/nodif_brain_mask.nii.gz',
    
    # SELF SUPERVISED PRE-TRAINING
    's_n_epochs': 20,
    's_batch_size': 16, # default: 8
    's_lr': 5e-4, #1e-4, 1e-5        
    
    # TRAINING WITH WEAK SUPERVISION
    'p_n_epochs': 100,
    'w_n_epochs': 10,
    'w_batch_size': 2,
    'w_lr': 5e-4,    #5e-5 
    'w_eval_freq': 100,
    
    # RANDOM FOREST
    'min_samples_leaf': 8,
    
    # USER MODEL
    'init_voxels': 200,
    'refinement_voxels': 200,
    'num_interactions': 10,
}

In [16]:
# we set balance to true. This also effects the dataloader later
balance = False
dataset = AEDataset(cfg, modality='segmentation', normalize=True,
                    set=2, augment=False, balance=balance, init='per_class', to_gpu=False)

# currently, there are no annotations. We can also enforce this with clear_annotations() at any point
dataset.clear_annotation()
# get initial annotations
annot = dataset.initial_annotation(seed=42)
# and update the dataset
dataset.update_annotation(annot)
print(f"number of annotations: {dataset.annotations.detach().cpu().sum()}")

# The dataset currently always provides 4 items. Input (the image), target (the input for reconstruction or 
# the annotations for segmentation), weights that mask out voxels which are not annotated for segmentation 
# and a brain mask for masking background during reconstruction
item = dataset[0]
print(item.keys())

number of annotations: 9510.0
dict_keys(['input', 'target', 'weight', 'mask'])


In [17]:
print(len(dataset))

145


In [18]:
print(item['input'].shape)
print(item['target'].shape)
print(item['weight'].shape)
print(item['mask'].shape)

torch.Size([288, 145, 145])
torch.Size([5, 145, 145])
torch.Size([1, 145, 145])
torch.Size([145, 145])


# Model and Inference

The overall pipeline is illustrated in the README.

In [19]:
# At first, we do not have annotations but still need features for the Random Forest. So we pre-train 
# on a reconstruction task and later re-use the same Encoder (the part of the network that outputs our features),
# simply replace the decoder and resume training. 

# init the model with segmentation decoder. Have a look at the source code for additional guidance. The dataset
# updates the config to contain labels. We initialize with one channel per class.
model = DualBranchAE(encoder    = 'dual',
                     decoder    = 'segmentation',
                     in_size    = 145,
                     n_classes  = len(cfg['labels']),
                     thresholds = 'learned') #.to(cfg['rank'])

# example model from one of the experiments
#model_path = 'example_dual_xy_0_best.pt'
model_path = 'models/Test_best.pt'

# load the components
checkpoint           = torch.load(model_path)
model_state_dict     = checkpoint['model_dict']
encoder_state_dict   = {k.replace('encoder.', ''): v for k, v in model_state_dict.items() if 'encoder' in k}

# copy encoder weights to model. Decoder weights remain as they are, initialized as random
model.encoder.load_state_dict(encoder_state_dict, strict=True)

# Define the dataloader. If we use balanced sampling in the dataset, we also need the custom balanced_collate 
# function in the dataloader. This handles the unusal batching logic.

if balance:
    loader  = DataLoader(dataset, 
                         batch_size=cfg['w_batch_size'], 
                         shuffle=True, 
                         drop_last=False, 
                         collate_fn=balanced_collate)
else:
    loader  = DataLoader(dataset, 
                         batch_size=16, 
                         shuffle=True, 
                         drop_last=False)

In [20]:
# For evaluation, we are interested in the Random Forest (RF) prediction based on
# the CNN features. 

# write checkpoints for stuff that changes the behaviour of the dataset.
# E.g. balancing changes the __getitem__ method and thus influences 
# evaluation. Turn it off and on later if needed.
augment_checkpoint = dataset.augment
balance_checkpoint = dataset.balance
dataset.augment = False
dataset.balance = False

# define the layer you want the features from. This is usually the encoder output.
f_layer = 'encoder'
# Init the feature extractor. Have a look at PyTorchs Hook functionality.
extractor = FeatureExtractor(model, layers=[f_layer])
# Cache all features for a dataset and reformat/move to numpy for random forest stuff
hooked_results  = extractor(dataset)
features = hooked_results[f_layer]
features  = features.permute(0,2,3,1).numpy()
# In the utils file are a bunch of evaluation scripts, some are not used anymore.
# This one provides F1 scores for the whole dataset based on all ground truth labels
# and also the predictions themselve as given by the RF. We need them later to update the annotations with
# the user model.

# Turn dataset attributes to normal again
dataset.augment = augment_checkpoint
dataset.balance = balance_checkpoint


 #Now you can change the model and features to your liking and try again (e.g. via constrastive learning ;)).
 #The scores from the RF are the signal you need for evaluation, the rest is up to you.

In [21]:
print(f"number of annotations: {dataset.annotations.detach().cpu().sum()}")

number of annotations: 9510.0


In [11]:
print('Selection using entropy:')
scores, rf_prediction, entropy_map, x, y, z, v, w = evaluate_RF_with_uncertainty(dataset, features, cfg)
print(f"Average F1 score for RF after initial user interaction:    {scores['Avg_f1_tracts'].item():.4f}")
for i in range(5):
    annot = dataset.uncertainty_refinement_annotation(prediction=rf_prediction, uncertainty_map=entropy_map, seed=42)
    dataset.update_annotation(annot)
    print(f"number of annotations: {dataset.annotations.detach().cpu().sum()}")
    scores, rf_prediction, entropy_map, x, y, z, v, w = evaluate_RF_with_uncertainty(dataset, features, cfg)
    print(f"Average F1 score for RF after additional user interaction: {scores['Avg_f1_tracts'].item():.4f}")

Selection using entropy:
Average F1 score for RF after initial user interaction:    0.3152
number of annotations: 15463.0
Average F1 score for RF after additional user interaction: 0.3745
number of annotations: 20730.0
Average F1 score for RF after additional user interaction: 0.3876
number of annotations: 25456.0
Average F1 score for RF after additional user interaction: 0.4184
number of annotations: 29263.0
Average F1 score for RF after additional user interaction: 0.4243
number of annotations: 34116.0
Average F1 score for RF after additional user interaction: 0.4290


In [None]:
print('Selection using spatial distance:')
scores, rf_prediction, entropy_map, em_m, sd, sd_pc, fd, fd_pc = evaluate_RF_with_uncertainty(dataset, features, cfg)
print(f"Average F1 score for RF after initial user interaction:    {scores['Avg_f1_tracts'].item():.4f}")
for i in range(5):
    annot = dataset.uncertainty_refinement_annotation(prediction=rf_prediction, uncertainty_map=sd_pc, seed=42)
    dataset.update_annotation(annot)
    print(f"number of annotations: {dataset.annotations.detach().cpu().sum()}")
    scores, rf_prediction, entropy_map, em_m, sd, sd_pc, fd, fd_pc = evaluate_RF_with_uncertainty(dataset, features, cfg)
    print(f"Average F1 score for RF after additional user interaction: {scores['Avg_f1_tracts'].item():.4f}")

Selection using spatial distance:
Average F1 score for RF after initial user interaction:    0.3152
number of annotations: 13862.0
Average F1 score for RF after additional user interaction: 0.3823
number of annotations: 17573.0
Average F1 score for RF after additional user interaction: 0.3905
number of annotations: 20408.0
Average F1 score for RF after additional user interaction: 0.4115
number of annotations: 23559.0
Average F1 score for RF after additional user interaction: 0.4310
number of annotations: 26581.0
Average F1 score for RF after additional user interaction: 0.4353


In [22]:
print('Selection using feature distance:')
scores, rf_prediction, entropy_map, em_m, sd, sd_pc, fd, fd_pc = evaluate_RF_with_uncertainty(dataset, features, cfg)
print(f"Average F1 score for RF after initial user interaction:    {scores['Avg_f1_tracts'].item():.4f}")
for i in range(5):
    annot = dataset.uncertainty_refinement_annotation(prediction=rf_prediction, uncertainty_map=fd_pc, seed=42)
    dataset.update_annotation(annot)
    print(f"number of annotations: {dataset.annotations.detach().cpu().sum()}")
    scores, rf_prediction, entropy_map, em_m, sd, sd_pc, fd, fd_pc = evaluate_RF_with_uncertainty(dataset, features, cfg)
    print(f"Average F1 score for RF after additional user interaction: {scores['Avg_f1_tracts'].item():.4f}")

Selection using feature distance:
Average F1 score for RF after initial user interaction:    0.3152
number of annotations: 14036.0
Average F1 score for RF after additional user interaction: 0.3720
number of annotations: 17377.0
Average F1 score for RF after additional user interaction: 0.3974
number of annotations: 20672.0
Average F1 score for RF after additional user interaction: 0.4059
number of annotations: 23926.0
Average F1 score for RF after additional user interaction: 0.4125
number of annotations: 26672.0
Average F1 score for RF after additional user interaction: 0.4179


In [None]:
print('Selection using ground truth:')
scores, rf_prediction = evaluate_RF(dataset, features, cfg)
print(f"Average F1 score for RF after initial user interaction:    {scores['Avg_f1_tracts'].item():.4f}")
for i in range(5):
    annot = dataset.refinement_annotation(prediction=rf_prediction, seed=42)
    dataset.update_annotation(annot)
    print(f"number of annotations: {dataset.annotations.detach().cpu().sum()}")
    scores, rf_prediction = evaluate_RF(dataset, features, cfg)
    print(f"Average F1 score for RF after additional user interaction: {scores['Avg_f1_tracts'].item():.4f}")

Selection using ground truth:
Average F1 score for RF after initial user interaction:    0.3152
number of annotations: 16194.0
Average F1 score for RF after additional user interaction: 0.3604
number of annotations: 22890.0
Average F1 score for RF after additional user interaction: 0.4065
number of annotations: 28969.0
Average F1 score for RF after additional user interaction: 0.4237
number of annotations: 35380.0
Average F1 score for RF after additional user interaction: 0.4551
number of annotations: 39962.0
Average F1 score for RF after additional user interaction: 0.4915
