Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.

## 3D Segmentation with DAF3D

This tutorial shows how to construct a training workflow of the DAF3D network based on 'Deep Attentive Features for Prostate Segmentation in 3D Transrectal Ultrasound' <https://arxiv.org/pdf/1907.01743.pdf>. 

![DAF3D](../figures/DAF3D_scheme.png)

It contains the following features:
1. Transforms for dictionary format data
1. Use of DecathlonDataset to faciliate loading and caching images from Nifti Files
1. DAF3D model, Dice metric, customized loss function to work on multiple supervised signals
1. Visualization of supervised signals and Attentive Maps of last validation image
1. Sliding window inference method


Instead of prostate scans data is obtained from Decathlon Dataset Task09:Spleen since image properties are similar (one-channeled input, single-label ground truth). The Spleen dataset can be downloaded from http://medicaldecathlon.com/.

## Setup environment

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[gdown, nibabel, tqdm, ignite]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

## Setup imports

In [None]:
from monai.utils import first
from monai.transforms import (
    EnsureChannelFirstd,
    Compose,
    LoadImaged,
    ToTensord,
    ScaleIntensityRanged,
    Resized,
    Orientationd
)
from monai.networks.nets.daf3d import DAF3D
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from monai.losses import DiceLoss
from monai.data import DataLoader
from monai.config import print_config
from monai.apps import DecathlonDataset
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os

print_config()

## Setup paths to your data

In [None]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

## Load and transform spleen data from Decathlon Dataset

In [None]:
spatial_size = (128, 128, 90)

train_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    ScaleIntensityRanged(
            keys=["image"],
            a_min=-57,
            a_max=164,
            b_min=0.0,
            b_max=1.0,
            clip=True,
    ),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    Resized(
            keys=["image", "label"], 
            spatial_size=spatial_size, 
            mode=["trilinear", "nearest-exact"], 
            allow_missing_keys=True, 
            align_corners=[False, None]
    ),
    ToTensord(keys=["image", "label"]),
])

msd_task = "Task09_Spleen"
train_set = DecathlonDataset(root_dir, msd_task, "training", train_transforms)
val_set = DecathlonDataset(root_dir, msd_task, "validation", train_transforms)

## Check transforms in DataLoader

In [None]:
train_loader= DataLoader(train_set)
val_loader = DataLoader(val_set)

image = first(train_loader)["image"]
label = first(train_loader)["label"]
print("Image shape: ", image.shape, ", label shape: ", label.shape)
image = torch.squeeze(image)
label = torch.squeeze(label)
print(f"image shape: {image.shape},label shape: {label.shape}")
plt.figure("check", (12, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:,:,80], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:,:,80])
plt.show()

## Create Model, Optimizer, Dice Metric

In [None]:
device = torch.device('cuda:0') if (torch.cuda.is_available()) else 'cpu'

model = DAF3D(in_channels=1, out_channels=1, visual_output=True).to(device)

criterion_dice = DiceLoss(smooth_nr=1, smooth_dr=1, squared_pred=True, reduction="none")
criterion_bce = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
dice_metric = DiceMetric()

## Create custom loss function
Based on 'Deep Attentive Features for Prostate Segmentation in 3D Transrectal Ultrasound' <https://arxiv.org/pdf/1907.01743.pdf>.
Necessary to work on multiple supervised signals consisting of 
1. Layerwise outputs of Feature Pyramid Network (Single Layer Features / SLFs)
2. Layerwise outputs of Attention Module (Attentive Features / Refined SLFs)
3. Final prediction

In [None]:
def loss_function(outputs, label):
    criterion_dice = DiceLoss(smooth_nr=1, smooth_dr=1, squared_pred=True, reduction="none")
    criterion_bce = torch.nn.BCELoss()

    dice_losses = [criterion_dice(i, label) for i in outputs]
    bce_losses = [criterion_bce(i, label) for i in outputs]

    weights = [0.4, 0.5, 0.7, 0.8] #weights for slfs & attentive bce
    weights_special = [0.4, 0.7, 0.8, 1] #weights for attentive dice
    loss_slf = sum([weights[i] * (dice_losses[i] + bce_losses[i]) for i in range(0, 4)])
    loss_attentive = sum([weights[i] * bce_losses[i+4] + weights_special[i] * dice_losses[i+4] for i in range(0, 4)])
    loss_output = dice_losses[8] + bce_losses[8]

    loss = loss_slf + loss_attentive + loss_output
    return loss

## Execute DAF3D training process

In [None]:
numEpochs = 20
numTrainingData = len(train_set)
best_dice = -1
best_dice_epoch = -1
epoch_loss_values = []
epoch_dice_values = []
plotted_outputs = []


for epoch in range(1, numEpochs+1):
    print(("-" * 10) + " Epoch: {} ".format(epoch) + ("-" * 10))
    
    #Start training
    model.train()
    epoch_loss = 0.
    for batch_idx, batch_data in enumerate(train_loader):

        image = batch_data["image"].to(device)
        label = batch_data["label"].to(device)
        optimizer.zero_grad()
        outputs = model(image)
        outputs = [torch.sigmoid(i) for i in outputs]

        loss = loss_function(outputs, label)
        epoch_loss += loss.item()

        loss.backward()
        optimizer.step()

    epoch_loss = epoch_loss / (batch_idx + 1)
    epoch_loss_values.append(epoch_loss)
    print(f"Epoch {epoch} Finished ! Loss is {epoch_loss:.4f}")

    #Start validation
    model.eval()
    epoch_dice = 0.
    for batch_idx, batch_data in enumerate(val_loader):

        image = batch_data["image"].to(device)
        label = batch_data["label"].to(device)

        outputs = model(image) #since visual_outputs=True in shape of [final_prediction,inner_outputs]
        predict = outputs[0]
        inner_outputs = outputs[1:]
        predict = torch.sigmoid(predict)
        inner_outputs = [torch.sigmoid(i) for i in inner_outputs]

        dice_metric(predict > 0.5, label)
    
    epoch_dice = dice_metric.aggregate().item()
    dice_metric.reset()
    epoch_dice_values.append(epoch_dice)
    print(f"Epoch {epoch} Dice score: {epoch_dice:.4f}")

    if (epoch_dice > best_dice):
        best_dice = epoch_dice
        best_dice_epoch = epoch
        torch.save(model.state_dict(), os.path.join(root_dir, "best_dice_model.pth"))


print(f"Training completed, best dice score: {best_dice:.4f} at epoch: {best_dice_epoch}")

## Visualize loss and metric results

In [None]:
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.subplot(1, 2, 2)
plt.title("Epoch Dice Score")
x = [i + 1 for i in range(len(epoch_dice_values))]
y = epoch_dice_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.show()

## Visualize all outputs of last validation image

In [None]:
slice = 55

inner_outputs_slice =[torch.squeeze(i).detach().cpu().numpy()[:,:,slice] for i in inner_outputs]   
label_slice = torch.squeeze(label).detach().cpu().numpy()[:,:,slice]
image_slice = torch.squeeze(image).detach().cpu().numpy()[:,:,slice]
predict_slice = torch.squeeze(predict).detach().cpu().numpy()[:,:,slice]

slfs = inner_outputs_slice[0:4]
refined = inner_outputs_slice[4:8]
att_maps = inner_outputs_slice[8:]

plt.figure("Outputs", (30,10))

plt.subplot(3,7,1)
plt.title("Image")
plt.imshow(image_slice, cmap="gray")

for i,slf in enumerate(slfs):
    plt.subplot(3,7,i+2)
    plt.title("SLF " + str(i+1))
    plt.imshow(slf, cmap="gray")

for (i, am) in enumerate(att_maps):
    plt.subplot(3,7,i+9)
    plt.title("Attentive Map " + str(i+1))
    plt.imshow(am)

for i,rslf in enumerate(refined):
    plt.subplot(3,7,i+16)
    plt.title("Refined SLF " + str(i+1))
    plt.imshow(rslf, cmap="gray")

plt.subplot(3,7,6)
plt.title("Prediction")
plt.imshow(predict_slice, cmap="gray")

plt.subplot(3,7,7)
plt.title("Ground truth")
plt.imshow(label_slice, cmap="gray")       

## Load test data

In [None]:
test_transforms = Compose([
    LoadImaged(keys=["image"]),
    EnsureChannelFirstd(keys=["image"]),
    ScaleIntensityRanged(
            keys=["image"],
            a_min=-57,
            a_max=164,
            b_min=0.0,
            b_max=1.0,
            clip=True,
    ),
    Orientationd(keys=["image"], axcodes="RAS"),
    Resized(
            keys=["image"], 
            spatial_size=spatial_size, 
            mode=["trilinear"]
    ),
    ToTensord(keys=["image"]),
])
test_set = DecathlonDataset(root_dir, msd_task, "test", test_transforms)
test_loader = DataLoader(test_set)

## Inference on Test Set

In [None]:
model = DAF3D(in_channels=1, out_channels=1, visual_output=False)
model.load_state_dict(torch.load(os.path.join(root_dir, "best_dice_model.pth"), map_location=device))
model.eval()

with torch.no_grad():
    for test_data in test_loader:
        image = test_data["image"].to(device)
        outputs = model(image)
        roi_size = spatial_size
        sw_batch_size = 1

        prediction = sliding_window_inference(image, roi_size, sw_batch_size, model)

        plt.figure("Prediction on test set")

        plt.subplot(1,2,1)
        plt.title("image")
        plt.imshow(torch.squeeze(image).detach().cpu()[:,:,55], cmap="gray")

        plt.subplot(1,2,2)
        plt.title("prediction")
        plt.imshow(torch.squeeze(prediction).detach().cpu()[:,:,55], cmap="gray")

        plt.show()


## Cleanup data directory

In [None]:
if directory is None:
    shutil.rmtree(root_dir)