# Experimenting and training the model
This was mostly done in Colab with T4 GPU, but can be run on CPU as well it will just take longer.
Either way you need CUDA toolkit installed. <br>
I used a template from MONAI and modified it to fit my needs, the original can be found here: https://github.com/Project-MONAI/tutorials/blob/main/2d_segmentation/torch/unet_training_dict.py

In [16]:
# First make all the necesarry imports
import logging
import sys
import tempfile
import glob
import random

import numpy as np

import torch
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau

import monai
from monai.utils import set_determinism
from monai.data import (
    decollate_batch, 
    DataLoader, 
    pad_list_data_collate
)
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import (
    Activations,
    EnsureChannelFirstd,
    AsDiscrete,
    Compose,
    LoadImaged,
    RandRotate90d,
    ScaleIntensityd,
    RandSpatialCropSamplesd,
)
from monai.visualize import plot_2d_or_3d_image

from sklearn.model_selection import train_test_split

### Set the configurations

In [2]:
# Set determinism fro reproducibility
set_determinism(seed=42)

# Set the logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

# Set random seed
random.seed(42)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Import data

Since the controls are missing ground thruths we need to create the ground thruths for them. <br>
We create the images with the same size as the controls and fill them with zeros. <br>

***(Edit the paths to where you have the data)***

In [18]:
# Path to healthy scans
healthy_images = glob.glob('../data/controls/imgs/*.png')
healthy_images.sort()

# Create empty ground truths for healthy images
for img_path in healthy_images:
    img = Image.open(img_path)
    empty_gt = np.zeros((img.size[1], img.size[0], 4), dtype=np.uint8)  # Adjusting to the image dimensions
    empty_gt[..., 3] = 255 # Set the alpha channel to 255 so they are black
    empty_gt_path = img_path.replace('../data/controls/imgs/', '../data/controls/lables/')
    Image.fromarray(empty_gt).save(empty_gt_path)

In [3]:
# Make a list of healthy images
healthy_images = glob.glob('../data/controls/imgs/*.png')
healthy_images.sort()
sampled_healthy_images = random.sample(healthy_images, 182)

# Make a list of healthy ground truths
sampled_healthy_gt = [img_path.replace('../data/controls/imgs/', '../data/controls/lables/') for img_path in sampled_healthy_images]

# Make a list of patient images
patient_images = glob.glob('../data/patients/imgs/*.png') 
patient_images.sort()

# Make a list of patient ground truths
patient_gt = glob.glob('../data/patients/labels/*.png')
patient_gt.sort()

# Combine patient and sampled healthy data
all_images = patient_images + sampled_healthy_images
all_gts = patient_gt + sampled_healthy_gt

# Ensure each image has a corresponding label
assert len(all_images) == len(all_gts)

# Create a list of dictionaries
data_dicts = [{'img': img, 'seg': gt} for img, gt in zip(all_images, all_gts)]

print(data_dicts[0:5])

[{'img': '../data/patients/imgs/patient_000.png', 'seg': '../data/patients/labels/segmentation_000.png'}, {'img': '../data/patients/imgs/patient_001.png', 'seg': '../data/patients/labels/segmentation_001.png'}, {'img': '../data/patients/imgs/patient_002.png', 'seg': '../data/patients/labels/segmentation_002.png'}, {'img': '../data/patients/imgs/patient_003.png', 'seg': '../data/patients/labels/segmentation_003.png'}, {'img': '../data/patients/imgs/patient_004.png', 'seg': '../data/patients/labels/segmentation_004.png'}]


In [4]:
# Split to training and validation set
train_files, val_files = train_test_split(data_dicts, test_size=0.2, random_state=42)

## Transformation pipeline
Using MONAI's pipeline to transform the data. <br>

In [5]:
# Define transforms for training data
train_transforms = Compose(
    [
        LoadImaged(keys=["img", "seg"]),
        EnsureChannelFirstd(keys=["img", "seg"]),
        ScaleIntensityd(keys=["img", "seg"]),
        RandSpatialCropSamplesd(
            keys=["img", "seg"], roi_size=[96, 96], num_samples=4,
        ),
        RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 1]),
    ]
)

# Define transforms for validation data
val_transforms = Compose(
    [
        LoadImaged(keys=["img", "seg"]),
        EnsureChannelFirstd(keys=["img", "seg"]),
        ScaleIntensityd(keys=["img", "seg"]),
    ]
)

Next we set up the data loaders for training and validation. <br>

In [6]:
# First check that the training transforms are working
check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=6, num_workers=2, collate_fn=pad_list_data_collate)
check_data = monai.utils.misc.first(check_loader)
print(check_data["img"].shape, check_data["seg"].shape)

torch.Size([24, 4, 96, 96]) torch.Size([24, 4, 96, 96])


In [19]:
# Create a training data loader
train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
# Use batch_size=3 to load images and use RandSpatialCropSamplesd to generate 3 x 4 images for network training
train_loader = DataLoader(
    train_ds,
    batch_size=3,
    shuffle=True,
    num_workers=2, # can be set to higher if you have more cores
    collate_fn=pad_list_data_collate,
    pin_memory=torch.cuda.is_available(),
)
# Create a validation data loader
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(
    val_ds, 
    batch_size=6, 
    num_workers=2, 
    collate_fn=pad_list_data_collate
)

# Set the metrics for evaluation
dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.4)])

## Model Setup
I used a `AttentionUNet` architecture, a model from the Monai library. This model's concept is originally introduced in a research paper, which can be accessed at: https://arxiv.org/abs/1804.03999 <br>
I aligned my approach with recommendations from another study on whole-body MIP-PET imaging, detailed in a paper available here: https://aapm.onlinelibrary.wiley.com/doi/10.1002/mp.16438. Following the guidelines from this paper, I used the `NAdam` optimizer along with the `DiceFocal` loss function for training. <br>


The model is set up to use the GPU if available, if not it will use the CPU. <br>


In [8]:
# Create AttentionUNet
model = monai.networks.nets.AttentionUnet(
    spatial_dims=2,
    in_channels=4,
    out_channels=4,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
).to(device)

# Set the loss function, optimizer and scheduler
loss_function = monai.losses.DiceFocalLoss(sigmoid=True, lambda_dice=1, lambda_focal=10)
optimizer = torch.optim.NAdam(model.parameters(), 1e-3)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=10, verbose=True)


### Model Training

In [23]:

def training(tempdir, epochs=10):
    monai.config.print_config()
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(epochs):
        print("-" * epochs)
        print(f"epoch {epoch}/{epochs}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
                    roi_size = (96, 96)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                    val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                    # compute metric for current iteration
                    dice_metric(y_pred=val_outputs, y=val_labels)
                # aggregate the final mean dice result
                metric = dice_metric.aggregate().item()
                scheduler.step(metric)
                # reset the status for next validation round
                dice_metric.reset()
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), "best_metric_model_segmentation2d_dict.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                        epoch + 1, metric, best_metric, best_metric_epoch
                    )
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")

    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
    writer.close()

Adjust the number of epochs and run the training.

In [None]:
with tempfile.TemporaryDirectory() as tempdir:
  training(tempdir, 20)

### Model Evaluation

In [20]:
from monai.transforms import SaveImage

# Load the best model
model.load_state_dict(torch.load("/home/magsam/workspace/tumor-segmentation/src/model/best_metric_model_segmentation2d_dict (4).pth"))

# Set the model to evaluation mode
model.eval()

# Sliding window inference need to input 1 image in every iteration
val_loader = DataLoader(val_ds, batch_size=1, num_workers=2, collate_fn=pad_list_data_collate)
saver = SaveImage(output_dir="./output", output_ext=".png", output_postfix="seg")


with torch.no_grad():
    for val_data in val_loader:
        val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
        # define sliding window size and batch size for windows inference
        roi_size = (96, 96)
        sw_batch_size = 4
        val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
        val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
        val_labels = decollate_batch(val_labels)
        # compute metric for current iteration
        dice_metric(y_pred=val_outputs, y=val_labels)
        #for val_output in val_outputs:
            #saver(val_output)
    # aggregate the final mean dice result
    print("evaluation metric:", dice_metric.aggregate().item())
    # reset the status
    dice_metric.reset()

evaluation metric: 0.8637876510620117
