In [1]:
%matplotlib inline

import os
import sys
import tempfile
from glob import glob

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib

from ignite.engine import Events, create_supervised_trainer
from ignite.handlers import ModelCheckpoint

# assumes the framework is found here, change as necessary
sys.path.append("..")

from monai import application, data, networks, utils
from monai.data.readers import NiftiDataset
from monai.data.transforms import AddChannel, Transpose, Rescale, ToTensor, UniformRandomPatch, GridPatchDataset


application.config.print_config()

MONAI version: 0.0.1
Python version: 3.7.3 (default, Mar 27 2019, 22:11:17)  [GCC 7.3.0]
Numpy version: 1.16.4
Pytorch version: 1.3.1
Ignite version: 0.2.1


In [2]:
def create_test_image_3d(height, width, depth, numObjs=12, radMax=30, noiseMax=0.0, numSegClasses=5):
    '''Return a noisy 3D image and segmentation.'''
    image = np.zeros((width, height,depth))

    for i in range(numObjs):
        x = np.random.randint(radMax, width - radMax)
        y = np.random.randint(radMax, height - radMax)
        z = np.random.randint(radMax, depth - radMax)
        rad = np.random.randint(5, radMax)
        spy, spx, spz = np.ogrid[-x:width - x, -y:height - y, -z:depth - z]
        circle = (spx * spx + spy * spy + spz * spz) <= rad * rad

        if numSegClasses > 1:
            image[circle] = np.ceil(np.random.random() * numSegClasses)
        else:
            image[circle] = np.random.random() * 0.5 + 0.5

    labels = np.ceil(image).astype(np.int32)

    norm = np.random.uniform(0, numSegClasses * noiseMax, size=image.shape)
    noisyimage = utils.arrayutils.rescale_array(np.maximum(image, norm))

    return noisyimage, labels

In [3]:
tempdir = tempfile.mkdtemp()

for i in range(50):
    im, seg = create_test_image_3d(256,256,256)
    
    n = nib.Nifti1Image(im, np.eye(4))
    nib.save(n, os.path.join(tempdir, 'im%i.nii.gz'%i))
    
    n = nib.Nifti1Image(seg, np.eye(4))
    nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz'%i))

In [4]:
images = sorted(glob(os.path.join(tempdir,'im*.nii.gz')))
segs = sorted(glob(os.path.join(tempdir,'seg*.nii.gz')))

imtrans=transforms.Compose([
    Rescale(),
    AddChannel(),
    UniformRandomPatch((64, 64, 64)),
    ToTensor()
])    

segtrans=transforms.Compose([
    AddChannel(),
    UniformRandomPatch((64, 64, 64)),
    ToTensor()
])    
    
ds = NiftiDataset(images, segs, imtrans, segtrans)

loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
im, seg = utils.mathutils.first(loader)
print(im.shape, seg.shape)

torch.Size([10, 1, 64, 64, 64]) torch.Size([10, 1, 64, 64, 64])


In [5]:
lr = 1e-3

net = networks.nets.UNet(
    dimensions=3,
    in_channels=1,
    num_classes=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
)

loss = networks.losses.DiceLoss(do_sigmoid=True)
opt = torch.optim.Adam(net.parameters(), lr)

In [7]:
trainEpochs = 30

loss_fn = lambda i, j: loss(i[0], j)
device = torch.device("cuda:0")

trainer = create_supervised_trainer(net, opt, loss_fn, device, False)

checkpoint_handler = ModelCheckpoint('./', 'net', n_saved=10, require_empty=False)
trainer.add_event_handler(
    event_name=Events.EPOCH_COMPLETED,
    handler=checkpoint_handler,
    to_save={'net': net}
)

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_loss(engine):
    print("Epoch", engine.state.epoch, "Loss:", engine.state.output)


loader = DataLoader(ds, batch_size=20, num_workers=8, pin_memory=torch.cuda.is_available())
    
state = trainer.run(loader, trainEpochs)

Epoch 1 Loss: 0.8619852662086487
Epoch 2 Loss: 0.8307779431343079
Epoch 3 Loss: 0.8064168691635132
Epoch 4 Loss: 0.7981672883033752
Epoch 5 Loss: 0.7950631976127625
Epoch 6 Loss: 0.7949732542037964
Epoch 7 Loss: 0.7963427901268005
Epoch 8 Loss: 0.7939450144767761
Epoch 9 Loss: 0.7926643490791321
Epoch 10 Loss: 0.7911991477012634
Epoch 11 Loss: 0.7886414527893066
Epoch 12 Loss: 0.7867528796195984
Epoch 13 Loss: 0.7857398390769958
Epoch 14 Loss: 0.7833380699157715
Epoch 15 Loss: 0.7791398763656616
Epoch 16 Loss: 0.7720394730567932
Epoch 17 Loss: 0.7671006917953491
Epoch 18 Loss: 0.7646064758300781
Epoch 19 Loss: 0.7672612071037292
Epoch 20 Loss: 0.7600041627883911
Epoch 21 Loss: 0.7583478689193726
Epoch 22 Loss: 0.7571365833282471
Epoch 23 Loss: 0.7545363306999207
Epoch 24 Loss: 0.7499511241912842
Epoch 25 Loss: 0.7481640577316284
Epoch 26 Loss: 0.7469437122344971
Epoch 27 Loss: 0.7460543513298035
Epoch 28 Loss: 0.74577796459198
Epoch 29 Loss: 0.7429620027542114
Epoch 30 Loss: 0.74248588