In [None]:
%matplotlib inline

import os, sys
from functools import partial

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from ignite.engine import Events, create_supervised_trainer

# assumes the framework is found here, change as necessary
sys.path.append("..")

from monai import application, data, networks, utils
import monai.data.augments.augments as augments

application.config.print_config()

Download the downsampled segmented images from the Sunnybrook Cardiac Dataset. This is a simple low-res dataset I put together for a workshop. The task is to segment the left ventricle in the image which shows up as an annulus. 

In [None]:
! [ ! -f scd_lvsegs.npz ] &&  wget -q https://github.com/ericspod/VPHSummerSchool2019/raw/master/scd_lvsegs.npz

Create the reader to bring the images in, these are initially in uint16 format with no channels:

In [None]:
imSrc = data.readers.NPZReader("scd_lvsegs.npz", ["images", "segs"], other_values=data.streams.OrderType.CHOICE)

Define a stream to convert the image format, apply some basic augments using multiple threads, and buffer the stream behind a thread so that batching can be done in parallel with the training process.

In [None]:
def normalizeImg(im, seg):
    im = utils.arrayutils.rescale_array(im)
    im = im[None].astype(np.float32)
    seg = seg[None].astype(np.int32)
    return im, seg


augs = [
    normalizeImg,
    augments.rot90,
    augments.transpose,
    augments.flip,
    partial(augments.shift, dim_fract=5, order=0, nonzero_index=1),
]

src = data.augments.augmentstream.ThreadAugmentStream(imSrc, 200, augments=augs)
src = data.streams.ThreadBufferStream(src)

im, seg = utils.mathutils.first(src)
print(im.shape, im.dtype, seg.shape, seg.dtype)
plt.imshow(np.hstack([im[0, 0], seg[0, 0]]))

Define the network, loss, and optimizer:

In [None]:
lr = 1e-3

net = networks.nets.UNet(
    dimensions=2,
    in_channels=1,
    num_classes=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
)

loss = networks.losses.DiceLoss()
opt = torch.optim.Adam(net.parameters(), lr)

Train using an Ignite Engine:

In [None]:
trainSteps = 100
trainEpochs = 20
trainSubsteps = 1


def _prepare_batch(batch, device=None, non_blocking=False):
    x, y = batch
    return torch.from_numpy(x).to(device), torch.from_numpy(y).to(device)


loss_fn = lambda i, j: loss(i[0], j)

trainer = create_supervised_trainer( net, opt, loss_fn, torch.device("cuda:0"), False, _prepare_batch)


@trainer.on(Events.EPOCH_COMPLETED)
def log_training_loss(engine):
    print("Epoch", engine.state.epoch, "Loss:", engine.state.output)


fsrc = data.streams.FiniteStream(
    src, trainSteps
)  # finite stream to train only for as many steps as we specify
state = trainer.run(fsrc, trainEpochs)

In [None]:
im, seg = utils.mathutils.first(imSrc)
testim = utils.arrayutils.rescale_array(im[None, None])

pred = net.cpu()(torch.from_numpy(testim))

pseg = pred[1].data.numpy()

plt.imshow(np.hstack([testim[0, 0], pseg[0]]))