In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import sys
import glob
import math
sys.path.append('../..')

import numpy as np
import scipy.ndimage
import pandas as pd
import tensorflow as tf

import ipywidgets
from ipywidgets import interact
from IPython.display import clear_output

from lung_cancer.dataset import FilesIndex, Dataset, action, model, any_action_failed
from lung_cancer.preprocessing import CTImagesBatch, CTImagesMaskedBatch

In [None]:
from lung_cancer.models.metrics import log_loss, accuracy, tpr, fpr, precision, recall
from lung_cancer.models.keras_resnet import KerasResNet

In [None]:
resnet_model = KerasResNet('keras_resnet')
resnet_model.compile()

In [None]:
class CustomBatch(CTImagesMaskedBatch):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.labels = None
    
    @model()
    def keras_resnet():
        return resnet_model
    
    def visualize(self, component):
        size = len(self)
        @interact(item=ipywidgets.IntSlider(value=0, min=0, max=size-1), slc=ipywidgets.FloatSlider(0, min=0, max=0.99, step=0.01))
        def visualizer(item, slc):
            if component == 'nodule':
                image = self.get(item, 'images') * self.get(item, 'masks')
            else:
                image = self.get(item, component)
            image_slice = int(slc * image.shape[0])
            plt.imshow(image[image_slice, :, :])
            plt.show()
        return visualizer

    @action
    def train_on_crop(self, model_name, y_component='labels', dim_ordering='channels_last', **kwargs):
        """ Train model on crops of CT-scans contained in batch.

        Args:
        - model_name: str, name of classification model;
        - y_component: str, name of y component, can be 'masks' or 'labels';
        - dim_ordering: str, dimension ordering, can be 'channels_first' or 'channels_last';

        Returns:
        - self, unchanged CTImagesMaskedBatch;
        """
        super().train_on_crop(model_name, y_component, dim_ordering)
        model = self.get_model_by_name(model_name)
        x, y_true = self.unpack_data(dim_ordering=dim_ordering, y_component=y_component)
        y_pred = model.predict_on_batch(x)
        sys.stdout.write("Log loss on train: " + str(log_loss(y_pred, y_true)) + '\n')
        sys.stdout.write(str(pd.DataFrame(np.array([y_pred.ravel(), y_true.ravel()]).T)))
        sys.stdout.flush()
        clear_output(wait=True)

        return self

    @action
    def test_on_crop_dataset(self, model_name, dataset, batch_size, callbacks=(), test_freq=10):
        """ Test model on data contained in batch. """

        if not hasattr(dataset, 'metrics'):
            dataset.metrics = []
        if not hasattr(dataset, 'counter'):
            dataset.counter = 0
        if dataset.counter % 20 == 0:
            model = self.get_model_by_name(model_name)
            y_pred_list, y_true_list = [], []
            for batch in dataset.gen_batch(batch_size):
                batch.load(fmt='blosc').create_labels_by_mask(threshold=10)
                x, batch_y_true = batch.unpack_data(dim_ordering='channels_last', y_component='labels')
                batch_y_pred = np.asarray(model.predict_on_batch(x))

                y_pred_list.append(batch_y_pred.ravel())
                y_true_list.append(batch_y_true.ravel())

            y_pred = np.concatenate(y_pred_list, axis=0)
            y_true = np.concatenate(y_true_list, axis=0)

            metrics_values = {m.__name__: m(y_pred, y_true) for m in callbacks}

            if not hasattr(dataset, 'metrics'):
                dataset.metrics = []
            if not hasattr(dataset, 'counter'):
                dataset.counter = 0
            dataset.metrics.append(metrics_values)
            if dataset.counter % test_freq == 0:
                print(pd.DataFrame(dataset.metrics))
                sys.stdout.flush()
                clear_output(wait=True)

        dataset.counter += 1
        return self

## Pipeline for training classification model

In [2]:
luna_index = FilesIndex(path='/data/final_nodules_dump/*', dirs=True)
luna_dataset = Dataset(luna_index, batch_class=CustomBatch)
luna_dataset.cv_split(shares=0.9, shuffle=81)

train_pipeline = \
(
    luna_dataset.train
                .pipeline()
                .load(fmt='blosc', src_blosc=['images', 'masks',
                                              'origin', 'spacing'])
    
                .create_labels_by_mask(threshold=10)
                .train_on_crop(model_name='keras_resnet',
                               y_component='labels',
                               dim_ordering='channels_last')

                .test_on_crop_dataset('keras_resnet',
                                      luna_dataset.test, batch_size=4,
                                      callbacks=(log_loss, accuracy),
                                      test_freq=20)
)

### Run pipeline with batch_size=16

In [36]:
train_pipeline.run(batch_size=16, epochs=10)

<lung_cancer.dataset.dataset.pipeline.Pipeline at 0x7f2174d9fcc0>

## Pipeline for getting predictions on full scans

In [None]:
NODULES_DF = pd.read_csv('/notebooks/data/MRT/luna/CSVFILES/annotations.csv')
scans_index = FilesIndex(path= '/notebooks/data/MRT/luna/s*/*.mhd', no_ext=True)
scans_dataset = Dataset(scans_index, batch_class=CustomBatch)
scans_result = {}

scans_pipeline = \
(
        scans_dataset.pipeline()
                     .load(fmt='raw')
                     .normalize_hu()
                     .fetch_nodules_info(nodules_df=NODULES_DF)
                     .create_mask()
                     .unify_spacing(shape=(256, 256, 256), spacing=(1.3, 1.3, 1.3), padding='reflect')
                     .predict_on_scan('keras_resnet', y_component='labels', strides=(32, 64, 64))
)

### Run pipeline on 4 full scans

In [None]:
batch = scans_pipeline.next_batch(4)

In [None]:
batch.visualize('images')

In [None]:
batch.visualize('masks')