In [1]:
import os
from glob import glob

import warnings

import torch
from monai.transforms import (
    Compose,
    LoadImaged,
    ToTensord,
    EnsureChannelFirstd ,
    Orientationd,
    Spacingd,
    ScaleIntensityRanged,
    CropForegroundd,
    Resized,
    SaveImaged,
    RandSpatialCropd,
    RandFlipd,
    RandRotated,
    ToDeviced
)
# import torch.mps

import numpy as np
import pandas as pd
import nibabel as nib
from pathlib import Path
from collections.abc import Callable, Sequence, Hashable
from typing import Mapping,Dict

from monai.data import Dataset, DataLoader
from monai.utils import first
import matplotlib.pyplot as plt
from monai.data.meta_tensor import MetaTensor
from monai.config.type_definitions import NdarrayOrTensor
from monai.utils.misc import ImageMetaKey

In [2]:
import os
import shutil
import tempfile
import time
import matplotlib.pyplot as plt
from monai.config import print_config
from monai.data import DataLoader, decollate_batch
from monai.handlers.utils import from_engine
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.networks.nets import SegResNet
from monai.transforms import (
    Activations,
    Activationsd,
    AsDiscrete,
    AsDiscreted,
    Compose,
    Invertd,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    EnsureTyped,
    EnsureChannelFirstd,
)
from monai.utils import set_determinism
from monai.config import print_config
print_config()

MONAI version: 1.1.0
Numpy version: 1.24.1
Pytorch version: 2.0.0+cu118
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: a2ec3752f54bfc3b40e7952234fbeb5452ed63e3
MONAI __file__: /home/maximus/DataspellProjects/Varian/venv/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.1.0
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
Pillow version: 9.3.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.15.1+cu118
tqdm version: NOT INSTALLED or UNKNOWN VERSION.
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.5
pandas version: 2.0.1
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.

For details about instal

In [3]:
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
mps_available

False

In [4]:
root_dir = '/home/maximus/DataspellProjects/Varian_Hecktor'
data_dir = 'data'

train_images_ct = sorted(glob(os.path.join(data_dir, 'TrainData', '*_CT.nii.gz')))
train_images_pt = sorted(glob(os.path.join(data_dir, 'TrainData', '*_PT.nii.gz')))
train_labels = sorted(glob(os.path.join(data_dir, 'TrainLabels', '*.nii.gz')))

trai_files = [{"image": image_name, "image2": pet_image, 'label': label_name} for image_name, pet_image, label_name in zip(train_images_ct, train_images_pt, train_labels)]

val_images_ct = sorted(glob(os.path.join(data_dir, 'ValData', '*_CT.nii.gz')))
val_images_pt = sorted(glob(os.path.join(data_dir, 'ValData', '*_PT.nii.gz')))
val_labels = sorted(glob(os.path.join(data_dir, 'ValLabels', '*.nii.gz')))
val_files = [{"image": image_name, "image2": pet_image, 'label': label_name} for image_name, pet_image, label_name in zip(val_images_ct, val_images_pt, val_labels)]

In [5]:
train_files = []
train_files.append(trai_files[0])
train_files.append(trai_files[1])
val_files = val_files[0]

In [6]:
tfiles = []
tfiles.append(train_files[3])
tfiles.append(train_files[1])
tfiles

IndexError: list index out of range

In [7]:
train_files
val_files

{'image': 'data/ValData/CHUM-010__CT.nii.gz',
 'image2': 'data/ValData/CHUM-010__PT.nii.gz',
 'label': 'data/ValLabels/CHUM-010.nii.gz'}

In [12]:
class HecktorCropNeckRegion(CropForegroundd):
    """
    A simple pre-processing transform to approximately crop the head and neck region based on a PET image.
    This transform relies on several assumptions of patient orientation with a head location on the top,
    and is specific for Hecktor22 dataset, and should not be used for an arbitrary PET image pre-processing.
    """

    def __init__(
        self,
        keys=["image", "image2", "label"],
        source_key="image",
        box_size=[200, 200, 310],
        allow_missing_keys=True,
        **kwargs,
    ) -> None:
        super().__init__(keys=keys, source_key=source_key, allow_missing_keys=allow_missing_keys, **kwargs)
        self.box_size = box_size

    def __call__(self, data : Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:

        d = dict(data)
        im_pet = d["image2"][0]
        #print(im_pet)
        box_size = np.array(self.box_size)  # H&N region to crop in mm , defaults to 200x200x310mm
        filename = ""

        if isinstance(im_pet, MetaTensor):
            filename = im_pet.meta[ImageMetaKey.FILENAME_OR_OBJ]
            box_size = (box_size / np.array(im_pet.pixdim)).astype(int)  # compensate for resolution

        box_start, box_end = self.extract_roi(im_pet=im_pet, box_size=box_size)
        
        if "label" in d and "label" in self.keys:
            # if label mask is available, let's check if the cropped region includes all foreground
            before_sum = d["label"].sum().item()
            after_sum = (
                (d["label"][0, box_start[0] : box_end[0], box_start[1] : box_end[1], box_start[2] : box_end[2]])
                .sum()
                .item()
            )
            if before_sum != after_sum:
                print("WARNING, H&N crop could be incorrect!!!", before_sum, after_sum)

        d[self.start_coord_key] = box_start
        d[self.end_coord_key] = box_end
        
        for key, m in self.key_iterator(d, self.mode): #question: what is mode in the iterators?
            self.push_transform(d, key, extra_info={"box_start": box_start, "box_end": box_end})
            d[key] = self.cropper.crop_pad(img=d[key], box_start=box_start, box_end=box_end, mode=m)
        return d

    def extract_roi(self, im_pet, box_size):

        crop_len = int(0.75 * im_pet.shape[2])
        im = im_pet[..., crop_len:]

        mask = ((im - im.mean()) / im.std()) > 1
        comp_idx = torch.argwhere(mask)
        center = torch.mean(comp_idx.float(), dim=0).cpu().int().numpy()
        xmin = torch.min(comp_idx, dim=0).values.cpu().int().numpy()
        xmax = torch.max(comp_idx, dim=0).values.cpu().int().numpy()

        xmin[:2] = center[:2] - box_size[:2] // 2
        xmax[:2] = center[:2] + box_size[:2] // 2

        xmax[2] = xmax[2] + crop_len
        xmin[2] = max(0, xmax[2] - box_size[2])

        return xmin.astype(int), xmax.astype(int)

In [16]:
device = torch.device("cuda:0")
train_transforms = Compose(
    [
        LoadImaged(keys=['image', 'image2', 'label']),
        ToDeviced(keys=["image", "image2", "label"], device=device),
        EnsureChannelFirstd(['image', 'image2', 'label']),
        Orientationd(keys=["image", "image2"], axcodes="RAS"),
        Spacingd(
            keys=["image", "image2"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        HecktorCropNeckRegion(keys=["image", "image2", "label"], source_key="image"),
        RandSpatialCropd(keys=["image", "image2", "label"], roi_size=[192, 192, 192], random_size=False),
        RandFlipd(keys=["image", "image2", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "image2", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "image2", "label"], prob=0.5, spatial_axis=2),
#         RandRotated(keys=["image", "image2", "label"], prob = 0.5),
        SaveImaged(
        keys = ['image', 'image2', 'label'],
        output_dir='data/output',
        output_postfix="crop",
        resample=False,
        output_dtype=np.int16,
        separate_folder=False,
        )
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=['image', 'image2', 'label']),
        ToDeviced(keys=["image", "image2", "label"], device=device),
        EnsureChannelFirstd(['image', 'image2', 'label']),
        Orientationd(keys=["image", "image2"], axcodes="RAS"),
        Spacingd(
            keys=["image", "image2"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        HecktorCropNeckRegion(keys=["image", "image2", "label"], source_key="image")
    ]
)

In [17]:
train_ds = Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=1)

val_ds = Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1)

In [18]:
for batch_data in train_loader:
    batch_data

2023-04-24 10:29:24,120 INFO image_writer.py:194 - writing: data/output/CHUM-001__CT_crop.nii.gz
2023-04-24 10:29:24,376 INFO image_writer.py:194 - writing: data/output/CHUM-001__PT_crop.nii.gz
2023-04-24 10:29:24,436 INFO image_writer.py:194 - writing: data/output/CHUM-001_crop.nii.gz
2023-04-24 10:29:25,189 INFO image_writer.py:194 - writing: data/output/CHUM-002__CT_crop.nii.gz
2023-04-24 10:29:25,394 INFO image_writer.py:194 - writing: data/output/CHUM-002__PT_crop.nii.gz
2023-04-24 10:29:25,453 INFO image_writer.py:194 - writing: data/output/CHUM-002_crop.nii.gz


In [20]:
max_epochs = 1
val_interval = 1
VAL_AMP = True

# standard PyTorch program style: create SegResNet, DiceLoss and Adam optimizer

model = SegResNet(
    blocks_down=[1, 2, 2, 4, 4],
    init_filters=16,
    in_channels=1,
    out_channels=2,
).to(device)
loss_function = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)

dice_metric = DiceMetric(include_background=True, reduction="mean")
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")

post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])


# define inference method
def inference(input):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=(192, 192, 192),
            sw_batch_size=1,
            predictor=model,
            overlap=0.5,
        )

    if VAL_AMP:
        with torch.cuda.amp.autocast():
            return _compute(input)
    else:
        return _compute(input)

# use amp to accelerate training
scaler = torch.cuda.amp.GradScaler()
print(scaler)
# enable cuDNN benchmark
torch.backends.cudnn.benchmark = True

<torch.cuda.amp.grad_scaler.GradScaler object at 0x7ff29adca350>


In [None]:
best_metric = -1
best_metric_epoch = -1
best_metrics_epochs_and_time = [[], [], []]
epoch_loss_values = []
metric_values = []
metric_values_tc = []
metric_values_wt = []
metric_values_et = []

total_start = time.time()
for epoch in range(max_epochs):
    epoch_start = time.time()
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    print("Model Training complete")
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        print("I am here")
        step_start = time.time()
        step += 1
        # inputs, labels = (
        #     batch_data["image"].to(device),
        #     batch_data["label"].to(device),
        # )
        # print("I am here")
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(batch_data["image"])
            loss = loss_function(outputs, batch_data["label"])
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()
        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}"
            f", train_loss: {loss.item():.4f}"
            f", step time: {(time.time() - step_start):.4f}"
        )
    lr_scheduler.step()
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                val_outputs = inference(val_inputs)
                val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                dice_metric(y_pred=val_outputs, y=val_labels)
                dice_metric_batch(y_pred=val_outputs, y=val_labels)

            metric = dice_metric.aggregate().item()
            metric_values.append(metric)
            metric_batch = dice_metric_batch.aggregate()
            metric_tc = metric_batch[0].item()
            metric_values_tc.append(metric_tc)
            metric_wt = metric_batch[1].item()
            metric_values_wt.append(metric_wt)
            metric_et = metric_batch[2].item()
            metric_values_et.append(metric_et)
            dice_metric.reset()
            dice_metric_batch.reset()

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                best_metrics_epochs_and_time[0].append(best_metric)
                best_metrics_epochs_and_time[1].append(best_metric_epoch)
                best_metrics_epochs_and_time[2].append(time.time() - total_start)
                torch.save(
                    model.state_dict(),
                    os.path.join(root_dir, "best_metric_model.pth"),
                )
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
                f"\nbest mean dice: {best_metric:.4f}"
                f" at epoch: {best_metric_epoch}"
            )
    print(f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}")
total_time = time.time() - total_start

----------
epoch 1/1
Model Training complete
2023-04-24 10:31:08,786 INFO image_writer.py:194 - writing: data/output/CHUM-001__CT_crop.nii.gz
2023-04-24 10:31:09,036 INFO image_writer.py:194 - writing: data/output/CHUM-001__PT_crop.nii.gz
2023-04-24 10:31:09,097 INFO image_writer.py:194 - writing: data/output/CHUM-001_crop.nii.gz
