Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.

## Setup environment

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[ignite,pyyaml]"
!pip install -q pytorch-lightning~=2.0.0

## Setup imports

In [None]:
from monai.config import print_config

print_config()

# Spleen Segmentation Lightning Bundle

In this tutorial we'll describe how to create a bundle for a segmentation network. This will include how to train and apply the network on the command line. Medical  will be used as the dataset with the bundle based off the [Spleen 3D segmentation with MONAI](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/spleen_segmentation_3d_lightning.ipynb) from Spleen segmentation using Task_09 subset from the Medical Segmentation Decathlon.

This work is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License. To view a copy of this license, visit http://creativecommons.org/licenses/by-sa/4.0/.


Let's start by initialising a bundle directory structure:

In [None]:
%%bash

python -m monai.bundle init_bundle SpleenSegLightning
which tree && tree SpleenSegLightning || true

## Metadata

We'll first replace the `metadata.json` file with our description of what the network will do:

In [None]:
%%writefile SpleenSegLightning/configs/metadata.json

{
    "version": "0.0.1",
    "changelog": {
        "0.0.1": "Initial version"
    },
    "monai_version": "1.2.0",
    "pytorch_version": "2.0.0",
    "numpy_version": "1.23.5",
    "optional_packages_version": {},
    "name": "SpleenSegLightning",
    "task": "3D Spleen segmentation network using MONAI and Pytorch Lightning",
    "description": "This is a demo network for segmentation of the spleen from 3D MRI images.",
    "authors": "Oeslle Lucena",
    "copyright": "Copyright (c) Oeslle Lucena",
    "data_source": "Task_09 subset from the Medical Segmentation Decathlon",
    "data_type": "Nifti",
    "intended_use": "This is suitable for demonstration only",
    "network_data_format": {
        "inputs": {
            "image": {
                "type": "image",
                "format": "magnitude",
                "modality": "MR",
                "num_channels": 1,
                "spatial_shape": [160, 160, 160],
                "dtype": "float32",
                "value_range": [0, 1],
                "is_patch_data": false,
                "channel_def": {"0": "image"}
            }
        },
        "outputs": {
            "pred": {
                "type": "image",
                "format": "labels",
                "num_channels": 2,
                "spatial_shape": [160, 160, 160],
                "dtype": "float32",
                "value_range": [],
                "is_patch_data": false,
                "channel_def": {"0": "background", "1": "spleen"}
            }
        }
    }
}


## Common Definitions

What we'll now do is construct the bundle configuration scripts to implement training, testing, and inference based off the original script file given above. Common definitions should be placed in a common file used with other scripts to reduce duplication. In our original script, the network definition and transform sequence will be used in multiple places so should go in this common file:

In [None]:
%%writefile SpleenSegLightning/configs/common.yaml

# common imports
imports: 
- $import glob
- $import os

# define a default root directory value, this can 
# overridden on the command line
bundle_dir: .
data_dir: .

# use constants from MONAI instead of hard-coding names
image: $monai.utils.CommonKeys.IMAGE
label: $monai.utils.CommonKeys.LABEL

# define a train and validation files from the data directory
train_images: '$sorted(glob.glob(os.path.join(@data_dir, ''imagesTr'', ''*.nii.gz'')))'
train_labels: '$sorted(glob.glob(os.path.join(@data_dir, ''labelsTr'', ''*.nii.gz'')))'

data_dicts: '$[{''image'': img, ''label'': lbl} for img, lbl in zip(@train_images, @train_labels)]'

train_files: '$@data_dicts[:-9]'
val_files: '$@data_dicts[-9:]'

# Scripts for training and evaluation

We'll define the training and evaluation yaml files and scripts contained the Pytorch Lightning based network.
First, we'll create a python module `scripts` directory and a script `model.py` file to contain the network definition:

In [None]:
!mkdir SpleenSegLightning/scripts

In [None]:
%%writefile SpleenSegLightning/scripts/__init__.py



In [None]:
%%writefile SpleenSegLightning/scripts/model.py

import pytorch_lightning
from monai.utils import set_determinism
from monai.transforms import (
    AsDiscrete,
    Compose,
    EnsureType,
)
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import decollate_batch
import torch


class MySegNet(pytorch_lightning.LightningModule):
    def __init__(self):
        super().__init__()
        self._model = UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=2,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=Norm.BATCH,
        )
        self.learning_rate = 1e-4
        self.loss_function = DiceLoss(to_onehot_y=True, softmax=True)
        self.post_pred = Compose([EnsureType("tensor", device="cpu"),
                                  AsDiscrete(argmax=True, to_onehot=2)])
        self.post_label = Compose([EnsureType("tensor", device="cpu"),
                                   AsDiscrete(to_onehot=2)])
        self.dice_metric = DiceMetric(include_background=False, reduction="mean",
                                      get_not_nans=False)
        self.best_val_dice = 0
        self.best_val_epoch = 0
        self.validation_step_outputs = []

    def forward(self, x):
        return self._model(x)

    def configure_optimizers(self):
        print("configure_optimizers", self.learning_rate)
        optimizer = torch.optim.Adam(self._model.parameters(), self.learning_rate)
        return optimizer

    def training_step(self, batch, batch_idx):
        images, labels = batch["image"], batch["label"]
        output = self.forward(images)
        loss = self.loss_function(output, labels)
        tensorboard_logs = {"train_loss": loss.item()}
        return {"loss": loss, "log": tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        images, labels = batch["image"], batch["label"]
        roi_size = (160, 160, 160)
        sw_batch_size = 4
        outputs = sliding_window_inference(images, roi_size, sw_batch_size, self.forward)
        loss = self.loss_function(outputs, labels)
        outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
        labels = [self.post_label(i) for i in decollate_batch(labels)]
        self.dice_metric(y_pred=outputs, y=labels)
        d = {"val_loss": loss, "val_number": len(outputs)}
        self.validation_step_outputs.append(d)
        return d

    def on_validation_epoch_end(self):
        val_loss, num_items = 0, 0
        for output in self.validation_step_outputs:
            val_loss += output["val_loss"].sum().item()
            num_items += output["val_number"]
        mean_val_dice = self.dice_metric.aggregate().item()
        self.dice_metric.reset()
        mean_val_loss = torch.tensor(val_loss / num_items)
        tensorboard_logs = {
            "val_dice": mean_val_dice,
            "val_loss": mean_val_loss,
        }
        if mean_val_dice > self.best_val_dice:
            self.best_val_dice = mean_val_dice
            self.best_val_epoch = self.current_epoch
        print(
            f"current epoch: {self.current_epoch} "
            f"current mean dice: {mean_val_dice:.4f}"
            f"\nbest mean dice: {self.best_val_dice:.4f} "
            f"at epoch: {self.best_val_epoch}"
        )
        self.validation_step_outputs.clear()  # free memory
        return {"log": tensorboard_logs}

Next, we'll create a `main.py` file to house the training and evaluation scripts. In this example, we use the `lightning_param` dictionary to customize some default arguments in the PyTorch Lightning `Trainer` class. We've set `num_nodes` and `devices` to 1, turned off the sanity checking (`num_sanity_val_steps=0`), and logged the training for every 3 steps (`log_every_n_steps=3`) for demonstration purposes. For more information about the PyTorch Lightning `Trainer` arguments, please refer to the following [link](https://lightning.ai/docs/pytorch/stable/common/trainer.html).

In [None]:
%%writefile SpleenSegLightning/scripts/main.py
from scripts.model import MySegNet
import pytorch_lightning

def train(lightninig_param, train_dl, val_dl):
    net = MySegNet()
    trainer = pytorch_lightning.Trainer(max_epochs=lightninig_param['max_epochs'], 
                                        default_root_dir=lightninig_param['default_root_dir'],
                                        check_val_every_n_epoch=lightninig_param['check_val_every_n_epoch'],
                                        devices=1, num_nodes=1, log_every_n_steps=3, num_sanity_val_steps=0)
    trainer.fit(model=net, train_dataloaders=train_dl, val_dataloaders=val_dl)


def evaluate(lightninig_param, ckpt_file, val_dl):
    net = MySegNet()
    trainer = pytorch_lightning.Trainer(default_root_dir=lightninig_param['default_root_dir'],
                                        devices=1, num_nodes=1)
    trainer.validate(model=net, dataloaders=val_dl, ckpt_path=ckpt_file)

## Training
Now, we'll define a `train.yaml` file to be used to set the configurations for the training stage:


In [None]:
%%writefile SpleenSegLightning/configs/train.yaml

imports:
- $from scripts.main import train
- $import glob
- $import os

# define a default root directory value, this can overridden on the command line
bundle_dir: .
data_dir: .

# define hyperparameters for the lightning trainer
max_epochs: 50
default_root_dir: $@bundle_dir+"/lightning_logs"
check_val_every_n_epoch: 1

lightninig_param:  '${
    ''max_epochs'': @max_epochs,
    ''default_root_dir'': @default_root_dir,
    ''check_val_every_n_epoch'': @check_val_every_n_epoch,
}'


# define a transform sequence by instantiating a Compose instance with a transform sequence
train_transform:
  _target_: Compose
  transforms:
  - _target_: LoadImaged
    keys: ['@image','@label']
    image_only: true
  - _target_: EnsureChannelFirstd
    keys:  ['@image','@label']
  - _target_: Orientationd
    keys:  ['@image','@label']
    axcodes: 'RAS'
  - _target_: Spacingd
    keys:  ['@image','@label']
    pixdim: [1.5, 1.5, 2.0]
  - _target_: ScaleIntensityRanged
    keys: '@image'
    a_min: -57
    a_max: 164
    b_min: 0.0
    b_max: 1.0
    clip: True
  - _target_: CropForegroundd
    keys: ['@image','@label']
    allow_smaller: False
    source_key: '@image'
  - _target_: RandCropByPosNegLabeld
    keys: ['@image','@label']
    label_key: '@label'
    spatial_size: [96, 96, 96]
    pos: 1
    neg: 1
    num_samples: 4
    image_key: '@image'
    image_threshold: 0

val_transform:
  _target_: Compose
  transforms:
  - _target_: LoadImaged
    keys: ['@image','@label']
    image_only: true
  - _target_: EnsureChannelFirstd
    keys: ['@image','@label']
  - _target_: Orientationd
    keys: ['@image','@label']
    axcodes: 'RAS'
  - _target_: Spacingd
    keys: ['@image','@label']
    pixdim: [1.5, 1.5, 2.0]
  - _target_: ScaleIntensityRanged
    keys: '@image'
    a_min: -57
    a_max: 164
    b_min: 0.0
    b_max: 1.0
    clip: True
  - _target_: CropForegroundd
    keys: ['@image','@label']
    source_key: '@image'
    allow_smaller: False

val_dataset:
  _target_: CacheDataset
  data: '@val_files'
  transform: '@val_transform'
  cache_rate: 1.0
  num_workers: 4

train_dataset:
  _target_: CacheDataset
  data: '@train_files'
  transform: '@train_transform'
  cache_rate: 1.0
  num_workers: 4
  
train_dl:
  _target_: DataLoader
  dataset: '@train_dataset'
  batch_size: 1
  shuffle: true
  num_workers: 4
  
val_dl:
  _target_: DataLoader
  dataset: '@val_dataset'
  batch_size: 1
  shuffle: false
  num_workers: 4

train:
- '$train(@lightninig_param, @train_dl, @val_dl)'

We can now train as normal to replicate the original code. For demonstration purpose, we set `max_epochs=1`.

In [None]:
%%bash

BUNDLE="./SpleenSegLightning"
DATA_DIR="./Task09_Spleen"
export PYTHONPATH="$BUNDLE"

# run the bundle with epochs set to 1 for speed during testing, change this to get a better result
python -m monai.bundle run train \
    --bundle_dir "$BUNDLE" \
    --data_dir "$DATA_DIR" \
    --meta_file "$BUNDLE/configs/metadata.json" \
    --config_file "['$BUNDLE/configs/common.yaml','$BUNDLE/configs/train.yaml']" \
    --max_epochs 1

The trained model is inside the subdir `lightning_logs` which the parent folder is defined in the yaml file as `default_root_dir`.

In [None]:
# !which tree && tree SpleenSegLightning || true

## Evaluation


Here we defined `evaluate` script to reproduce the results from the original code.

In [None]:
%%writefile SpleenSegLightning/configs/evaluate.yaml

# common imports
imports:
- $from scripts.main import evaluate
- $import glob
- $import os

ckpt_file: ""

# define hyperparameters for the lightning trainer
default_root_dir: $@bundle_dir+"/lightning_logs"
lightninig_param:  '${''default_root_dir'': @default_root_dir,}'


val_transform:
  _target_: Compose
  transforms:
  - _target_: LoadImaged
    keys: ['@image','@label']
    image_only: true
  - _target_: EnsureChannelFirstd
    keys: ['@image','@label']
  - _target_: Orientationd
    keys: ['@image','@label']
    axcodes: 'RAS'
  - _target_: Spacingd
    keys: ['@image','@label']
    pixdim: [1.5, 1.5, 2.0]
  - _target_: ScaleIntensityRanged
    keys: '@image'
    a_min: -57
    a_max: 164
    b_min: 0.0
    b_max: 1.0
    clip: True
  - _target_: CropForegroundd
    keys: ['@image','@label']
    source_key: '@image'
    allow_smaller: False

val_dataset:
  _target_: CacheDataset
  data: '@val_files'
  transform: '@val_transform'
  cache_rate: 1.0
  num_workers: 4
 
val_dl:
  _target_: DataLoader
  dataset: '@val_dataset'
  batch_size: 1
  shuffle: false
  num_workers: 4

  
# loads the weights from the given file (which needs to be set on the command line) then calls "evaluate" script
evaluate:
- '$evaluate(@lightninig_param,@ckpt_file, @val_dl)'

Evaluation is then run on the command line, using "evaluate" as the program to run and providing a path to the model weights with the `ckpt_file` and `data_dir` variable:

In [None]:
%%bash

DATA_DIR="./Task09_Spleen"
BUNDLE="./SpleenSegLightning"
export PYTHONPATH="$BUNDLE"

python -m monai.bundle run evaluate \
    --bundle_dir "$BUNDLE" \
    --data_dir "$DATA_DIR" \
    --meta_file "$BUNDLE/configs/metadata.json" \
    --config_file "['$BUNDLE/configs/common.yaml','$BUNDLE/configs/evaluate.yaml']" \
    --ckpt_file "./epoch=599-step=9600.ckpt"

## Summary and Next

This tutorial has covered:
* Creating full training and evaluation scripts in bundles using MONAI and Pytorch Lightning
* Training a network then evaluating its performance with scripts.

The next tutorial will discuss creating bundles to combine Pytorch Lightning CLI config files with MONAI bundles.