In [1]:
from __future__ import division, print_function

import os, glob, re
import dicom
import numpy as np
from PIL import Image, ImageDraw
from keras import backend as K

import argparse
from math import ceil
from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.filters import gaussian_filter

from keras import utils
from keras.preprocessing import image as keras_image
from keras.preprocessing.image import ImageDataGenerator


ModuleNotFoundError: No module named 'dicom'

In [None]:
def maybe_rotate(image):
    # orient image in landscape
    height, width = image.shape
    return np.rot90(image) if width < height else image

In [None]:
class PatientData(object):
    """Data directory structure (for patient 01):
    directory/
      P01dicom.txt
      P01dicom/
        P01-0000.dcm
        P01-0001.dcm
        ...
      P01contours-manual/
        P01-0080-icontour-manual.txt
        P01-0120-ocontour-manual.txt
        ...
    """
    def __init__(self, directory):
        self.directory = os.path.normpath(directory)

        # get patient index from contour listing file
        glob_search = os.path.join(directory, "P*list.txt")
        files = glob.glob(glob_search)

        self.contour_list_file = files[0]
        match = re.search("P(..)list.txt", self.contour_list_file)
        self.index = int(match.group(1))

        # load all data into memory
        self.load_images()

        # some patients do not have contour data, and that's ok
        try:
            self.load_masks()
        except FileNotFoundError:
            pass

    @property
    def images(self):
        return [self.all_images[i] for i in self.labeled]

    @property
    def dicoms(self):
        return [self.all_dicoms[i] for i in self.labeled]

    @property
    def dicom_path(self):
        return os.path.join(self.directory, "P{:02d}dicom".format(self.index))

    def load_images(self):
        glob_search = os.path.join(self.dicom_path, "*.dcm")
        dicom_files = sorted(glob.glob(glob_search))
        self.all_images = []
        self.all_dicoms = []
        for dicom_file in dicom_files:
            plan = dicom.read_file(dicom_file)
            image = maybe_rotate(plan.pixel_array)
            self.all_images.append(image)
            self.all_dicoms.append(plan)
        self.image_height, self.image_width = image.shape
        self.rotated = (plan.pixel_array.shape != image.shape)

    def load_contour(self, filename):
        # strip out path head "patientXX/"
        match = re.search("patient../(.*)", filename)
        path = os.path.join(self.directory, match.group(1))
        x, y = np.loadtxt(path).T
        if self.rotated:
            x, y = y, self.image_height - x
        return x, y

    def contour_to_mask(self, x, y, norm=255):
        BW_8BIT = 'L'
        polygon = list(zip(x, y))
        image_dims = (self.image_width, self.image_height)
        img = Image.new(BW_8BIT, image_dims, color=0)
        ImageDraw.Draw(img).polygon(polygon, outline=1, fill=1)
        return norm * np.array(img, dtype='uint8')

    def load_masks(self):
        with open(self.contour_list_file, 'r') as f:
            files = [line.strip() for line in f.readlines()]

        inner_files = [path.replace("\\", "/") for path in files[0::2]]
        outer_files = [path.replace("\\", "/") for path in files[1::2]]

        # get list of frames which have contours
        self.labeled = []
        for inner_file in inner_files:
            match = re.search("P..-(....)-.contour", inner_file)
            frame_number = int(match.group(1))
            self.labeled.append(frame_number)

        self.endocardium_contours = []
        self.epicardium_contours = []
        self.endocardium_masks = []
        self.epicardium_masks = []
        for inner_file, outer_file in zip(inner_files, outer_files):
            inner_x, inner_y = self.load_contour(inner_file)
            self.endocardium_contours.append((inner_x, inner_y))
            outer_x, outer_y = self.load_contour(outer_file)
            self.epicardium_contours.append((outer_x, outer_y))

            inner_mask = self.contour_to_mask(inner_x, inner_y, norm=1)
            self.endocardium_masks.append(inner_mask)
            outer_mask = self.contour_to_mask(outer_x, outer_y, norm=1)
            self.epicardium_masks.append(outer_mask)
    """        
    def write_video(self, outfile, FPS=24):
        import cv2
        image_dims = (self.image_width, self.image_height)
        video = cv2.VideoWriter(outfile, -1, FPS, image_dims)
        for image in self.all_images:
            grayscale = np.asarray(image * (255 / image.max()), dtype='uint8')
            video.write(cv2.cvtColor(grayscale, cv2.COLOR_GRAY2BGR))
        video.release()
    """

In [None]:
def load_images(data_dir, mask='both'):
    """Load all patient images and contours from TrainingSet, Test1Set or
    Test2Set directory. The directories and images are read in sorted order.

    Arguments:
      data_dir - path to data directory (TrainingSet, Test1Set or Test2Set)

    Output:
      tuples of (images, masks), both of which are 4-d tensors of shape
      (batchsize, height, width, channels). Images is uint16 and masks are
      uint8 with values 0 or 1.
    """
    assert mask in ['inner', 'outer', 'both']

    glob_search = os.path.join(data_dir, "patient*")
    patient_dirs = sorted(glob.glob(glob_search))

    # load all images into memory (dataset is small)
    images = []
    inner_masks = []
    outer_masks = []
    for patient_dir in patient_dirs:
        p = patient.PatientData(patient_dir)
        images += p.images
        inner_masks += p.endocardium_masks
        outer_masks += p.epicardium_masks

    # reshape to account for channel dimension
    images = np.asarray(images)[:,:,:,None]
    if mask == 'inner':
        masks = np.asarray(inner_masks)
    elif mask == 'outer':
        masks = np.asarray(outer_masks)
    elif mask == 'both':
        # mask = 2 for endocardium, 1 for cardiac wall, 0 elsewhere
        masks = np.asarray(inner_masks) + np.asarray(outer_masks)

    # one-hot encode masks
    dims = masks.shape
    classes = len(set(masks[0].flatten())) # get num classes from first image
    new_shape = dims + (classes,)
    masks = utils.to_categorical(masks).reshape(new_shape)

    return images, masks
"""
def random_elastic_deformation(image, alpha, sigma, mode='nearest',
                               random_state=None):
    Elastic deformation of images as described in [Simard2003]_.
    .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
       Convolutional Neural Networks applied to Visual Document Analysis", in
       Proc. of the International Conference on Document Analysis and
       Recognition, 2003.
    
    assert len(image.shape) == 3

    if random_state is None:
        random_state = np.random.RandomState(None)

    height, width, channels = image.shape

    dx = gaussian_filter(2*random_state.rand(height, width) - 1,
                         sigma, mode="constant", cval=0) * alpha
    dy = gaussian_filter(2*random_state.rand(height, width) - 1,
                         sigma, mode="constant", cval=0) * alpha

    x, y = np.meshgrid(np.arange(height), np.arange(width), indexing='ij')
    indices = (np.repeat(np.ravel(x+dx), channels),
               np.repeat(np.ravel(y+dy), channels),
               np.tile(np.arange(channels), height*width))
    
    values = map_coordinates(image, indices, order=1, mode=mode)

    return values.reshape((height, width, channels))
"""
class Iterator(object):
    def __init__(self, images, masks, batch_size,
                 shuffle=True,
                 rotation_range=180,
                 width_shift_range=0.1,
                 height_shift_range=0.1,
                 shear_range=0.1,
                 zoom_range=0.01,
                 fill_mode='nearest',
                 alpha=500,
                 sigma=20):
        self.images = images
        self.masks = masks
        self.batch_size = batch_size
        self.shuffle = shuffle
        augment_options = {
            'rotation_range': rotation_range,
            'width_shift_range': width_shift_range,
            'height_shift_range': height_shift_range,
            'shear_range': shear_range,
            'zoom_range': zoom_range,
            'fill_mode': fill_mode,
        }
        self.idg = ImageDataGenerator(**augment_options)
        self.alpha = alpha
        self.sigma = sigma
        self.fill_mode = fill_mode
        self.i = 0
        self.index = np.arange(len(images))
        if shuffle:
            np.random.shuffle(self.index)

    def __next__(self):
        return self.next()

    def next(self):
        # compute how many images to output in this batch
        start = self.i
        end = min(start + self.batch_size, len(self.images))

        augmented_images = []
        augmented_masks = []
        for n in self.index[start:end]:
            image = self.images[n]
            mask = self.masks[n]

            _, _, channels = image.shape

            # stack image + mask together to simultaneously augment
            stacked = np.concatenate((image, mask), axis=2)

            # apply simple affine transforms first using Keras
            augmented = self.idg.random_transform(stacked)
            """
            # maybe apply elastic deformation
            if self.alpha != 0 and self.sigma != 0:
                augmented = random_elastic_deformation(
                    augmented, self.alpha, self.sigma, self.fill_mode)
            """

            # split image and mask back apart
            augmented_image = augmented[:,:,:channels]
            augmented_images.append(augmented_image)
            augmented_mask = np.round(augmented[:,:,channels:])
            augmented_masks.append(augmented_mask)

        self.i += self.batch_size
        if self.i >= len(self.images):
            self.i = 0
            if self.shuffle:
                np.random.shuffle(self.index)

        return np.asarray(augmented_images), np.asarray(augmented_masks)

def normalize(x, epsilon=1e-7, axis=(1,2)):
    x -= np.mean(x, axis=axis, keepdims=True)
    x /= np.std(x, axis=axis, keepdims=True) + epsilon

def create_generators(data_dir, batch_size, validation_split=0.0, mask='both',
                      shuffle_train_val=True, shuffle=True, seed=None,
                      normalize_images=True, augment_training=False,
                      augment_validation=False, augmentation_args={}):
    images, masks = load_images(data_dir, mask)

    # before: type(masks) = uint8 and type(images) = uint16
    # convert images to double-precision
    images = images.astype('float64')

    # maybe normalize image
    if normalize_images:
        normalize(images, axis=(1,2))

    if seed is not None:
        np.random.seed(seed)

    if shuffle_train_val:
        # shuffle images and masks in parallel
        rng_state = np.random.get_state()
        np.random.shuffle(images)
        np.random.set_state(rng_state)
        np.random.shuffle(masks)

    # split out last %(validation_split) of images as validation set
    split_index = int((1-validation_split) * len(images))

    if augment_training:
        train_generator = Iterator(
            images[:split_index], masks[:split_index],
            batch_size, shuffle=shuffle, **augmentation_args)
    else:
        idg = ImageDataGenerator()
        train_generator = idg.flow(images[:split_index], masks[:split_index],
                                   batch_size=batch_size, shuffle=shuffle)

    train_steps_per_epoch = ceil(split_index / batch_size)

    if validation_split > 0.0:
        if augment_validation:
            val_generator = Iterator(
                images[split_index:], masks[split_index:],
                batch_size, shuffle=shuffle, **augmentation_args)
        else:
            idg = ImageDataGenerator()
            val_generator = idg.flow(images[split_index:], masks[split_index:],
                                     batch_size=batch_size, shuffle=shuffle)
    else:
        val_generator = None

    val_steps_per_epoch = ceil((len(images) - split_index) / batch_size)

    return (train_generator, train_steps_per_epoch,
            val_generator, val_steps_per_epoch)


In [None]:
def soft_sorensen_dice(y_true, y_pred, axis=None, smooth=1):
    intersection = K.sum(y_true * y_pred, axis=axis)
    area_true = K.sum(y_true, axis=axis)
    area_pred = K.sum(y_pred, axis=axis)
    return (2 * intersection + smooth) / (area_true + area_pred + smooth)

def sorensen_dice_loss(y_true, y_pred, weights):
    # Input tensors have shape (batch_size, height, width, classes)
    # User must input list of weights with length equal to number of classes
    #
    # Ex: for simple binary classification, with the 0th mask
    # corresponding to the background and the 1st mask corresponding
    # to the object of interest, we set weights = [0, 1]
    batch_dice_coefs = soft_sorensen_dice(y_true, y_pred, axis=[1, 2])
    dice_coefs = K.mean(batch_dice_coefs, axis=0)
    w = K.constant(weights) / sum(weights)
    return 1 - K.sum(w * dice_coefs)

In [None]:
def select_optimizer(optimizer_name, optimizer_args):
    optimizers = {
        'sgd': SGD,
        'rmsprop': RMSprop,
        'adagrad': Adagrad,
        'adadelta': Adadelta,
        'adam': Adam,
        'adamax': Adamax,
        'nadam': Nadam,
    }

    return optimizers[optimizer_name](**optimizer_args)

def train():
    #logging.basicConfig(level=logging.INFO)

    #args = opts.parse_arguments()

    #logging.info("Loading dataset...")
    augmentation_args = {
        'rotation_range': args.rotation_range,
        'width_shift_range': args.width_shift_range,
        'height_shift_range': args.height_shift_range,
        'shear_range': args.shear_range,
        'zoom_range': args.zoom_range,
        'fill_mode' : args.fill_mode,
        'alpha': args.alpha,
        'sigma': args.sigma,
    }
    train_generator, train_steps_per_epoch, \
        val_generator, val_steps_per_epoch = dataset.create_generators(
            args.datadir, args.batch_size,
            validation_split=args.validation_split,
            mask=args.classes,
            shuffle_train_val=args.shuffle_train_val,
            shuffle=args.shuffle,
            seed=args.seed,
            normalize_images=args.normalize,
            augment_training=args.augment_training,
            augment_validation=args.augment_validation,
            augmentation_args=augmentation_args)

    # get image dimensions from first batch
    images, masks = next(train_generator)
    _, height, width, channels = images.shape
    _, _, _, classes = masks.shape

    logging.info("Building model...")
    string_to_model = {
        "unet": models.unet,
    }
    model = string_to_model[args.model]
    m = model(height=height, width=width, channels=channels, classes=classes,
              features=32, depth=3, padding='same',
              temperature=1.0, batchnorm=False,
              dropout=0.0)

    m.summary()
"""
    if args.load_weights:
        logging.info("Loading saved weights from file: {}".format(args.load_weights))
        m.load_weights(args.load_weights)
"""
    # instantiate optimizer, and only keep args that have been set
    # (not all optimizers have args like `momentum' or `decay')
    optimizer_args = {
        'lr':       None,
        'momentum': None,
        'decay':    None
    }
    for k in list(optimizer_args):
        if optimizer_args[k] is None:
            del optimizer_args[k]
    optimizer = select_optimizer('adam', optimizer_args)

    # select loss function: pixel-wise crossentropy, soft dice or soft
    if args.loss == 'pixel':
        def lossfunc(y_true, y_pred):
            return loss.weighted_categorical_crossentropy(
                y_true, y_pred, args.loss_weights)
    elif args.loss == 'dice':
        def lossfunc(y_true, y_pred):
            return loss.sorensen_dice_loss(y_true, y_pred, args.loss_weights)

    def dice(y_true, y_pred):
        batch_dice_coefs = loss.sorensen_dice(y_true, y_pred, axis=[1, 2])
        dice_coefs = K.mean(batch_dice_coefs, axis=0)
        return dice_coefs[1]    # HACK for 2-class case

    metrics = ['accuracy', dice]

    m.compile(optimizer=optimizer, loss=lossfunc, metrics=metrics)

    # automatic saving of model during training
    if args.checkpoint:
        if args.loss == 'pixel':
            filepath = os.path.join(
                args.outdir, "weights-{epoch:02d}-{val_acc:.4f}.hdf5")
            monitor = 'val_acc'
            mode = 'max'
        elif args.loss == 'dice':
            filepath = os.path.join(
                args.outdir, "weights-{epoch:02d}-{val_dice:.4f}.hdf5")
            monitor='val_dice'
            mode = 'max'
        checkpoint = ModelCheckpoint(
            filepath, monitor=monitor, verbose=1,
            save_best_only=True, mode=mode)
        callbacks = [checkpoint]
    else:
        callbacks = []

    # train
    logging.info("Begin training.")
    m.fit_generator(train_generator,
                    epochs=args.epochs,
                    steps_per_epoch=train_steps_per_epoch,
                    validation_data=val_generator,
                    validation_steps=val_steps_per_epoch,
                    callbacks=callbacks,
                    verbose=2)

    m.save(os.path.join(args.outdir, args.outfile))

if __name__ == '__main__':
    train()
