In [1]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.nets
import tqdm

In [2]:
sess = tf.Session()

# Loading NIH metadata

## Raw Metadata

In [3]:
path = '/mnt/storage/data/nih-chest-xrays/'
raw_data = pd.read_csv(os.path.join(path, 'Data_Entry_2017.csv'))

## Extract paths and labels

In [4]:
data = raw_data[['Image Index', 'Finding Labels']].copy()
data.columns = ['image', 'label']
data.image = os.path.join(path, 'images/') + data.image
data = data.sample(frac=1)

## Produce a binary matrix of labels

In [5]:
encoded_labels = data.label.str.get_dummies(sep='|').sort_index(axis=1)

# Read data from Dataset

## Load individual items directly from metadata

In [6]:
dataset = tf.data.Dataset.from_tensor_slices({
    'index': data.index,
    'path': data['image'].values,
    'label': encoded_labels.values.astype(np.float32)
})

## Read and decode the corresponding image files

In [7]:
def read_file(item):
    #item['path'] = tf.Print(item['path'], [item['path']], 'path: ')
    item['image'] = tf.read_file(item['path'])
    return item

def decode_image(item):
    decoded = tf.image.decode_image(item['image'])
    item['image'] = tf.image.convert_image_dtype(decoded, tf.float32)
    # All images are B&W, but some seem to have the channel replicated,
    # to avoid issues we simply select the first channel
    item['image'] = tf.expand_dims(item['image'][:, :, 0], axis=-1)
    item['image'].set_shape([None, None, 1])
    return item

dataset = dataset.map(lambda item: decode_image(read_file(item)), num_parallel_calls=32)

## Prepare data for training

In [8]:
batch_size = 12

dataset = dataset.shuffle(100)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat()  # repeat indefinitely (reshuffled each time)

iterator = dataset.make_one_shot_iterator()
batch = iterator.get_next()

## Build model

In [9]:
def build_model(images, labels, reuse):
    # Create a new head with sigmoids instead of softmax applied to the logits
    _, resnet = tf.contrib.slim.nets.resnet_v2.resnet_v2_152(
        images, num_classes=labels.shape[-1], reuse=reuse
    )
    logits = tf.squeeze(resnet['resnet_v2_152/logits'], axis=[1, 2])
    predictions = tf.sigmoid(logits)
    return tf.losses.sigmoid_cross_entropy(labels, logits)


gpus = range(0, 4)
batch_slice_size = batch_size // len(gpus)
losses = list()
for gpu in gpus:
    i_start = batch_slice_size * gpu
    gpu_slice = slice(i_start, i_start + batch_slice_size)
    
    # Place operations on a GPU and variables on the CPU
    with tf.device('/gpu:%d' % gpu):
        with tf.name_scope('tower_%d' % gpu) as scope:
            with slim.arg_scope([slim.model_variable, slim.variable], device='/cpu:0'):
                losses.append(build_model(
                    batch['image'][gpu_slice],
                    batch['label'][gpu_slice],
                    gpu > 0
                ))

with tf.device('/cpu:0'):
    loss = tf.add_n(losses)

optimizer = tf.train.AdamOptimizer()
train_op = optimizer.minimize(loss, colocate_gradients_with_ops=True)

## Initialize and start training

In [10]:
sess.run(tf.global_variables_initializer())

In [None]:
for epoch in range(100):
    print('Epoch %d' % epoch)
    tq = tqdm.trange(len(data) // batch_size, unit='batch', smoothing=1)
    for i in tq:
        batch_loss, _ = sess.run([loss, train_op])
        tq.set_description('loss: %.3f' % batch_loss)