# SSBR Baseline

In [60]:
%load_ext autoreload
%autoreload 2

In [61]:
from ssbr.datasets.ircad import IrcadData
from ssbr.datasets.utils import DicomVolumeStore, stack_sampler, SSBRDataset
from ssbr.datasets.ops import grey2rgb, resize, rescale, image2np
from pathlib import Path
import numpy as np
import os
import h5py
from ssbr.model import ssbr_model
from keras.callbacks import ModelCheckpoint

## Data sources

### Dataset

We use IRCAD's liver dataset for training a baseline model. This dataset is composed of 20 abdominal CT scans and a number of segmented organs, mainly the liver.

The `IrcadData` is a utility class which downloads and extracts the data. It also provides a mapping between volume_ids and their corresponding dicom folder.

In [62]:
ircad_folder = Path('../data/ircad')
ircad = IrcadData(ircad_folder)

for k, v in ircad.items():
    print(f'{k} -{v}')


### Data generator

The data source for the SSBR model must provide a 5 dimensional array :

`[BATCH_SIZE, NUM_SLICES, WIDTH, HEIGHT, CHANNEL]`

A sample in this case is a stack of equidistant slices sampled inside the dicom volume. Each image is a 2 dimensional, 3 channel array.

When loading the volumes, a series of volumes transformations are applied before sampling equidistant slices. The transfomed volumes are cached in an HDF5 file to only load slices from disk when needed.

The DicomVolumeStore class provides a mapping between volume_ids and transformed volume (the HDF5 cache interface). This store also manages the split ratio for training and validation.

Finally the SSBRDataset is a generator that yields batches of image stacks.

In [63]:
BATCH_SIZE = 4 # The number of image stacks
NUM_SLICES = 8 # The number of equidistant images per stacks

volume_transforms = [
    resize((64, 64)),
    image2np,
    rescale(low=-300, high=700, scale=255, dtype=np.uint8),
    grey2rgb,
]


cache = h5py.File(str(ircad_folder / 'ircad.h5'), 'a')
volumes = DicomVolumeStore(ircad, transforms=volume_transforms, cache=cache)
dataset = SSBRDataset(volumes=volumes, split=0.2)
datagen_train = dataset.train(batch_size=BATCH_SIZE, num_slices=NUM_SLICES)
datagen_valid = dataset.valid(batch_size=BATCH_SIZE, num_slices=NUM_SLICES)

# Since this training dataset is small, we build the cache right away
volumes.build_cache()

In [71]:
%matplotlib notebook
from ipywidgets import interact
import ipywidgets as widgets
import matplotlib.pyplot as plt

batch, _ = next(datagen_train)
stack = batch[0]

fig = plt.figure()
ax = plt.imshow(stack[0])

slider = widgets.IntSlider(min=0, max=NUM_SLICES-1, step=1, value=0)
def update(sli = slider):
    ax.set_data(stack[sli])

interact(update)


## Model

The model definition provided by `ssbr_model` provides two entrypoints. The first one, `m`, will be the model for the training procedure which requires an input slice stack. The second one, `score_extractor`, will be used at prediction time for evaluating the score for individual slices.

In [72]:
config = {}
MODEL = '../data/model.h5'
m, score_extractor = ssbr_model(lr=config.get('lr', 0.0001),
                                batch_size=BATCH_SIZE,
                                num_slices=NUM_SLICES,
                                alpha=config.get('alpha', 0.5))

In [73]:
os.makedirs(os.path.dirname(MODEL), exist_ok=True)
mcp = ModelCheckpoint(filepath=MODEL,
                      monitor='val_loss',
                      verbose=1,
                      save_best_only=True)

In [85]:
hist = m.fit_generator(generator=datagen_train,
                       steps_per_epoch=50,
                       epochs=30,
                       callbacks=[mcp],
                       validation_data=datagen_valid,
                       validation_steps=30)

In [86]:
# Plot losses
fig = plt.figure()
labels = []
for k, v in hist.history.items():
    labels.append(k)
    plt.plot(v)
plt.legend(labels)
plt.show()

## Model evaluation


In [87]:
from ssbr.datasets.utils import batcher

m.load_weights(MODEL)

results = {}
for vid in volumes:
    print(f'Computing results for volume {vid}')
    vol = volumes[vid]
    scores= []
    for batch in batcher(vol, 10):
        sco = score_extractor.predict_on_batch(batch)
        scores.extend(sco)
    results[vid] = np.asarray(scores)


In [88]:
import matplotlib
%matplotlib inline

In [89]:
import matplotlib.pyplot as plt
for vid, val in results.items():
    plt.plot(val)