In [9]:
%matplotlib inline

import os
import sys
import tempfile
from glob import glob

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.handlers import ModelCheckpoint, EarlyStopping

# assumes the framework is found here, change as necessary
sys.path.append("..")


import monai.data.transforms.compose as transforms
from monai import application, data, networks, utils
from monai.data.readers import NiftiDataset
from monai.data.transforms import AddChannel, Transpose, Rescale, ToTensor, UniformRandomPatch, GridPatchDataset
from monai.networks.metrics.mean_dice import MeanDice
from monai.utils.stopperutils import stopping_fn_from_metric


application.config.print_config()

MONAI version: 0.0.1
Python version: 3.8.1 (default, Jan  8 2020, 22:29:32)  [GCC 7.3.0]
Numpy version: 1.18.1
Pytorch version: 1.4.0
Ignite version: 0.3.0


In [10]:
def create_test_image_3d(height, width, depth, numObjs=12, radMax=30, noiseMax=0.0, numSegClasses=5):
    '''Return a noisy 3D image and segmentation.'''
    image = np.zeros((width, height,depth))

    for i in range(numObjs):
        x = np.random.randint(radMax, width - radMax)
        y = np.random.randint(radMax, height - radMax)
        z = np.random.randint(radMax, depth - radMax)
        rad = np.random.randint(5, radMax)
        spy, spx, spz = np.ogrid[-x:width - x, -y:height - y, -z:depth - z]
        circle = (spx * spx + spy * spy + spz * spz) <= rad * rad

        if numSegClasses > 1:
            image[circle] = np.ceil(np.random.random() * numSegClasses)
        else:
            image[circle] = np.random.random() * 0.5 + 0.5

    labels = np.ceil(image).astype(np.int32)

    norm = np.random.uniform(0, numSegClasses * noiseMax, size=image.shape)
    noisyimage = utils.arrayutils.rescale_array(np.maximum(image, norm))

    return noisyimage, labels

In [11]:
tempdir = tempfile.mkdtemp()

for i in range(50):
    im, seg = create_test_image_3d(256,256,256)
    
    n = nib.Nifti1Image(im, np.eye(4))
    nib.save(n, os.path.join(tempdir, 'im%i.nii.gz'%i))
    
    n = nib.Nifti1Image(seg, np.eye(4))
    nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz'%i))

In [12]:
images = sorted(glob(os.path.join(tempdir,'im*.nii.gz')))
segs = sorted(glob(os.path.join(tempdir,'seg*.nii.gz')))

imtrans=transforms.Compose([
    Rescale(),
    AddChannel(),
    UniformRandomPatch((64, 64, 64)),
    ToTensor()
])    

segtrans=transforms.Compose([
    AddChannel(),
    UniformRandomPatch((64, 64, 64)),
    ToTensor()
])    
    
ds = NiftiDataset(images, segs, imtrans, segtrans)

loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
im, seg = utils.mathutils.first(loader)
print(im.shape, seg.shape)

torch.Size([10, 1, 64, 64, 64]) torch.Size([10, 1, 64, 64, 64])


In [13]:
lr = 1e-3

net = networks.nets.UNet(
    dimensions=3,
    in_channels=1,
    num_classes=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
)

loss = networks.losses.DiceLoss(do_sigmoid=True)
opt = torch.optim.Adam(net.parameters(), lr)

In [14]:
trainEpochs = 30

loss_fn = lambda i, j: loss(i[0], j)
device = torch.device("cuda:0")

trainer = create_supervised_trainer(net, opt, loss_fn, device, False)

checkpoint_handler = ModelCheckpoint('./', 'net', n_saved=10, require_empty=False)
trainer.add_event_handler(
    event_name=Events.EPOCH_COMPLETED,
    handler=checkpoint_handler,
    to_save={'net': net}
)

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_loss(engine):
    print("Epoch", engine.state.epoch, "Loss:", engine.state.output)


loader = DataLoader(ds, batch_size=20, num_workers=8, pin_memory=torch.cuda.is_available())
val_loader = DataLoader(ds, batch_size=20, num_workers=8, pin_memory=torch.cuda.is_available())
    


In [15]:
validation_every_n_epochs = 1

val_metrics = {'Mean Dice': MeanDice(add_sigmoid=True)}
evaluator = create_supervised_evaluator(net, val_metrics, device, True,
                                        output_transform=lambda x, y, y_pred: (y_pred[0], y))


early_stopper = EarlyStopping(patience=4, 
                              score_function=stopping_fn_from_metric('Mean Dice'),
                              trainer=trainer)
evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)

@evaluator.on(Events.EPOCH_COMPLETED)
def log_validation_metrics(engine):
    for name, value in engine.state.metrics.items():
        print("Validation --", name, ":", value)

@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
def run_validation(engine):
    evaluator.run(val_loader)



In [16]:
state = trainer.run(loader, trainEpochs)

Epoch 1 Loss: 0.8975554704666138
Validation -- Mean Dice : 0.11846490800380707
Epoch 2 Loss: 0.8451039791107178
Validation -- Mean Dice : 0.12091563045978546
Epoch 3 Loss: 0.9355515241622925
Validation -- Mean Dice : 0.12139833569526673
Epoch 4 Loss: 0.843208909034729
Validation -- Mean Dice : 0.12108306288719177
Epoch 5 Loss: 0.8225834965705872
Validation -- Mean Dice : 0.12179622799158096
Epoch 6 Loss: 0.957372784614563
Validation -- Mean Dice : 0.12193384170532226
Epoch 7 Loss: 0.9011092782020569
Validation -- Mean Dice : 0.1230143740773201
Epoch 8 Loss: 0.8651387691497803
Validation -- Mean Dice : 0.1254110112786293
Epoch 9 Loss: 0.8767974972724915
Validation -- Mean Dice : 0.12633273899555206
Epoch 10 Loss: 0.8193061947822571
Validation -- Mean Dice : 0.12657881826162337
Epoch 11 Loss: 0.9466649293899536
Validation -- Mean Dice : 0.12699378579854964
Epoch 12 Loss: 0.8258659243583679
Validation -- Mean Dice : 0.12790720015764237
Epoch 13 Loss: 0.8661612868309021
Validation -- Mean 