# nnUNet model exploration

In [1]:
%load_ext autoreload

In [2]:
%autoreload 2
from pathlib import Path

# Computation
import torch
from torch import nn
import numpy as np

# Radiology
import nibabel as nib

# Plotting
import matplotlib.pyplot as plt

# Custom
from brats21 import utils as bu
from brats21 import visualisation as vis
from nnunet.network_architecture.generic_UNet import Generic_UNet
from nnunet.network_architecture.segnet import SegNet



Please cite the following paper when using nnUNet:

Isensee, F., Jaeger, P.F., Kohl, S.A.A. et al. "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation." Nat Methods (2020). https://doi.org/10.1038/s41592-020-01008-z


If you have questions or suggestions, feel free to open an issue at https://github.com/MIC-DKFZ/nnUNet



In [3]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
plans = bu.load_pickle("/sc-scratch/sc-scratch-gbm-radiomics/nnUNet_trained_models/nnUNet/3d_fullres/Task500_Brats21/nnUNetTrainerV2BraTSSegnet__nnUNetPlansv2.1/plans.pkl")
plans.keys()

dict_keys(['num_stages', 'num_modalities', 'modalities', 'normalization_schemes', 'dataset_properties', 'list_of_npz_files', 'original_spacings', 'original_sizes', 'preprocessed_data_folder', 'num_classes', 'all_classes', 'base_num_features', 'use_mask_for_norm', 'keep_only_largest_region', 'min_region_size_per_class', 'min_size_per_class', 'transpose_forward', 'transpose_backward', 'data_identifier', 'plans_per_stage', 'preprocessor_name', 'conv_per_stage'])

In [5]:
plans_per_stage = plans["plans_per_stage"][0]
plans_per_stage

{'batch_size': 2,
 'num_pool_per_axis': [5, 5, 5],
 'patch_size': array([128, 128, 128]),
 'median_patient_size_in_voxels': array([140, 171, 137]),
 'current_spacing': array([1., 1., 1.]),
 'original_spacing': array([1., 1., 1.]),
 'do_dummy_2D_data_aug': False,
 'pool_op_kernel_sizes': [[2, 2, 2],
  [2, 2, 2],
  [2, 2, 2],
  [2, 2, 2],
  [2, 2, 2]],
 'conv_kernel_sizes': [[3, 3, 3],
  [3, 3, 3],
  [3, 3, 3],
  [3, 3, 3],
  [3, 3, 3],
  [3, 3, 3]]}

In [6]:
# TODO find defenition in plans
NUM_POOL = 5

In [7]:
plans["num_classes"]

3

## Load Generic_UNet

In [None]:
BEST_MODEL_PATH = Path("/sc-scratch/sc-scratch-gbm-radiomics/nnUNet_trained_models/nnUNet/3d_fullres/Task500_Brats21/nnUNetTrainerV2BraTSSegnet__nnUNetPlansv2.1/fold_4/model_best.model")

In [None]:
state_dict = torch.load(BEST_MODEL_PATH, map_location=torch.device(DEVICE))["state_dict"]

In [None]:
unet = SegNet(
    input_channels=plans["num_modalities"], # Our 4 modalities (FLAIR, T1, T1CE, T2)
    base_num_features=plans["base_num_features"], # Determines the featuremap size. Here 32.
    num_classes=plans["num_classes"], # Target classes (necrotic, enhancing, edema) + background class
    num_pool=NUM_POOL, # Number of localization pathways (for deep supervision???)
    conv_op=nn.Conv3d,
    norm_op=nn.InstanceNorm3d,
    convolutional_pooling=False,
    convolutional_upsampling=False,
).to(DEVICE)

In [None]:
unet.load_state_dict(state_dict, strict=True)

## Run prediction

### On random data

In [None]:
sample_shape = [1, plans["num_modalities"]] + list(plans_per_stage["patch_size"] )
sample = torch.rand(*sample_shape).to(DEVICE)
print("Sample input shape:", sample_shape)

In [None]:
output = unet(sample)

In [None]:
for i in range(NUM_POOL):
    print(f"Output pool {i} shape: {output[i].shape}")

### Loss

In [None]:
from nnunet.training.loss_functions.dice_loss import Tversky_and_CE_loss, get_tp_fp_fn_tn, DC_and_CE_loss

In [None]:
loss = DC_and_CE_loss({"batch_dice": True, "smooth": 1e-5, "do_bg": False}, {})

In [None]:
loss(output[0], sample)

### Grid Search

In [11]:
nPoolings = 5, 6
nConvolutions = 2, 3, 4

sample_shape = [1, plans["num_modalities"]] + list(plans_per_stage["patch_size"] )
sample = torch.rand(*sample_shape).to(DEVICE)
print("Sample input shape:", sample_shape)

Sample input shape: [1, 4, 128, 128, 128]


In [12]:
for pool in nPoolings:
    for conv in nConvolutions:
        print(f"Poolings: {pool} Convs: {conv}")
        
        unet = SegNet(
            input_channels=plans["num_modalities"], # Our 4 modalities (FLAIR, T1, T1CE, T2)
            base_num_features=plans["base_num_features"], # Determines the featuremap size. Here 32.
            num_classes=plans["num_classes"], # Target classes (necrotic, enhancing, edema) + background class
            num_pool=pool,
            num_conv_per_stage=conv,
            conv_op=nn.Conv3d,
            norm_op=nn.InstanceNorm3d,
            convolutional_pooling=False,
            convolutional_upsampling=False,
        ).to(DEVICE)
        out = unet(sample)

Poolings: 5 Convs: 2
Poolings: 5 Convs: 3
Poolings: 5 Convs: 4
Poolings: 6 Convs: 2
Poolings: 6 Convs: 3
Poolings: 6 Convs: 4


In [15]:
unet

SegNet(
  (conv_blocks_localization): ModuleList(
    (0): Sequential(
      (0): StackedConvLayers(
        (blocks): Sequential(
          (0): ConvDropoutNonlinNorm(
            (conv): Conv3d(640, 320, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            (instnorm): InstanceNorm3d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
          )
          (1): ConvDropoutNonlinNorm(
            (conv): Conv3d(320, 320, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            (instnorm): InstanceNorm3d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
          )
          (2): ConvDropoutNonlinNorm(
            (conv): Conv3d(320, 320, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            (instnorm): InstanceNorm3d(320, eps=1e-05, momentum=0.1, affine=True, track_ru

### On train data

In [None]:
train_data_path = Path("dataset/nnUNet_raw_data_base/nnUNet_raw_data/Task500_Brats21/imagesTr")

In [None]:
data_generator = bu.NNUnetDataGenerator(train_data_path)

In [None]:
real_sample = data_generator[0]
real_sample.shape

In [None]:
output = unet(sample)

In [None]:
plt.imshow(real_sample[0, 0, :, :, 55].detach().numpy())

In [None]:
plt.imshow(output[0][0, 2, :, :, 55].detach().numpy())
plt.colorbar()

## Plot current progress

In [None]:
data_dir = Path("dataset")

In [None]:
fig, axs = vis.plot_nnunet_progress(data_dir, show_pbar=False, grid=True, alpha=.7, lw=2)