Training a Transformer on PI-CAI Dataset for csPCa detection:

## Train:

In [9]:
import os
import json
import torch
import datetime
import matplotlib.pyplot as plt
import torch.nn as nn
from tqdm import tqdm
import monai
from monai.losses import FocalLoss
from monai.transforms import (    
    AsDiscrete,EnsureChannelFirstd,Compose,LoadImaged,
    Orientationd,SpatialPadd, CenterSpatialCropd,Spacingd,ScaleIntensityRangePercentilesd)
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR
from monai.data import DataLoader,Dataset,load_decathlon_datalist,decollate_batch
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

root_dir = "./results_output_picai"
datasets_json = "./workdir/dataset_picai.json"

num_train = 25
num_val = 5
max_iterations = 10 
eval_num = 5 

train_transforms_simple = Compose(
    [LoadImaged(keys=["image", "label"]),
     EnsureChannelFirstd(keys=["image", "label"]),
     Orientationd(keys=["image", "label"], axcodes="RAS"),
     Spacingd(keys=["image", "label"],pixdim=(0.5, 0.5, 3.0),mode=("bilinear", "nearest"),),
     SpatialPadd(keys=["image", "label"], spatial_size=(272, 272, 32)), 
     ScaleIntensityRangePercentilesd(keys=["image"], lower=1, upper=99, b_min=0.0, b_max=1.0, clip=True),])

val_transforms = Compose(
    [LoadImaged(keys=["image", "label"]),
     EnsureChannelFirstd(keys=["image", "label"]),
     Orientationd(keys=["image", "label"], axcodes="RAS"),
     Spacingd(keys=["image", "label"],pixdim=(0.5, 0.5, 3.0),mode=("bilinear", "nearest"),),
     SpatialPadd(keys=["image", "label"], spatial_size=(272, 272, 32)), 
     CenterSpatialCropd(keys=["image", "label"], roi_size=(272, 272, 32)), 
     ScaleIntensityRangePercentilesd(keys=["image"], lower=1, upper=99, b_min=0.0, b_max=1.0, clip=True),])

In [None]:
datalist = load_decathlon_datalist(datasets_json, True, "training")[:num_train]
val_files = load_decathlon_datalist(datasets_json, True, "validation")[:num_val]

train_ds = Dataset(data=datalist,transform=train_transforms_simple,)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=8, pin_memory=True)
val_ds = Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=2, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
model = UNETR(
    in_channels=3,
    out_channels=2,
    img_size=(272, 272, 32),
    feature_size=16,
    hidden_size=768,
    mlp_dim=3072,
    num_heads=12,
    proj_type="perceptron",
    norm_name="instance",
    res_block=True,
    dropout_rate=0.3,
).to(device)

post_label = AsDiscrete(to_onehot=2)
post_pred = AsDiscrete(argmax=True, to_onehot=2)
dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)

Loss function and the optimizer: 

Depending on the task in hand the Loss function is selected Focal loss, CE loss, or Dice CE loss.

In [None]:
loss_function = nn.CrossEntropyLoss()
loss_function = monai.losses.DiceCELoss(include_background=False, to_onehot_y=True, softmax=True)
loss_function = FocalLoss(to_onehot_y=True, gamma=2.0)

torch.backends.cudnn.benchmark = True
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

In [None]:
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []
current_date = datetime.datetime.now().strftime("%Y-%m-%d")

for global_step in range(max_iterations):
    model.train()
    epoch_loss = 0.0

    print(f"Global training step: {global_step}/{max_iterations}")

    for step, batch in enumerate(train_loader, start=1):
        x = batch["image"].to(device)
        y = batch["label"].to(device)

        outputs = model(x)
        loss = loss_function(outputs, y)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        epoch_loss += loss.item()

    epoch_loss_avg = epoch_loss / step
    epoch_loss_values.append(epoch_loss_avg)
    print(f"Epoch {global_step} average loss: {epoch_loss_avg:.4f}\n")

    # ---- VALIDATION ----
    if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations:
        print(f"Validation at global iteration {global_step}")

        model.eval()
        with torch.no_grad():
            for batch in val_loader:
                val_inputs = batch["image"].to(device)
                val_labels = batch["label"].to(device)

                with torch.amp.autocast('cuda', enabled=True):
                    val_outputs = model(val_inputs)

                val_labels_convert = [post_label(i) for i in decollate_batch(val_labels)]
                val_outputs_convert = [post_pred(i) for i in decollate_batch(val_outputs)]

                dice_metric(y_pred=val_outputs_convert, y=val_labels_convert)

        dice_val = dice_metric.aggregate().item()
        dice_metric.reset()

        metric_values.append(dice_val)
        print(f"Validation Dice: {dice_val:.4f}")

        # Save best model
        if dice_val > dice_val_best:
            dice_val_best = dice_val
            global_step_best = global_step
            torch.save(
                model.state_dict(),
                os.path.join(root_dir, f"best_metric_model_FocalLoss_{current_date}.pth")
            )

# ---- SAVE FINAL MODEL ----
torch.save(
    model.state_dict(),
    os.path.join(root_dir, f"final_model_FocalLoss_{current_date}.pth"),
)

# ---- SAVE TRAINING DATA ----
json_data = {
    "global_step": list(range(max_iterations)),
    "loss": epoch_loss_values,
    "dice_metric": metric_values,
}

with open(os.path.join(root_dir, f"data_FocalLoss_{current_date}.json"), "w") as json_file:
    json.dump(json_data, json_file, indent=4)

print(f"Training complete. Best Dice: {dice_val_best:.4f} at iteration {global_step_best}")

## Plot the learning curves: 

In [None]:
import json
import matplotlib.pyplot as plt

# Load results
with open("results_data.json") as f:
    data = json.load(f)

loss_values = data["loss"]
dice_values = data["dice_metric"]

# Plot
plt.figure("train", (12, 6))

plt.subplot(1, 2, 1)
plt.title("Iteration Average Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.plot(range(len(loss_values)), loss_values, label="Training Loss", marker="o")
plt.legend()

plt.subplot(1, 2, 2)
plt.title("Validation Mean Dice")
plt.xlabel("Epoch")
plt.ylabel("Dice Score")
plt.plot(range(len(dice_values)), dice_values, label="Validation Dice", marker="x")
plt.legend()

plt.show()

## Prediction:

In [None]:
import nibabel as nib
import numpy as np 
from monai.transforms import ToTensord, ScaleIntensityd
import os
import monai

test_dir = "/picai_testset"
test_dataset_json = "dataset_test.json"

test_transforms = Compose([
    LoadImaged(keys=["image", "label"], image_only=False),
    EnsureChannelFirstd(keys=["image", "label"]),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    Spacingd(keys=["image", "label"],pixdim=(0.5, 0.5, 3.0),mode=("bilinear", "nearest"),),
    SpatialPadd(keys=["image", "label"], spatial_size=(272, 272, 32)), 
    CenterSpatialCropd(keys=["image", "label"], roi_size=(272, 272, 32)), 
    ScaleIntensityRangePercentilesd(keys=["image"], lower=1, upper=99, b_min=0.0, b_max=1.0, clip=True)
    ])
test_files = load_decathlon_datalist(test_dataset_json, True, "test")
test_ds = Dataset(data=test_files, transform=test_transforms)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNETR(
    in_channels=3,
    out_channels=2,
    img_size=(272, 272, 32),
    feature_size=16,
    hidden_size=768,
    mlp_dim=3072,
    num_heads=12,
    proj_type="perceptron",
    norm_name="instance",
    res_block=True,
    dropout_rate=0.3,
).to(device)

In [None]:
import os
import torch
from tqdm import tqdm
from monai.transforms import SaveImaged, AsDiscrete
from monai.data import decollate_batch

# Create output dirs
prediction_dir = "./prediction"
gt_dir = "./picai_testset/gt"
os.makedirs(prediction_dir, exist_ok=True)
os.makedirs(gt_dir, exist_ok=True)

# Post-processing transforms
post_label = AsDiscrete()
post_pred = AsDiscrete(argmax=True)

# Load model
model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model_FocalLoss.pth")))
model.eval()

# Saving transforms
save_pred = SaveImaged(
    keys="pred",
    output_dir=prediction_dir,
    output_postfix="pred",
    resample=False,
    separate_folder=False,
    print_log=False,
    meta_key_postfix="meta_dict",
    output_ext=".nii.gz"
)

save_gt = SaveImaged(
    keys="gt",
    output_dir=gt_dir,
    output_postfix="gt",
    resample=False,
    separate_folder=False,
    print_log=False,
    meta_key_postfix="meta_dict",
    output_ext=".nii.gz"
)

# Inference loop
with torch.no_grad():
    for batch in tqdm(test_loader):
        images = batch["image"].to(device)
        labels = batch["label"].to(device)

        outputs = model(images)

        labels_post = [post_label(x) for x in decollate_batch(labels)]
        preds_post = [post_pred(x) for x in decollate_batch(outputs)]

        dice_metric(y_pred=preds_post, y=labels_post)

        # Softmax + select positive class
        pred_soft = torch.softmax(outputs, dim=1)[:, 1]

        # Meta dict reused for pred and gt
        meta = {
            "filename_or_obj": batch["label_meta_dict"]["filename_or_obj"][0],
            "affine": batch["label_meta_dict"]["affine"][0],
        }

        save_pred({"pred": pred_soft[0], "meta_dict": meta})
        save_gt({"gt": labels_post[0], "meta_dict": meta})

    mean_dice = dice_metric.aggregate().item()
    dice_metric.reset()

print(mean_dice)

# Metric Evaluation:

In [None]:
from picai_eval import evaluate
from report_guided_annotation import extract_lesion_candidates
import glob

pred_cases = sorted(glob.glob(f"{prediction_dir}/**/*.nii.gz", recursive=True))
gt_cases = sorted(glob.glob(f"{gt_dir}/**/*.nii.gz", recursive=True))

metrics = evaluate(y_det=pred_cases, y_true=gt_cases, 
                   y_det_postprocess_func=lambda pred: extract_lesion_candidates(pred, threshold="dynamic")[0],)

print(f"\n")
print(f"AUROC: {round(metrics.auroc,4)}")
print(f"AP: {round(metrics.AP,4)}")
print(f"PICAI score: {round(.5*(metrics.auroc+metrics.AP),4)}")