In [None]:
%matplotlib inline

In [1]:
from IPython.display import clear_output

import torch
from torch import nn
import os
import numpy as np
from nnunet.run.default_configuration import get_default_configuration
from nnunet.training.network_training.competitions_with_custom_Trainers.BraTS2021.nnUNetTrainerV2BraTSSegnet import nnUNetTrainerV2SegnetFocal, nnUNetTrainerSegNetPool5Conv3
from nnunet.training.loss_functions.dice_loss import Tversky_and_CE_loss
from nnunet.paths import (
    network_training_output_dir,
    preprocessing_output_dir,
    default_plans_identifier,
    
)
from nnunet.utilities.to_torch import maybe_to_torch, to_cuda
from nnunet.network_architecture.segnet import SegNet
from brats21 import utils as bu
import matplotlib.pyplot as plt



Please cite the following paper when using nnUNet:

Isensee, F., Jaeger, P.F., Kohl, S.A.A. et al. "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation." Nat Methods (2020). https://doi.org/10.1038/s41592-020-01008-z


If you have questions or suggestions, feel free to open an issue at https://github.com/MIC-DKFZ/nnUNet



In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE)
if DEVICE == "cuda":
    for i in range(torch.cuda.device_count()):
        print("\t", i, "=", torch.cuda.get_device_name(i))

## Load single trainer
The trainer contains everything you need for training and validation like dataloaders, netowrk, augmentation and training and loading logics.

In [None]:
# Load basic configs save in plans.pkl

network = "3d_fullres"
task = "Task500_Brats21"
network_trainer = "nnUNetTrainerSegNetPool6Conv4"
plans_identifier = "nnUNetPlansv2.1"

plans_file, output_folder_name, dataset_directory, batch_dice, stage, trainer = get_default_configuration(network, task, network_trainer, plans_identifier)

In [None]:
# Initialize trainer
trainer = trainer(plans_file, 4, output_folder_name, dataset_directory, batch_dice, stage, False, True, True)
trainer.initialize()

In [None]:
# Query datagenerator for next batch

data_dict = next(trainer.tr_gen)

data = data_dict['data']
target = data_dict['target']

print("Type of data", type(data))
print("Type of target", type(target), len(target))

In [None]:
# Transform everything to tensors and put on DEVICE

data = to_cuda(maybe_to_torch(data))
target = to_cuda(maybe_to_torch(target))
trainer.network = trainer.network.to(DEVICE)

In [None]:
# FORWARD- Predict on data
output = trainer.network(data)

In [None]:
# Compute loss
l = trainer.loss(output, target)
l

## Run *n* epochs

In [None]:
def run_n_iterations(network_trainer, network="3d_fullres", task="Task500_Brats21", plans_identifier = "nnUNetPlansv2.1", epochs = 1, batches_per_epoch = 1):
    """ Run any number of iterations to test wheather network trainer is functional.
    
    Args:
        network_trainer (str): Class name of trainer as in nnUNet
        network (str): Netowrk type. Defaults to "3d_fullres".
        task (str): Task ID. Defaults to "Task500_Brats21".
        plans_identifier (str): Identifier to read plans.pkl file. Defaults to "nnUNetPlansv2.1".
        epochs ([type], optional): [description]. Defaults to 1.
        batches_per_epoch ([type], optional): [description]. Defaults to 1.

    return:
        error raised during run_training function
    """
    plans_file, output_folder_name, dataset_directory, batch_dice, stage, trainer = get_default_configuration(network, task, network_trainer, plans_identifier)
        
    trainer = trainer(plans_file, 4, output_folder_name, dataset_directory, batch_dice, stage, False, True, True)
    trainer.initialize()
    
    trainer.max_num_epochs = epochs
    trainer.num_batches_per_epoch = batches_per_epoch
    
    try:
        trainer.run_training()
    except Exception as e:
        return {"Trainer": network_trainer, "Error": True, "Message": str(e)}
    return {"Trainer": network_trainer, "Error": False, "Message": None}

In [None]:
trainer_classes = [
    "nnUNetTrainerSegNetPool5Conv3",
    "nnUNetTrainerSegNetPool5Conv4",
    "nnUNetTrainerSegNetPool6Conv2",
    "nnUNetTrainerSegNetPool6Conv3",
    "nnUNetTrainerSegNetPool6Conv4"
]

In [None]:
# Test all trainer classes
summary = []
for tr_class in trainer_classes:
    res = run_n_iterations(tr_class)
    summary.append(res)
    clear_output()

In [None]:
# Examine errors
found_error = False
for res in summary:
    if res["Error"]:
        print("Found error in ", res["Trainer"])
        print("Message = ", res["Message"])
        print()
        found_error = True
if not found_error:
    print("No Errors found!! 🥳️")

## Analyse architecture

In [None]:
network="3d_fullres"
task="Task500_Brats21"
plans_identifier = "nnUNetPlansv2.1"
nets = []

for tr_class in trainer_classes:
    plans_file, output_folder_name, dataset_directory, batch_dice, stage, trainer = get_default_configuration(network, task, tr_class, plans_identifier)
    trainer = trainer(plans_file, 4, output_folder_name, dataset_directory, batch_dice, stage, False, True, True)
    trainer.initialize()
    nets.append((tr_class, trainer.network))
    clear_output()

In [None]:
def get_n_params(model):
    """Get number of parameters from pytorch model network."""
    return sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad])

In [None]:
for name, arch in nets:
    print(f"{name} -> Number of parameters = {get_n_params(arch):,}")

In [None]:
m = nn.AvgPool3d(kernel_size=(50, 44, 31), stride=None)

In [None]:
inp = torch.randn(20, 16, 50, 44, 31)

In [None]:
output = m(inp)

In [None]:
output.shape