In [1]:
from monai.utils import first, set_determinism
from monai.transforms import (
    AddChanneld,
    AsDiscrete,
    AsDiscreted,
    Compose,
    EnsureChannelFirstd,
    EnsureTyped,
    EnsureType,
    Invertd,
    LoadImaged,
    RandFlipd,
    RandSpatialCropd,
    RandZoomd,
    Resized,
    ScaleIntensityRanged,
    SpatialCrop,
    SpatialCropd,
    ToTensord,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
from glob import glob

import warnings
warnings.filterwarnings("ignore")

import numpy as np

import itk

import site
site.addsitedir('../../ARGUS')
from ARGUSUtils_Transforms import *


In [2]:
img1_dir = "../../Data/VFoldData/ColumnData/"

all_images = sorted(glob(os.path.join(img1_dir, '*Class[NS]*[0123456789].mha')))
all_labels = sorted(glob(os.path.join(img1_dir, '*Class[NS]*.roi-overlay.mha')))

print("# of Images =", len(all_images))
print("# of Labels =", len(all_labels))

num_classes = 4

num_columns = 16
num_slices = 48

num_workers_tr = 8
batch_size_tr = 8
num_workers_vl = 8
batch_size_vl = 4

model_filename_base = "BAMC_PTX_3DUNet-4Class.best_model.vfold64-Columns-ROI"

num_images = len(all_images)

num_folds = 10

ns_prefix = ['025ns','026ns','027ns','035ns','048ns','055ns','117ns',
             '135ns','193ns','210ns','215ns','218ns','219ns','221ns','247ns']
s_prefix = ['004s','019s','030s','034s','037s','043s','065s','081s',
            '206s','208s','211s','212s','224s','228s','236s','237s']

# of Images = 90
# of Labels = 90


In [3]:
fold_prefix_list = []
ns_count = 0
s_count = 0
for i in range(num_folds):
    if i%2 == 0:
        num_ns = 2
        num_s = 1
        if i > num_folds-3:
            num_s = 2
    else:
        num_ns = 1
        num_s = 2
    f = []
    for ns in range(num_ns):
        f.append([ns_prefix[ns_count+ns]])
    ns_count += num_ns
    for s in range(num_s):
        f.append([s_prefix[s_count+s]])
    s_count += num_s
    fold_prefix_list.append(f)
        
train_files = []
val_files = []
test_files = []
for i in range(num_folds):
    tr_folds = []
    for f in range(i,i+num_folds-2):
        tr_folds.append(fold_prefix_list[f%num_folds])
    tr_folds = list(np.concatenate(tr_folds).flat)
    va_folds = list(np.concatenate(fold_prefix_list[(i+num_folds-2) % num_folds]).flat)
    te_folds = list(np.concatenate(fold_prefix_list[(i+num_folds-1) % num_folds]).flat)
    train_files.append(
            [
                {"image": img, "label": seg}
                for img, seg in zip(
                    [im for im in all_images if any(pref in im for pref in tr_folds)],
                    [se for se in all_labels if any(pref in se for pref in tr_folds)])
            ]
        )
    val_files.append(
            [
                {"image": img, "label": seg}
                for img, seg in zip(
                    [im for im in all_images if any(pref in im for pref in va_folds)],
                    [se for se in all_labels if any(pref in se for pref in va_folds)])
            ]
        )
    test_files.append(
            [
                {"image": img, "label": seg}
                for img, seg in zip(
                    [im for im in all_images if any(pref in im for pref in te_folds)],
                    [se for se in all_labels if any(pref in se for pref in te_folds)])
            ]
        )
    print(len(train_files[i]),len(val_files[i]),len(test_files[i]))

70 14 6
74 6 10
72 10 8
73 8 9
73 9 8
74 8 8
75 8 7
76 7 7
70 7 13
63 13 14


In [4]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        ScaleIntensityRanged(keys=["image"],
            a_min=0, a_max=255,
            b_min=0.0, b_max=1.0),
        ARGUS_RandSpatialCropSlicesd(
            num_slices=num_slices,
            axis=2,
            keys=['image', 'label']),
        ARGUS_RandSpatialCropSlicesd(
            num_slices=num_columns,
            axis=0,
            require_labeled=True,
            keys=['image', 'label']),
        RandFlipd(prob=0.5, 
            spatial_axis=2,
            keys=['image', 'label']),
        RandFlipd(prob=0.5, 
            spatial_axis=0,
            keys=['image', 'label']),
        RandZoomd(prob=0.5, 
            min_zoom=1.0,
            max_zoom=1.2,
            keys=['image', 'label'],
            keep_size=True,
            mode=['trilinear', 'nearest']),
        ToTensord(keys=["image", "label"]),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        ScaleIntensityRanged(keys=["image"],
            a_min=0, a_max=255,
            b_min=0.0, b_max=1.0),
        ARGUS_RandSpatialCropSlicesd(
            num_slices=num_slices,
            axis=2,
            center_slice=30,
            keys=['image', 'label']),
        ARGUS_RandSpatialCropSlicesd(
            num_slices=num_columns,
            axis=0,
            center_slice=num_columns//2,
            keys=['image', 'label']),
        ToTensord(keys=["image", "label"]),
    ]
)

In [5]:
itkForceDynamicLoading = itk.imread(train_files[0][0]["image"])
print(itkForceDynamicLoading)
arrForceDynamicLoading = itk.GetArrayFromImage(itkForceDynamicLoading)

Image (0x55a6d78d1630)
  RTTI typeinfo:   itk::Image<short, 3u>
  Reference Count: 1
  Modified Time: 442
  Debug: Off
  Object Name: 
  Observers: 
    none
  Source: (none)
  Source output name: (none)
  Release Data: Off
  Data Released: False
  Global Release Data: Off
  PipelineMTime: 246
  UpdateMTime: 441
  RealTimeStamp: 0 seconds 
  LargestPossibleRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [93, 320, 61]
  BufferedRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [93, 320, 61]
  RequestedRegion: 
    Dimension: 3
    Index: [0, 0, 0]
    Size: [93, 320, 61]
  Spacing: [1, 1, 1]
  Origin: [0, 0, 0]
  Direction: 
1 0 0
0 1 0
0 0 1

  IndexToPointMatrix: 
1 0 0
0 1 0
0 0 1

  PointToIndexMatrix: 
1 0 0
0 1 0
0 0 1

  Inverse Direction: 
1 0 0
0 1 0
0 0 1

  PixelContainer: 
    ImportImageContainer (0x55a6d81e97b0)
      RTTI typeinfo:   itk::ImportImageContainer<unsigned long, short>
      Reference Count: 1
      Modified Time: 438
      Debug: Off
      Obj

In [6]:
train_ds = [CacheDataset(data=train_files[i], transform=train_transforms,cache_rate=1.0, num_workers=num_workers_tr)
            for i in range(num_folds)]
train_loader = [DataLoader(train_ds[i], batch_size=batch_size_tr, shuffle=True, num_workers=num_workers_tr) 
                for i in range(num_folds)]

val_ds = [CacheDataset(data=val_files[i], transform=val_transforms, cache_rate=1.0, num_workers=num_workers_vl)
          for i in range(num_folds)]
val_loader = [DataLoader(val_ds[i], batch_size=batch_size_vl, num_workers=num_workers_vl)
              for i in range(num_folds)]

Loading dataset: 100%|██████████████████████| 70/70 [00:01<00:00, 57.34it/s]
Loading dataset: 100%|██████████████████████| 74/74 [00:01<00:00, 60.35it/s]
Loading dataset: 100%|██████████████████████| 72/72 [00:01<00:00, 64.37it/s]
Loading dataset: 100%|██████████████████████| 73/73 [00:01<00:00, 60.50it/s]
Loading dataset: 100%|██████████████████████| 73/73 [00:01<00:00, 63.22it/s]
Loading dataset: 100%|██████████████████████| 74/74 [00:01<00:00, 66.81it/s]
Loading dataset: 100%|██████████████████████| 75/75 [00:01<00:00, 66.48it/s]
Loading dataset: 100%|██████████████████████| 76/76 [00:01<00:00, 56.59it/s]
Loading dataset: 100%|██████████████████████| 70/70 [00:01<00:00, 63.59it/s]
Loading dataset: 100%|██████████████████████| 63/63 [00:01<00:00, 56.87it/s]
Loading dataset: 100%|█████████████████████| 14/14 [00:00<00:00, 125.28it/s]
Loading dataset: 100%|███████████████████████| 6/6 [00:00<00:00, 100.38it/s]
Loading dataset: 100%|████████████████████| 10/10 [00:00<00:00, 1139.38it/s]

In [7]:
device = torch.device("cuda:1")

In [12]:
def vfold_train(vfold_num, train_loader, val_loader):
    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=num_classes,
        channels=(16, 32, 64, 128),
        strides=(2, 2, 2),
        num_res_units=2,
        norm=Norm.BATCH,
    ).to(device)
    loss_function = DiceLoss(to_onehot_y=True, softmax=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-4)
    dice_metric = DiceMetric(include_background=False, reduction="mean")


    max_epochs = 1000
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = []
    metric_values = []
    post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=True, num_classes=num_classes)])
    post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, num_classes=num_classes)])

    root_dir = "."

    for epoch in range(max_epochs):
        print("-" * 10)
        print(f"{vfold_num}: epoch {epoch + 1}/{max_epochs}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = (
                batch_data["image"].to(device),
                batch_data["label"].to(device),
            )
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            print(
                f"{step}/{len(train_ds) // train_loader.batch_size}, "
                f"train_loss: {loss.item():.4f}")
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"{vfold_num} epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                for val_data in val_loader:
                    val_inputs, val_labels = (
                        val_data["image"].to(device),
                        val_data["label"].to(device),
                    )
                    roi_size = (num_columns, 320, num_slices)
                    sw_batch_size = batch_size_vl
                    val_outputs = sliding_window_inference(
                        val_inputs, roi_size, sw_batch_size, model)
                    val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                    val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                    # compute metric for current iteration
                    dice_metric(y_pred=val_outputs, y=val_labels)

                # aggregate the final mean dice result
                metric = dice_metric.aggregate().item()
                # reset the status for next validation round
                dice_metric.reset()

                metric_values.append(metric)
                if metric >= best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), os.path.join(
                        root_dir, model_filename_base+'_'+str(vfold_num)+'.pth'))
                    print("saved new best metric model")
                print(
                    f"current epoch: {epoch + 1} current mean dice: {metric:.8f}"
                    f"\nbest mean dice: {best_metric:.8f} "
                    f"at epoch: {best_metric_epoch}"
                )

    np.save(model_filename_base+"_loss_"+str(vfold_num)+".npy", epoch_loss_values)
    np.save(model_filename_base+"_val_dice_"+str(vfold_num)+".npy", metric_values)

In [13]:
for i in range(0,num_folds):
    vfold_train(i, train_loader[i], val_loader[i])


----------
0: epoch 1/1000
1/1, train_loss: 0.8599
2/1, train_loss: 0.8591
3/1, train_loss: 0.8547
4/1, train_loss: 0.8521
5/1, train_loss: 0.8485
6/1, train_loss: 0.8435
7/1, train_loss: 0.8400
8/1, train_loss: 0.8349
9/1, train_loss: 0.8372
0 epoch 1 average loss: 0.8477
----------
0: epoch 2/1000
1/1, train_loss: 0.8327
2/1, train_loss: 0.8285
3/1, train_loss: 0.8252
4/1, train_loss: 0.8236
5/1, train_loss: 0.8219
6/1, train_loss: 0.8315
7/1, train_loss: 0.8210
8/1, train_loss: 0.8116
9/1, train_loss: 0.8193
0 epoch 2 average loss: 0.8239
saved new best metric model
current epoch: 2 current mean dice: 0.09288500
best mean dice: 0.09288500 at epoch: 2
----------
0: epoch 3/1000
1/1, train_loss: 0.8279
2/1, train_loss: 0.8086
3/1, train_loss: 0.8177
4/1, train_loss: 0.8186
5/1, train_loss: 0.8063
6/1, train_loss: 0.8096
7/1, train_loss: 0.8139
8/1, train_loss: 0.7948
9/1, train_loss: 0.7994
0 epoch 3 average loss: 0.8108
----------
0: epoch 4/1000


Exception ignored in: <function _releaseLock at 0x7f995db1e940>
Traceback (most recent call last):
  File "/home/local/KHQ/stephen.aylward/anaconda3/lib/python3.8/logging/__init__.py", line 223, in _releaseLock
    def _releaseLock():
KeyboardInterrupt: 


RuntimeError: DataLoader worker (pid(s) 2004356, 2004500, 2004574, 2004720, 2004866, 2004938, 2005094) exited unexpectedly