## Segmentation example
Here we provide a simple demo to perform Unet segmetnation on FLAIR data to segmentaion Brain Tumors. 

For this purpose, we used publicaly available MRI data, thanks to [Decathlon 10 Challenge](https://decathlon-10.grand-challenge.org).

## File structure
For segmentation purpose, NiftyTorch requires below folder/file organization:
```
StudyName
    └───train
    │   └───subjectID
    │          flair.nii.gz
    │          t1w.nii.gz
    │          seg.nii.gz
    │           ...
    └───val
    │   └───subjectID
    │          flair.nii.gz
    │          t1w.nii.gz
    │          seg.nii.gz
    │           ...
    └───test
        └───subjectID
               flair.nii.gz
               t1w.nii.gz
                ...
```

`flair.nii.gz`, `t1w.nii.gz` etc are the inputs of the Unet segmentation and `seg.nii.gz` is the label mask. 
> note that test folder does not have to contain `seg.nii.gz`. If the labels are provided, the prediction code will also output the loss and accuracy. 

## Example data
An axial mosaic view of the FLAIR data that contains the tumor is shown here:

![flair_example](./flair_example_mosaic.png)

## Define paths 

In [None]:
import torch
from niftytorch.Models.Unet import *

data_dir = 'path_to_data'
train_path = data_dir+'/train/'
val_path = data_dir+'/val/'
test_path = data_dir+'/test/'

## Training

In [None]:
trainer = train_unet()

trainer.set_params(train_data = train_path, 
                   val_data = val_path, 
                   test_data = test_path, 
                   batch_size = 10, 
                   in_channels = 1,
                   out_channels = 4, num_epochs = 100, 
                   downsample = 80, cuda = 'cuda:2', filename = ('flair.nii.gz'),
                   init_features = 64, model_name = 'UNet_training')

trainer.train()

Results of the training, including Dice loss error across epocs will be reported. 

## Testing

The training code also saves the generated model in the file path corresponding to training data. This can be loaded as provided in the code below.

In [None]:
model_val = " "
PATH = 'path_to_data/train/UNet_training_generated_model'

if 'UNet_training_generated_model' in  PATH:
    model_val = torch.load(PATH)
    trainer.predict(model_val)
else:
    trainer.predict(model_val)
    
print("UNet done")

At this stage a new data called `pred.nii.gz` will be saved in the testing folder with the result of the segmentation. 

`pred.nii.gz` is a probability map of the tumor, which you can binarize to obtain a mask. 

Here we overlaid the `pred.nii.gz` on input `flair.nii.gz`:

<video controls width=300 src="segmentation.mov" />

## Multiclass Multimodality 

Change the number of out channels, by providing out_channels = 4 for getting a multi channel output. 
UNet model can be trained by loading data corresponding to multiple modality at the same time by setting the in_channels and filename parameter to 2 and ('flair.nii.gz', 't1w.nii.gz') respectively. 

In addition to this, in the case of multi-class prediction, so as to handle data imbalance, weights has to be passed to the loss function. Depending upon the number of classes to be predicted in the output the corresponding weights has to provided as an array to the weights argument. 

If weights are unknown, the user need not pass in the weights array and the internal implementation takes care of the same. 
    
Below is the code sample and output corresponding to multi-class multi-modality segmentation of the model.

In [None]:
trainer.set_params(train_data = train_path, val_data = val_path, test_data = test_path, 
                   batch_size = 10, in_channels = 2, out_channels = 4, num_epochs = 100, 
                   downsample = 80, cuda = 'cuda:2', weights = [6130, 887, 480, 101], 
                   filename = ('t1w.nii.gz', 'flair.nii.gz'), init_features = 64, 
                   model_name = 'UNet_multimodality_multiclass')

![unet_multiclass.png](./unet_multiclass.png)