## Setup Environment
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/master/3d_registration/paired_lung_ct.ipynb)

In [None]:
%pip install -q "monai[nibabel, tqdm]"

In [None]:
%pip install -q matplotlib
%matplotlib inline

## Setup imports

In [18]:
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.nn import MSELoss
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import DataLoader, Dataset, CacheDataset
from monai.losses import DiceLoss, BendingEnergyLoss
from monai.metrics import compute_meandice
from monai.networks.blocks import Warp
from monai.networks.nets import LocalNet
from monai.transforms import LoadImaged, AddChanneld, ToTensord, Compose, ScaleIntensityRanged, RandAffined
from monai.utils import set_determinism, first

print_config()

## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.
This allows you to save results and reuse downloads.
If not specified a temporary directory will be used.

In [20]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

paired_ct_lung


## Download dataset

In [21]:
resource = "https://zenodo.org/record/3835682/files/training.zip"

compressed_file = os.path.join(root_dir, "paired_ct_lung.zip")
data_dir = os.path.join(root_dir, "paired_ct_lung")
if not os.path.exists(data_dir):
    download_and_extract(resource, compressed_file, root_dir)
    os.rename(os.path.join(root_dir, "training"), data_dir)

## Set dataset path

In [22]:
data_dicts = [
    {
        "fixed_image": os.path.join(data_dir, "scans/case_%03d_insp.nii.gz" % idx),
        "moving_image": os.path.join(data_dir, "scans/case_%03d_exp.nii.gz" % idx),
        "fixed_label": os.path.join(data_dir, "lungMasks/case_%03d_insp.nii.gz" % idx),
        "moving_label": os.path.join(data_dir, "lungMasks/case_%03d_exp.nii.gz" % idx),
    }
    for idx in range(1, 21)
]

train_files, val_files = data_dicts[:18], data_dicts[18:]

## Set deterministic training for reproducibility

In [23]:
set_determinism(seed=0)

## Setup transforms for training and validation
Here we use several transforms to augment the dataset:
1. LoadImaged loads the lung CT images and labels from NIfTI format files.
2. AddChanneld as the original data doesn't have channel dim, add 1 dim to construct "channel first" shape.
5. ScaleIntensityRanged extracts intensity range [-57, 164] and scales to [0, 1].
9. RandAffined efficiently performs rotate, scale, shear, translate, etc. together based on PyTorch affine transform.
10. ToTensord converts the numpy array to PyTorch Tensor for further steps.

In [24]:
train_transforms = Compose(
    [
        LoadImaged(keys=["fixed_image", "moving_image", "fixed_label", "moving_label"]),
        AddChanneld(keys=["fixed_image", "moving_image", "fixed_label", "moving_label"]),
        ScaleIntensityRanged(
            keys=["fixed_image", "moving_image"], a_min=-285, a_max=3770, b_min=0.0, b_max=1.0, clip=True,
        ),
        # random affine transforms
        RandAffined(keys=["fixed_image", "moving_image", "fixed_label", "moving_label"],
                    mode=('bilinear', 'nearest', 'bilinear', 'nearest'),
                    prob=1.0, spatial_size=(192, 192, 208),
                    rotate_range=(0, 0, np.pi/15), scale_range=(0.1, 0.1, 0.1)),
        ToTensord(keys=["fixed_image", "moving_image", "fixed_label", "moving_label"]),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["fixed_image", "moving_image", "fixed_label", "moving_label"]),
        AddChanneld(keys=["fixed_image", "moving_image", "fixed_label", "moving_label"]),
        ScaleIntensityRanged(
            keys=["fixed_image", "moving_image"], a_min=-285, a_max=3770, b_min=0.0, b_max=1.0, clip=True,
        ),
        ToTensord(keys=["fixed_image", "moving_image", "fixed_label", "moving_label"]),
    ]
)

## Check transforms in DataLoader

In [25]:
check_ds = Dataset(data=val_files, transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
fixed_image, fixed_label = (check_data["fixed_image"][0][0], check_data["fixed_label"][0][0])
moving_image, moving_label = (check_data["moving_image"][0][0], check_data["moving_label"][0][0])
print(f"image shape: {fixed_image.shape}, label shape: {fixed_label.shape}")
# plot the slice [:, :, 80]
plt.figure("check", (12, 6))
plt.subplot(1, 4, 1)
plt.title("moving_image")
plt.imshow(moving_image[:, :, 80], cmap="gray")
plt.subplot(1, 4, 2)
plt.title("moving_label")
plt.imshow(moving_label[:, :, 80])
plt.show()
plt.subplot(1, 4, 3)
plt.title("fixed_image")
plt.imshow(fixed_image[:, :, 80], cmap="gray")
plt.subplot(1, 4, 4)
plt.title("fixed_label")
plt.imshow(fixed_label[:, :, 80])
plt.show()


## Define CacheDataset and DataLoader for training and validation

Here we use CacheDataset to accelerate training and validation process, it's 10x faster than the regular Dataset.
To achieve best performance, set cache_rate=1.0 to cache all the data, if memory is not enough, set lower value.
Users can also set cache_num instead of cache_rate, will use the minimum value of the 2 settings.
And set num_workers to enable multi-threads during caching.
If want to to try the regular Dataset, just change to use the commented code below.


In [27]:
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
# train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)

# use batch_size=2 to load images and use RandCropByPosNegLabeld
# to generate 2 x 4 images for network training
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)

val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=0)
# val_ds = Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=2, num_workers=4)


100%|██████████| 18/18 [00:03<00:00,  5.49it/s]
100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


#### Create Model, Loss, Optimizer

In [28]:
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
model = LocalNet(
        spatial_dims=3,
        in_channels=2,
        out_channels=3,
        num_channel_initial=32,
        extract_levels=[0, 1, 2, 3],
        out_activation=None,
        out_initializer="zeros").to(device)
warp_layer = Warp(spatial_dims=3).to(device)
image_loss = MSELoss()
label_loss = DiceLoss()
regularization = BendingEnergyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-5)

Define a forward pass function for ddf computation and warping to avoid duplicate coding.

In [29]:
def forward(batch_data, model):
    fixed_image = batch_data["fixed_image"].to(device)
    moving_image = batch_data["moving_image"].to(device)
    moving_label = batch_data["moving_label"].to(device)
    ddf = model(torch.cat((moving_image, fixed_image), dim=1))
    pred_image = warp_layer(moving_image, ddf)
    pred_label = warp_layer(moving_label, ddf)
    return ddf, pred_image, pred_label

## Execute a typical PyTorch training process

In [None]:
epoch_num = 1
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()

for epoch in range(epoch_num):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{epoch_num}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        optimizer.zero_grad()

        ddf, pred_image, pred_label = forward(batch_data, model)

        fixed_image = batch_data["fixed_image"].to(device)
        fixed_label = batch_data["fixed_label"].to(device)
        loss = image_loss(pred_image, fixed_image) + label_loss(pred_label, fixed_label) + regularization(ddf)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        print(f"{step}/{len(train_ds) // train_loader.batch_size}, train_loss: {loss.item():.4f}")
    
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            metric_sum = 0.0
            metric_count = 0
            for val_data in val_loader:

                val_ddf, val_pred_image, val_pred_label = forward(val_data, model)

                val_fixed_image = val_data["fixed_image"].to(device)
                val_fixed_label = val_data["fixed_label"].to(device)
                value = compute_meandice(
                    y_pred=val_pred_label,
                    y=val_fixed_label,
                    include_background=False,
                )
                metric_count += len(value)
                metric_sum += value.sum().item()
            metric = metric_sum / metric_count
            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"))
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
            )

----------
epoch 1/1


In [None]:
print(f"train completed, best_metric: {best_metric:.4f}  at epoch: {best_metric_epoch}")

#### Plot the loss and metric

In [None]:
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [val_interval * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.show()

#### Check best model output with the input image and label

In [None]:
model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))
model.eval()
with torch.no_grad():
    for i, val_data in enumerate(val_loader):
        val_ddf, val_pred_image, val_pred_label = forward(val_data, model)

        # plot the slice [:, :, 80]
        plt.figure("check", (18, 6))
        plt.subplot(1, 6, 1)
        plt.title(f"moving_image {i}")
        plt.imshow(val_data["moving_image"][0, 0, :, :, 80], cmap="gray")
        plt.subplot(1, 6, 2)
        plt.title(f"moving_label {i}")
        plt.imshow(val_data["moving_label"][0, 0, :, :, 80])
        plt.subplot(1, 6, 3)
        plt.title(f"fixed_image {i}")
        plt.imshow(val_data["fixed_image"][0, 0, :, :, 80], cmap="gray")
        plt.subplot(1, 6, 4)
        plt.title(f"fixed_label {i}")
        plt.imshow(val_data["fixed_label"][0, 0, :, :, 80])
        plt.subplot(1, 6, 5)
        plt.title(f"pred_image {i}")
        plt.imshow(val_pred_image[0, 0, :, :, 80], cmap="gray")
        plt.subplot(1, 6, 6)
        plt.title(f"pred_label {i}")
        plt.imshow(val_pred_label[0, 0, :, :, 80])
        plt.show()


## Cleanup data directory

Remove directory if a temporary was used.

In [None]:
if directory is None:
    shutil.rmtree(root_dir)