In [13]:
import warnings 
warnings.filterwarnings("ignore")


from monai.utils import first, set_determinism
from monai.transforms import (
    AddChanneld,
    AsDiscrete,
    AsDiscreted,
    Compose,
    EnsureChannelFirstd,
    EnsureTyped,
    EnsureType,
    Invertd,
    LoadImaged,
    RandFlipd,
    RandSpatialCropd,
    RandZoomd,
    Resized,
    ScaleIntensityRanged,
    SpatialCrop,
    SpatialCropd,
    ToTensord,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
from glob import glob

import numpy as np

import itk

import site
site.addsitedir('../../../ARGUS')
from ARGUSUtils_Transforms import *

In [14]:
img1_dir = "../../../Data/VFoldData/BAMC-PTX*Sliding-Annotations-Linear/"

all_images = sorted(glob(os.path.join(img1_dir, '*_?????.mha')))
all_labels = sorted(glob(os.path.join(img1_dir, '*.interpolated-overlay-3class.mha')))

num_classes = 3
num_workers_tr = 8
batch_size_tr = 8
num_workers_vl = 8
batch_size_vl = 2

num_slices = 16
size_x = 240
size_y = 240

model_filename_base = "BAMC_PTX_3DUNet-3Class.best_model.vfold"

num_images = len(all_images)
print(num_images)
num_folds = 10


ns_prefix = ['025ns','026ns','027ns','035ns','048ns','055ns','117ns',
             '135ns','193ns','210ns','215ns','218ns','219ns','221ns','247ns']
s_prefix = ['004s','019s','030s','034s','037s','043s','065s','081s',
            '206s','208s','211s','212s','224s','228s','236s','237s']

fold_prefix_list = []
ns_count = 0
s_count = 0
for i in range(num_folds):
    if i%2 == 0:
        num_ns = 2
        num_s = 1
        if i > num_folds-3:
            num_s = 2
    else:
        num_ns = 1
        num_s = 2
    f = []
    for ns in range(num_ns):
        f.append([ns_prefix[ns_count+ns]])
    ns_count += num_ns
    for s in range(num_s):
        f.append([s_prefix[s_count+s]])
    s_count += num_s
    fold_prefix_list.append(f)
        
train_files = []
val_files = []
test_files = []
for i in range(num_folds):
    tr_folds = []
    for f in range(i,i+num_folds-2):
        tr_folds.append(fold_prefix_list[f%num_folds])
    tr_folds = list(np.concatenate(tr_folds).flat)
    va_folds = list(np.concatenate(fold_prefix_list[(i+num_folds-2) % num_folds]).flat)
    te_folds = list(np.concatenate(fold_prefix_list[(i+num_folds-1) % num_folds]).flat)
    train_files.append(
            [
                {"image": img, "label": seg}
                for img, seg in zip(
                    [im for im in all_images if any(pref in im for pref in tr_folds)],
                    [se for se in all_labels if any(pref in se for pref in tr_folds)])
            ]
        )
    val_files.append(
            [
                {"image": img, "label": seg}
                for img, seg in zip(
                    [im for im in all_images if any(pref in im for pref in va_folds)],
                    [se for se in all_labels if any(pref in se for pref in va_folds)])
            ]
        )
    test_files.append(
            [
                {"image": img, "label": seg}
                for img, seg in zip(
                    [im for im in all_images if any(pref in im for pref in te_folds)],
                    [se for se in all_labels if any(pref in se for pref in te_folds)])
            ]
        )
    print(len(train_files[i]),len(val_files[i]),len(test_files[i]))

62
49 8 5
51 5 6
50 6 6
50 6 6
51 6 5
51 5 6
50 6 6
49 6 7
48 7 7
47 7 8


In [15]:
img = itk.imread(train_files[0][0]["image"])
arr = itk.GetArrayFromImage(img)
imgshape = list(arr.shape)

In [16]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        ScaleIntensityRanged(keys=["image"],
            a_min=0, a_max=255,
            b_min=0.0, b_max=1.0),
        ARGUS_RandSpatialCropSlicesd(num_slices=num_slices,
            axis=2,
            keys=['image', 'label']),
        RandFlipd(prob=0.5, 
            spatial_axis=2,
            keys=['image', 'label']),
        RandFlipd(prob=0.5, 
            spatial_axis=0,
            keys=['image', 'label']),
        RandZoomd(prob=0.5, 
            min_zoom=1.0,
            max_zoom=1.2,
            keep_size=True,
            mode=['trilinear', 'nearest'],
            keys=['image', 'label']),
        RandSpatialCropd(
            roi_size=(size_x, size_y, num_slices),
            random_size=False,
            keys=['image', 'label']),
        ToTensord(keys=["image", "label"]),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        ScaleIntensityRanged(keys=["image"],
            a_min=0, a_max=255,
            b_min=0.0, b_max=1.0),
        ARGUS_RandSpatialCropSlicesd(num_slices=num_slices,
            axis=2,
            center_slice=30,
            keys=['image', 'label']),
        SpatialCropd(
            roi_start=((imgshape[0]-size_x)//2, (imgshape[1]-size_y)//2, 0),
            roi_end=((imgshape[0]-size_x)//2+size_x, (imgshape[1]-size_y)//2+size_y, num_slices),
            keys=['image', 'label']),
        ToTensord(keys=["image", "label"]),
    ]
)

In [17]:
train_ds = [CacheDataset(data=train_files[i], transform=train_transforms,cache_rate=1.0, num_workers=num_workers_tr)
            for i in range(num_folds)]
train_loader = [DataLoader(train_ds[i], batch_size=batch_size_tr, shuffle=True, num_workers=num_workers_tr) 
                for i in range(num_folds)]

val_ds = [CacheDataset(data=val_files[i], transform=val_transforms, cache_rate=1.0, num_workers=num_workers_vl)
          for i in range(num_folds)]
val_loader = [DataLoader(val_ds[i], batch_size=batch_size_vl, num_workers=num_workers_vl)
              for i in range(num_folds)]

Loading dataset: 100%|██████████████████████████| 49/49 [00:00<00:00, 81.71it/s]
Loading dataset: 100%|██████████████████████████| 51/51 [00:00<00:00, 80.51it/s]
Loading dataset: 100%|██████████████████████████| 50/50 [00:00<00:00, 81.51it/s]
Loading dataset: 100%|██████████████████████████| 50/50 [00:00<00:00, 80.41it/s]
Loading dataset: 100%|██████████████████████████| 51/51 [00:00<00:00, 76.76it/s]
Loading dataset: 100%|██████████████████████████| 51/51 [00:02<00:00, 22.89it/s]
Loading dataset: 100%|██████████████████████████| 50/50 [00:03<00:00, 14.60it/s]
Loading dataset: 100%|██████████████████████████| 49/49 [00:03<00:00, 12.26it/s]
Loading dataset: 100%|██████████████████████████| 48/48 [00:03<00:00, 13.43it/s]
Loading dataset: 100%|██████████████████████████| 47/47 [00:04<00:00, 10.55it/s]
Loading dataset: 100%|███████████████████████████| 8/8 [00:00<00:00, 153.87it/s]
Loading dataset: 100%|████████████████████████████| 5/5 [00:00<00:00, 91.92it/s]
Loading dataset: 100%|██████

In [18]:
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
# In[7]:


def vfold_train(vfold_num, train_loader, val_loader):
    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=num_classes,
        channels=(16, 32, 64, 128),
        strides=(2, 2, 2),
        num_res_units=2,
        norm=Norm.BATCH,
    ).to(device)
    loss_function = DiceLoss(to_onehot_y=True, softmax=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-4)
    dice_metric = DiceMetric(include_background=False, reduction="mean")


    max_epochs = 500
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = []
    metric_values = []
    post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=True, num_classes=num_classes)])
    post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, num_classes=num_classes)])

    root_dir = "."

    for epoch in range(max_epochs):
        print("-" * 10)
        print(f"{vfold_num}: epoch {epoch + 1}/{max_epochs}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = (
                batch_data["image"].to(device),
                batch_data["label"].to(device),
            )
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            print(
                f"{step}/{len(train_ds) // train_loader.batch_size}, "
                f"train_loss: {loss.item():.4f}")
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"{vfold_num} epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                for val_data in val_loader:
                    val_inputs, val_labels = (
                        val_data["image"].to(device),
                        val_data["label"].to(device),
                    )
                    roi_size = (size_x, size_y, num_slices)
                    sw_batch_size = batch_size_vl
                    val_outputs = sliding_window_inference(
                        val_inputs, roi_size, sw_batch_size, model)
                    val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                    val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                    # compute metric for current iteration
                    dice_metric(y_pred=val_outputs, y=val_labels)

                # aggregate the final mean dice result
                metric = dice_metric.aggregate().item()
                # reset the status for next validation round
                dice_metric.reset()

                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), os.path.join(
                        root_dir, model_filename_base+'_'+str(vfold_num)+'.pth'))
                    print("saved new best metric model")
                print(
                    f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                    f"\nbest mean dice: {best_metric:.4f} "
                    f"at epoch: {best_metric_epoch}"
                )

    np.save(model_filename_base+"_loss_"+str(vfold_num)+".npy", epoch_loss_values)
    np.save(model_filename_base+"_val_dice_"+str(vfold_num)+".npy", metric_values)

In [19]:
for i in range(2,num_folds):
    vfold_train(i, train_loader[i], val_loader[i])

----------
2: epoch 1/500
1/1, train_loss: 0.7795
2/1, train_loss: 0.7835
3/1, train_loss: 0.7898
4/1, train_loss: 0.7807
5/1, train_loss: 0.7897
6/1, train_loss: 0.7716
7/1, train_loss: 0.7737
2 epoch 1 average loss: 0.7812
----------
2: epoch 2/500
1/1, train_loss: 0.7725
2/1, train_loss: 0.7761
3/1, train_loss: 0.7597
4/1, train_loss: 0.7629
5/1, train_loss: 0.7669
6/1, train_loss: 0.7657
7/1, train_loss: 0.7634
2 epoch 2 average loss: 0.7668
saved new best metric model
current epoch: 2 current mean dice: 0.0144
best mean dice: 0.0144 at epoch: 2
----------
2: epoch 3/500
1/1, train_loss: 0.7735
2/1, train_loss: 0.7531
3/1, train_loss: 0.7654
4/1, train_loss: 0.7543
5/1, train_loss: 0.7536
6/1, train_loss: 0.7471
7/1, train_loss: 0.7790
2 epoch 3 average loss: 0.7609
----------
2: epoch 4/500


Exception ignored in: <function _releaseLock at 0x7f1d22e2a940>
Traceback (most recent call last):
  File "/home/local/KHQ/stephen.aylward/anaconda3/lib/python3.8/logging/__init__.py", line 223, in _releaseLock
    def _releaseLock():
KeyboardInterrupt: 
Exception ignored in: <function _releaseLock at 0x7f1d22e2a940>
Traceback (most recent call last):
  File "/home/local/KHQ/stephen.aylward/anaconda3/lib/python3.8/logging/__init__.py", line 223, in _releaseLock
    def _releaseLock():
KeyboardInterrupt: 


RuntimeError: DataLoader worker (pid(s) 82462, 82534, 82606, 82678) exited unexpectedly