In [None]:
!pip -q install monai gdown einops mlflow pynrrd torchinfo 
!pip install pandas numpy nibabel tqdm
!python -c "import monai" || pip install -q "monai-weekly[gdown, nibabel, tqdm, ignite]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

In [1]:
import os
from glob import glob
import shutil
import tempfile
import time
import warnings
import torch
import numpy as np
import pandas as pd
import nibabel as nib
from pathlib import Path
from collections.abc import Callable, Sequence, Hashable
from typing import Mapping,Dict
import matplotlib.pyplot as plt

from monai.transforms import (
    EnsureType,
    FillHoles,
    OneOf,
    SpatialCropd,
    Activations,
    Activationsd,
    ConcatItemsd,
    AsDiscrete,
    AsDiscreted,
    Compose,
    Invertd,
    LoadImaged,
    MapTransform,
    RandAffined,
    NormalizeIntensityd,
    ToTensord,
    EnsureChannelFirstd ,
    Orientationd,
    Spacingd,
    ScaleIntensityRanged,
    CropForegroundd,
    NormalizeIntensityd,
    Resized,
    SaveImaged,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    RandSpatialCropSamplesd,
    RandFlipd,
    RandRotated,
    EnsureTyped,
    ScaleIntensityd,
    RandCropByPosNegLabeld,
)

from monai.apps import DecathlonDataset
from monai.handlers.utils import from_engine
from monai.losses import DiceLoss, DiceCELoss
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric, HausdorffDistanceMetric
from monai.networks.nets import SegResNet, SwinUNETR
from monai.data import Dataset, DataLoader, CacheDataset, decollate_batch
from monai.utils import first
from monai.utils import set_determinism
from monai.config import print_config
from monai.data.meta_tensor import MetaTensor
from monai.config.type_definitions import NdarrayOrTensor
from monai.utils.misc import ImageMetaKey

In [3]:
set_determinism(seed=123)

In [4]:
root_dir = 'Hecktor22/model_data'
data_dir = 'hecktor2022_training/hecktor2022'
resampled_ct_path = 'hecktor2022_training/hecktor2022/resampled_largerCt'
resampled_pt_path = 'hecktor2022_training/hecktor2022/resampled_largerPt'
resampled_label_path = 'hecktor2022_training/hecktor2022/resampled_largerlabel'

train_images = sorted(
    glob(os.path.join(resampled_ct_path, "*_CT*")))
train_images2 = sorted(
    glob(os.path.join(resampled_pt_path, "*_PT*")))
train_labels = sorted(
    glob(os.path.join(resampled_label_path, "*.nii.gz")))
data_dicts = [{"image": image_name, "image2": pet_image, 'label': label_name}
    for image_name, pet_image, label_name in zip(train_images, train_images2, train_labels)
]
len(data_dicts)

488

In [5]:
print(len(train_images))
print(len(train_images2))
print(len(train_labels))

488
488
488


In [6]:
import random
x=[i for i in range(488)]
# print(x)
random.shuffle(x)
# print(x)
train_index,val_index,test_index=x[:100],x[400:410],x[480:]
train_files=[]
val_files=[]
test_files=[]
for i in train_index:
    train_files.append(data_dicts[i])
for i in val_index:
    val_files.append(data_dicts[i])
for i in test_index:
    test_files.append(data_dicts[i])

In [5]:
print(len(train_files))
print(len(val_files))
print(len(test_files))

100
10
8


In [None]:
train_files[22]

In [None]:
# root_dir = 'model_data'
# data_dir = 'Hecktor22/data'

# train_images_ct = sorted(glob(os.path.join(data_dir, 'TrainData', '*_CT.nii.gz')))
# train_images_pt = sorted(glob(os.path.join(data_dir, 'TrainData', '*_PT.nii.gz')))
# train_labels = sorted(glob(os.path.join(data_dir, 'TrainLabels', '*.nii.gz')))
# train_files = [{"image": image_name, "image2": pet_image, 'label': label_name} for image_name, pet_image, label_name in zip(train_images_ct, train_images_pt, train_labels)]

# val_images_ct = sorted(glob(os.path.join(data_dir, 'ValData', '*_CT.nii.gz')))
# val_images_pt = sorted(glob(os.path.join(data_dir, 'ValData', '*_PT.nii.gz')))
# val_labels = sorted(glob(os.path.join(data_dir, 'ValLabels', '*.nii.gz')))
# val_files = [{"image": image_name, "image2": pet_image, 'label': label_name} for image_name, pet_image, label_name in zip(val_images_ct, val_images_pt, val_labels)]

In [None]:
# print(train_files)
# print(val_files)

In [6]:
class ConvertToMultiChannelBasedOnClassesd(MapTransform):
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            result.append(d[key] == 0)
            result.append(d[key] == 1)
            result.append(d[key] == 2)
            d[key] = torch.stack(result, axis=0).float()
        return d

In [10]:
ct_a_min = -200
ct_a_max = 400
pt_a_min = 0
pt_a_max = 25
crop_samples = 2
input_size = [192, 192, 192]
modes_2d = ['bilinear', 'bilinear', 'nearest']
p = 0.5
image_keys = ["image", "image2", "label"]
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "image2", "label"]),
        EnsureChannelFirstd(keys = ["image", "image2"]),
        # EnsureChannelFirstd(keys = ["image", "image2", "label"]),
        # EnsureTyped(keys=["image", "image2", "label"]),
        ConvertToMultiChannelBasedOnClassesd(keys="label"),
        Orientationd(keys=["image", "image2", "label"], axcodes="RAS"),
        Spacingd(
            keys=image_keys,
            pixdim=(1, 1, 1),
            mode=modes_2d,
        ),
        ScaleIntensityRanged(keys=['image'], a_min=ct_a_min, a_max=ct_a_max, b_min=0.0, b_max=1.0, clip=True),
        ScaleIntensityRanged(keys=['image2'], a_min=pt_a_min, a_max=pt_a_max, b_min=0.0, b_max=1.0, clip=True),
        # CropForegroundd(keys=image_keys, source_key='image'),
        # RandCropByPosNegLabeld(
        #     keys=image_keys,
        #     label_key='label',
        #     spatial_size=input_size,
        #     pos=1,
        #     neg=1,
        #     num_samples=crop_samples,
        #     image_key='image',
        #     image_threshold=0,
        # ),
        RandFlipd(keys=["image", "image2", "label"], prob=p/3, spatial_axis=0),
        RandFlipd(keys=["image", "image2", "label"], prob=p/3, spatial_axis=1),
        RandFlipd(keys=["image", "image2", "label"], prob=p/3, spatial_axis=2),
        ToTensord(keys=["image", "image2", "label"])
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "image2", "label"]),
        EnsureChannelFirstd(keys = ["image", "image2"]),
        # EnsureChannelFirstd(keys = ["image", "image2", "label"]),
        # EnsureTyped(keys=["image", "image2", "label"]),
        ConvertToMultiChannelBasedOnClassesd(keys='label'),
        Orientationd(keys=["image", "image2", "label"], axcodes="RAS"),
        Spacingd(
            keys=image_keys,
            pixdim=(1, 1, 1),
            mode=modes_2d,
        ),
        ScaleIntensityRanged(keys=['image'], a_min=ct_a_min, a_max=ct_a_max, b_min=0.0, b_max=1.0, clip=True),
        ScaleIntensityRanged(keys=['image2'], a_min=pt_a_min, a_max=pt_a_max, b_min=0.0, b_max=1.0, clip=True),
        CropForegroundd(keys=image_keys, source_key='image'),
        ToTensord(keys=["image", "image2", "label"])
    ]
)

orig_transforms = Compose(
    [
        LoadImaged(keys=["image", "image2", "label"]),
        # ConvertToMultiChannelBasedOnClassesd(keys='label'),
    ]
)       

In [None]:
check_ds = Dataset(data=train_files, transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
# for check_data in check_loader:
#     print(check_data[0]['image'].shape)
#     print(check_data[0]['label'].shape)
#     break
# image, label = (check_data["image"][0][0], check_data["label"][0][0])
# print(f"image shape: {image.shape}, label shape: {label.shape}")
check_data = check_ds[2]
# print(check_data['image'].meta[ImageMetaKey.FILENAME_OR_OBJ])
plt.figure("image", (6, 6))
for i in range(1):
    plt.subplot(1, 1, i + 1)
    plt.title(f"image channel {i}")
    plt.imshow(check_data["image"][i, :, :, 90].detach().cpu())
plt.show()
plt.figure("image", (6, 6))
for i in range(1):
    plt.subplot(1, 1, i + 1)
    plt.title(f"image channel {i}")
    plt.imshow(check_data["image2"][i, :, :, 150].detach().cpu())
plt.show()
plt.figure("label", (6, 6))
for i in range(1):
    plt.subplot(1, 1, i + 1)
    plt.title(f"label channel {i}")
    plt.imshow(check_data["label"][i, :, :, 97].detach().cpu())
plt.show()

In [None]:
np.count_nonzero(label == 0)

In [11]:
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.0,
    num_workers=2)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=2)
val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=0.0,
    num_workers=2)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2)

In [None]:
train_ds[5]['label'].shape

In [12]:
max_epochs = 10
val_interval = 1
VAL_AMP = True
lr = 1e-4
momentum = 0
weight_decay = 1e-5
T_0 = 40
n_classes = 3
n_channels = 2
input_size = (192, 192, 192)
# standard PyTorch program style: create SegResNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
model = SegResNet(
    blocks_down=[1, 2, 2, 4],
    blocks_up=[1, 1, 1],
    init_filters=16,
    in_channels=2,
    out_channels= n_classes,
    dropout_prob=0.2,
).to(device)
# loss_function = DiceCELoss(include_background=False, to_onehot_y=True, softmax=True)
# optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

# # Scheduler
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=optimizer, T_0=T_0, T_mult=1, eta_min=1e-8)
# dice_metric = DiceMetric(include_background=False, reduction='mean', get_not_nans=False)
# dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")

# post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
# post_label = AsDiscrete(to_onehot=n_classes)
# post_pred = AsDiscrete(argmax=True, to_onehot=n_classes)
loss_function = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True)
optimizer = torch.optim.AdamW(model.parameters(), 1e-4, weight_decay=1e-5)
# lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=optimizer, T_0=T_0,
                                                                     T_mult=1, eta_min=1e-8)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)

dice_metric = DiceMetric(include_background=False, reduction="mean")
dice_metric_batch = DiceMetric(include_background=False, reduction="mean_batch")
# post_pred = Compose([AsDiscrete(argmax=True, to_onehot=n_classes)])
# post_label = Compose([AsDiscrete(to_onehot=n_classes)])
# post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

# define inference method
def inference(input):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=input_size,
            sw_batch_size=4,
            predictor=model,
            overlap=0.5,
        )

    if VAL_AMP:
        with torch.cuda.amp.autocast():
            return _compute(input)
    else:
        return _compute(input)


# use amp to accelerate training
scaler = torch.cuda.amp.GradScaler()
# enable cuDNN benchmark
torch.backends.cudnn.benchmark = True

In [13]:
best_metric = -1
best_metric_epoch = -1
best_metrics_epochs_and_time = [[], [], []]
epoch_loss_values = []
metric_values = []
metric_values_1 = []
metric_values_2 = []
max_epochs = 10
total_start = time.time()
for epoch in range(max_epochs):
    epoch_start = time.time()
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step_start = time.time()
        step += 1
        inputsct, inputspt, labels = (
            batch_data['image'].to(device),
            batch_data['image2'].to(device),
            batch_data['label'].to(device),
        )
        inputs = torch.concat([inputsct, inputspt], axis=1)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
        scaler.scale(loss).backward()
        # loss.backward()
        scaler.step(optimizer)
        # optimizer.step()
        scaler.update()
        epoch_loss += loss.item()
        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}"
            f", train_loss: {loss.item():.4f}"
            f", step time: {(time.time() - step_start):.4f}"
        )
    # lr_scheduler.step()
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            val_loss = 0
            count = 0
            for val_data in val_loader:
                count += 1
                val_inputsct, val_inputspt, val_labels = (
                    val_data['image'].to(device),
                    val_data['image2'].to(device),
                    val_data['label'].to(device),
                )
                val_inputs = torch.concat([val_inputsct, val_inputspt], axis=1)
                val_outputs = inference(val_inputs)
                loss = loss_function(val_outputs, val_labels)
                val_loss += loss.item()
                # val_outputs_convert = [post_pred(i) for i in decollate_batch(val_outputs)]
                # val_labels_convert = [post_label(i) for i in decollate_batch(val_labels)]
                # val_outputs = sliding_window_inference(val_inputs, input_size, 4, model)
                # val_label_list = decollate_batch(val_label)
                # val_label_convert = [post_label(val_label_tensor) for val_label_tensor in val_label_list]
                # val_outputs_list = decollate_batch(val_outputs)
                # val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
                val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                # dice_metric(y_pred=val_output_convert, y=val_label_convert)
                dice_metric(y_pred=val_outputs, y=val_labels)
                dice_metric_batch(y_pred=val_outputs_convert, y=val_labels_convert)
            print(f"val loss: {val_loss / count}")
            metric = dice_metric.aggregate().item()
            metric_values.append(metric)
            metric_batch = dice_metric_batch.aggregate()
            metric_1 = metric_batch[0].item()
            metric_values_1.append(metric_1)
            metric_2 = metric_batch[1].item()
            metric_values_2.append(metric_2)
            dice_metric.reset()
            dice_metric_batch.reset()

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                best_metrics_epochs_and_time[0].append(best_metric)
                best_metrics_epochs_and_time[1].append(best_metric_epoch)
                best_metrics_epochs_and_time[2].append(time.time() - total_start)
                torch.save(
                    model.state_dict(),
                    os.path.join(root_dir, "best_metric_model.pth"),
                )
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f" 1: {metric_1:.4f} 2: {metric_2:.4f}"
                f"\nbest mean dice: {best_metric:.4f}"
                f" at epoch: {best_metric_epoch}"
            )
    print(f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}")
total_time = time.time() - total_start

----------
epoch 1/10
1/100, train_loss: 1.7973, step time: 20.7050
2/100, train_loss: 1.6567, step time: 0.8898
3/100, train_loss: 1.6226, step time: 0.8731
4/100, train_loss: 1.5979, step time: 0.9499



KeyboardInterrupt



In [None]:
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}, total time: {total_time}.")

In [None]:
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("epoch")
plt.plot(x, y, color="red")
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [val_interval * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("epoch")
plt.plot(x, y, color="green")
plt.show()

plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Val Mean Dice : 1")
x = [val_interval * (i + 1) for i in range(len(metric_values_1))]
y = metric_values_1
plt.xlabel("epoch")
plt.plot(x, y, color="blue")
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice : 2")
x = [val_interval * (i + 1) for i in range(len(metric_values_2))]
y = metric_values_2
plt.xlabel("epoch")
plt.plot(x, y, color="brown")
plt.show()

In [None]:
model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))
model.eval()
with torch.no_grad():
    # select one image to evaluate and visualize the model output
    val_inputct = val_ds[5]["image"].unsqueeze(0).to(device)
    val_inputpt = val_ds[5]["image2"].unsqueeze(0).to(device)
    val_input = torch.concat([val_inputct, val_inputpt], axis=1)
    roi_size = (192, 192, 192)
    sw_batch_size = 4
    val_output = inference(val_input)
    val_output = post_trans(val_output[0])
    plt.figure("image", (6, 6))
    for i in range(1):
        plt.subplot(1, 2, i + 1)
        plt.title(f"image channel {i}")
        plt.imshow(val_ds[5]["image"][i, :, :, 70].detach().cpu(), cmap="gray")
    plt.show()
    # visualize the 3 channels label corresponding to this image
    plt.figure("label", (18, 6))
    for i in range(3):
        plt.subplot(1, 3, i + 1)
        plt.title(f"label channel {i}")
        plt.imshow(val_ds[5]["label"][i, :, :, 70].detach().cpu())
    plt.show()
    # visualize the 3 channels model output corresponding to this image
    plt.figure("output", (18, 6))
    for i in range(3):
        plt.subplot(1, 3, i + 1)
        plt.title(f"output channel {i}")
        plt.imshow(val_output[i, :, :, 70].detach().cpu())
    plt.show()

In [None]:
with torch.no_grad():
    # select one image to evaluate and visualize the model output
    val_inputct = val_ds[7]["image"].unsqueeze(0).to(device)
    val_inputpt = val_ds[7]["image2"].unsqueeze(0).to(device)
    val_input = torch.concat([val_inputct, val_inputpt], axis=1)
    roi_size = (192, 192, 192)
    sw_batch_size = 4
    val_output = inference(val_input)
    val_output = post_trans(val_output[0])
    plt.figure("image", (6, 6))
    for i in range(1):
        plt.subplot(1, 2, i + 1)
        plt.title(f"image channel {i}")
        plt.imshow(val_ds[7]["image"][i, :, :, 70].detach().cpu(), cmap="gray")
    plt.show()
    # visualize the 3 channels label corresponding to this image
    plt.figure("label", (18, 6))
    for i in range(3):
        plt.subplot(1, 3, i + 1)
        plt.title(f"label channel {i}")
        plt.imshow(val_ds[7]["label"][i, :, :, 70].detach().cpu())
    plt.show()
    # visualize the 3 channels model output corresponding to this image
    plt.figure("output", (18, 6))
    for i in range(3):
        plt.subplot(1, 3, i + 1)
        plt.title(f"output channel {i}")
        plt.imshow(val_output[i, :, :, 70].detach().cpu())
    plt.show()