# Segmentation in 3D using U-Nets with Delira - A very short introduction

*Author: Justus Schock, Alexander Moriz* 

*Date: 17.12.2018*
 
This Example shows how use the U-Net implementation in Delira with PyTorch.

Let's first setup the essential hyperparameters. We will use `delira`'s `HyperParameters`-class for this:

In [1]:
import torch
from delira.training import Hyperparameters

hyper_params = Hyperparameters(batch_size=1, # batchsize to use
                               num_epochs=10, # number of epochs to train
                               optimizer_cls=torch.optim.Adam, # optimization algorithm to use
                               optimizer_params={'lr': 1e-3}, # initialization parameters for this algorithm
                               criterions=[torch.nn.CrossEntropyLoss()], # the loss function
                               lr_sched_cls=None,  # the learning rate scheduling algorithm to use
                               lr_sched_params={}, # the corresponding initialization parameters
                               metrics=[]) # and some evaluation metrics

  from collections import MutableMapping
Process STDOUT and STDERR is being redirected to /tmp/ray/session_2019-01-07_11-31-44_24271/logs.
Waiting for redis server at 127.0.0.1:10513 to respond...
Waiting for redis server at 127.0.0.1:32510 to respond...
Starting the Plasma object store with 20.0 GB memory using /dev/shm.

View the web UI at http://localhost:8888/notebooks/ray_ui.ipynb?token=b9c73ad0dbf610323716a2c40b8bd8a100118351d0d54393

Couldn't import TensorFlow - disabling TensorBoard logging.


Using torch multi processing


Since we did not specify any metric, only the `CrossEntropyLoss` will be calculated for each batch. Since we have a classification task, this should be sufficient. We will train our network with a batchsize of 64 by using `Adam` as optimizer of choice.

## Logging and Visualization
To get a visualization of our results, we should monitor them somehow. For logging we will use `Visdom`. To start a visdom server you need to execute the following command inside an environment which has visdom installed: 
```shell
visdom -port=9999
```
This will start a visdom server on port 9999 of your machine and now we can start to configure our logging environment. To view your results you can open [http://localhost:9999](http://localhost:9999) in your browser.

In [2]:
from trixi.logger import PytorchVisdomLogger
from delira.logging import TrixiHandler
import logging

logger_kwargs = {
    'name': 'ClassificationExampleLogger', # name of our logging environment
    'port': 9999 # port on which our visdom server is alive
}

logger_cls = PytorchVisdomLogger

# configure logging module (and root logger)
logging.basicConfig(level=logging.INFO,
                    handlers=[TrixiHandler(logger_cls, **logger_kwargs)])


# derive logger from root logger
# (don't do `logger = logging.Logger("...")` since this will create a new
# logger which is unrelated to the root logger
logger = logging.getLogger("Test Logger")


Visdom successfully connected to server


Since a single visdom server can run multiple environments, we need to specify a (unique) name for our environment and need to tell the logger, on which port it can find the visdom server.

## Data Praparation
### Loading
Next we will create a small train and validation set (in this case they will be the same to show the overfitting capability of the UNet).

Our data is a brain MR-image thankfully provided by the [FSL](https://fsl.fmrib.ox.ac.uk/fsl/fslwiki) in their [introduction](http://www.fmrib.ox.ac.uk/primers/intro_primer/ExBox3/IntroBox3.html).

We first download the data and extract the T1 image and the corresponding segmentation:

In [3]:
from io import BytesIO
from zipfile import ZipFile
from urllib.request import urlopen

resp = urlopen("http://www.fmrib.ox.ac.uk/primers/intro_primer/ExBox3/ExBox3.zip")
zipfile = ZipFile(BytesIO(resp.read()))
#zipfile_list = zipfile.namelist()
#print(zipfile_list)
img_file = zipfile.extract("ExBox3/T1_brain.nii.gz")
mask_file = zipfile.extract("ExBox3/T1_brain_seg.nii.gz")

Now, we load the image and the mask (they are both 3D), convert them to a 32-bit floating point numpy array and ensure, they have the same shape (i.e. that for each voxel in the image, there is a voxel in the mask):

In [4]:
import SimpleITK as sitk
import numpy as np

# load image and mask
img = sitk.GetArrayFromImage(sitk.ReadImage(img_file))
img = img.astype(np.float32)
mask = mask = sitk.GetArrayFromImage(sitk.ReadImage(mask_file))
mask = mask.astype(np.float32)

assert mask.shape == img.shape
print(img.shape)

(192, 192, 174)


By querying the unique values in the mask, we get the following:

In [5]:
np.unique(mask)

array([0., 1., 2., 3.], dtype=float32)

This means, there are 4 classes (background and 3 types of tissue) in our sample.

To load the data, we have to use a `Dataset`. The following defines a very simple dataset, accepting an image slice, a mask slice and the number of samples. It always returns the same sample until `num_samples` samples have been returned.

In [6]:
from delira.data_loading import AbstractDataset

class CustomDataset(AbstractDataset):
    def __init__(self, img, mask, num_samples=1000):
        super().__init__(None, None, None, None)
        self.data = {"data": img.reshape(1, *img.shape), "label": mask.reshape(1, *mask.shape)}
        self.num_samples = num_samples
        
    def __getitem__(self, index):
        return self.data
    
    def __len__(self):
        return self.num_samples

Now, we can finally instantiate our datasets:

In [7]:
dataset_train = CustomDataset(img, mask, num_samples=10000)
dataset_val = CustomDataset(img, mask, num_samples=1)

### Augmentation
For Data-Augmentation we will apply a few transformations:

In [8]:
from batchgenerators.transforms import RandomCropTransform, \
                                        ContrastAugmentationTransform, Compose
from batchgenerators.transforms.spatial_transforms import ResizeTransform
from batchgenerators.transforms.sample_normalization_transforms import MeanStdNormalizationTransform

transforms = Compose([
    ContrastAugmentationTransform(), # randomly adjust contrast
    MeanStdNormalizationTransform(mean=[img.mean()], std=[img.std()])]) # use concrete values since we only have one sample (have to estimate it over whole dataset otherwise)

With these transformations we can now wrap our datasets into datamanagers:

In [9]:
from delira.data_loading import BaseDataManager, SequentialSampler, RandomSampler

manager_train = BaseDataManager(dataset_train, hyper_params.batch_size,
                                transforms=transforms,
                                sampler_cls=RandomSampler,
                                n_process_augmentation=4)

manager_val = BaseDataManager(dataset_val, hyper_params.batch_size,
                              transforms=transforms,
                              sampler_cls=SequentialSampler,
                              n_process_augmentation=4)

No DataLoader Class specified. Using BaseDataLoader
No DataLoader Class specified. Using BaseDataLoader


## Training

After we have done that, we can finally specify our experiment and run it. We will therfore use the already implemented `UNet3dPytorch`:

In [10]:
import warnings
warnings.simplefilter("ignore", UserWarning) # ignore UserWarnings raised by dependency code
warnings.simplefilter("ignore", FutureWarning) # ignore FutureWarnings raised by dependency code


from delira.training import PyTorchExperiment
from delira.training.train_utils import create_optims_default_pytorch
from delira.models.segmentation import UNet3dPyTorch

logger.info("Init Experiment")
experiment = PyTorchExperiment(hyper_params, UNet3dPyTorch,
                               name="Segmentation3dExample",
                               save_path="./tmp/delira_Experiments",
                               model_kwargs={'in_channels': 1, 'num_classes': 5},
                               optim_builder=create_optims_default_pytorch,
                               gpu_ids=[0])
experiment.save()

model = experiment.run(manager_train, manager_val)

Init Experiment
{'text': {'text': 'Hyperparameters:\n\tbatch_size = 1\n\tnum_epochs = 10\n\toptimizer_cls = <class \'torch.optim.adam.Adam\'>\n\toptimizer_params = {\n    "lr": 0.001\n}\n\t_criterions = [CrossEntropyLoss()]\n\tlr_sched_cls = None\n\tlr_sched_params = {}\n\t_metrics = []\n\n\tmodel_class = type'}}


HBox(children=(IntProgress(value=0, description='Epoch 1', max=10000, style=ProgressStyle(description_width='i…

RuntimeError: CUDA out of memory. Tried to allocate 1.53 GiB (GPU 0; 10.92 GiB total capacity; 3.47 GiB already allocated; 1.12 GiB free; 1.18 MiB cached)