# How to easily train a V-Net or any other model for lung cancer segmentation.

This is supplemental notebook for medium post in [Data Analysis Center blog](https://medium.com/data-analysis-center)

We suggest using GPU, tested for NVIDIA GTX 1080. Note, that network is quite huge and takes lots of memory.

V-Net is popular CNN architecture for segmentation on volumetric images, such as CT-scans, see [Milletari et al.](https://arxiv.org/abs/1606.04797)

In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
import pandas as pd
import numpy as np
import sys
import tensorflow as tf
sys.path.append('../')

from radio.dataset import Pipeline, FilesIndex, Dataset, F
from radio import CTImagesMaskedBatch as CTIMB


In [3]:
!nvidia-smi

Thu Dec 14 11:57:52 2017       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 375.26                 Driver Version: 375.26                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce GTX 1080    Off  | 0000:02:00.0     Off |                  N/A |
| 23%   49C    P8    15W / 200W |   4487MiB /  8113MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  GeForce GTX 1080    Off  | 0000:03:00.0     Off |                  N/A |
| 30%   52C    P8    15W / 200W |      2MiB /  8112MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                            

Examples in this notebook use [LUNA16 competition dataset](https://luna16.grand-challenge.org/) in MetaImage (mhd/raw) format.

You need to specify mask for '\*.mhd' input files in DIR_LUNA. Here we use unzipped competition dataset, mhd files are stored in subfolders, names of subfolders are taken as ids.

In [4]:
DIR_LUNA = '/notebooks/data/MRT/luna/s*/*.mhd' # Dir with LIDC-IDRI 3D scans

### Setting up the index and the dataset

Index all data and create a Dataset-thing, which conceptually represents all the raw data and let us do the cool thing: iterate data in batches, just like when we are training neural network.

In [5]:
index = FilesIndex(path=DIR_LUNA, no_ext=True)
lunaset = Dataset(index=index, batch_class=CTIMB)

### Preprocessing everything
Here we load LUNA dataset, split it to train and test, normalise values to radiologic Hounsfield Units on-the-fly, resize images (to equalize spacing along each axes for different patients), create masks and crop patches around nodules.

In [11]:
lunaset.cv_split([0.9, 0.1])  # 90 % goes for training

# load annotations into df
nodules = pd.read_csv('/notebooks/data/MRT/luna/CSVFILES/annotations.csv')

pipeline = (Pipeline()
        .load(fmt='raw')  # load scans
        .normalize_hu(-1000, 400)  # normalize hu
        .fetch_nodules_info(nodules=nodules)  # load nodules locations
        .unify_spacing(shape=(128, 256, 256), spacing=(2.5, 2.0, 2.0))  # equalize spacing of different scans
        .create_mask()  # create masks
        .sample_nodules(nodule_size=(32, 64, 64), batch_size=10, share=0.5)  # sample crops
       )

### TF-model for segmentation in a loop

Very often we run NN models in a simple train loop (see: https://www.tensorflow.org/get_started/mnist/mechanics#train_loop)

Say, we define a toy-model with inputs, targets, predictions and train_step as we usually do in TF

In [6]:
inputs = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64])
targets = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64])

session = tf.Session()

# oversimplified model
reshaped = tf.reshape(inputs, shape=(-1, 32, 64, 64, 1))
predictions = tf.reshape(
    tf.layers.conv3d(
        reshaped, kernel_size=(5, 5, 5), padding='same', filters=1),
    shape=(-1, 64, 64))

loss = tf.losses.mean_squared_error(labels=targets, predictions=predictions)
train_step = tf.train.AdamOptimizer().minimize(loss)

session.run(tf.global_variables_initializer())


With RadIO you can flow raw data through our preprocessing pipeline in lazy mode:

In [11]:
lunapipe = (lunaset >> pipeline)

...and then train network on batches:

In [None]:
N_ITERS = 100
for _ in range(N_ITERS):
    batch = lunapipe.next_batch(batch_size=10, n_epochs=None)
    session.run(train_step, feed_dict={inputs: batch.images, targets: batch.masks})
    print(_)

However sometimes it is inconvenient and bulky, for **TF** and **keras** training is done via pipelines.
Consider that we decided to train V-Net in TF, not an unusual thing these days:

In [7]:
from radio.dataset.models.tf.vnet import VNet
from radio.models.tf.losses import dice_loss
from radio import dataset as ds
from radio.models.metrics import dice

Using TensorFlow backend.


Models from [dataset](http://github.com/analysiscenter/dataset) requires specifying input config and model config. In model config we will slightly change the end layers of network to have sigmoid activation and predict only 1 class (cancerous/non-cancerous). Again, training is done on patches of (32, 64, 64) size.

In [8]:
inputs_config = dict(
    images={'shape': (32, 64, 64, 1)},
    labels={'name': 'targets', 'shape': (32, 64, 64, 1)}
)


model_config = dict(
    inputs=inputs_config,
    optimizer='Adam',
    loss=dice_loss,
    build=True
)

model_config['input_block/inputs'] = 'images'
model_config['head/num_classes'] = 1
model_config['head/layout'] = 'ca'
model_config['head/activation'] = tf.nn.sigmoid


Here, network's feed dict is directly specified, see [documentation]() for details.

In [12]:
train_pipeline = (
    pipeline
      .init_model('static', VNet, 'vnet', config=model_config)
      .train_model('vnet', feed_dict={ 
          'images': F(CTIMB.unpack, component='images'),
          'labels': F(CTIMB.unpack, component='segmentation_targets')
      })
) << lunaset

After compiling, you can train it immediately:

In [None]:
train_pipeline.run(1)

<radio.dataset.dataset.pipeline.Pipeline at 0x7f0f50d69358>

In [None]:
import matplotlib.pyplot as plt
import seaborn

model = learning_ppl.get_model_by_name('vnet')
(
    model.train_metrics
         .loc[:, ['dice']]
         .rolling(16)
         .mean()
         .plot(figsize=(10, 7))
)
plt.xlabel('iteration')
plt.ylabel('metric')
plt.grid(True)

If we want to see how our v-net model predict on the whole patient’s scan, hold on, it’s also super_easy with predict_on_scan action:

In [11]:
res_pipe = (
    Pipeline().load(fmt='raw')  # load scans
    .normalize_hu(-1000, 400)  # normalize hu
    .fetch_nodules_info(nodules=nodules)  # load nodules locations
    .unify_spacing(shape=(128, 256, 256), spacing=(2.5, 2.0, 2.0))
    .predict_on_scan(
        model_name='vnet',
        strides=(32, 64, 64),
        batch_size=4,
        dim_ordering='channels_last',
        y_component='labels')
)

In [None]:
batch = (lunaset >> res_pipe).next_batch(6)

In [None]:
visualize_batch(batch)

In [None]:
import ipywidgets
from ipywidgets import interact

In [None]:
def visualize(batch):
    size = len(batch)
    @interact(item=ipywidgets.IntSlider(value=0, min=0, max=size-1),
              height=(0.01, 0.99, 0.01))
    def visualizer(item, height):
        lb, ub = batch.lower_bounds[item], batch.upper_bounds[item]
        return plot_arr_slices(height, batch.images[lb: ub, :, :],
                               batch.masks[lb: ub, :, :], batch.real_masks[lb: ub, :, :])
    return visualizer

In [None]:
def plot_arr_slices(height, *arrays, clim=(180, 255)):
    fig, axes = plt.subplots(1, len(arrays), figsize=(20, len(arrays)*8))
    
    for arr, i in zip(arrays, range(len(arrays))):
        depth = arr.shape[0]
        n_slice = int(depth * height)
        
        kwargs = dict()
#         if np.max(arr) - np.min(arr) > 2.0:
#             kwargs.update(clim=clim)
#         else:
#             kwargs.update(clim=(0, 1))
        clim = (180, 255)
        axes[i].grid(color='w', linestyle='-', linewidth=0.5)
        axes[i].imshow(arr[n_slice], cmap=plt.cm.gray, **kwargs)
    plt.show()