In [27]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import sys
import glob
import math
sys.path.append('../..')

import numpy as np
import scipy.ndimage
import pandas as pd
import tensorflow as tf

from IPython.display import clear_output

from lung_cancer.dataset import FilesIndex, Dataset, action, model, any_action_failed
from lung_cancer.preprocessing import CTImagesBatch, CTImagesMaskedBatch

In [13]:
from lung_cancer.models.metrics import log_loss, accuracy, tpr, fpr, precision, recall
from lung_cancer.models.keras_resnet import KerasResNet

In [14]:
resnet_model = KerasResNet('keras_resnet')
resnet_model.compile()

keras_model.py[LINE:24]#INFO     [2017-08-21 17:09:12,395]  Building keras model...
keras_model.py[LINE:26]#INFO     [2017-08-21 17:09:16,892]  Keras model was build
keras_model.py[LINE:47]#INFO     [2017-08-21 17:09:16,892]  Compiling keras model...
keras_model.py[LINE:49]#INFO     [2017-08-21 17:09:16,935]  Model was compiled


In [34]:
class CustomBatch(CTImagesMaskedBatch):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.labels = None
    
    @model()
    def keras_resnet():
        return resnet_model

    @action
    def create_labels_by_mask(self, threshold):
        self.labels = np.asarray([np.sum(self.get(i, 'masks')) > threshold
                                 for i in range(len(self))], dtype=np.float)
        return self

    def unpack_data(self, y_component, dim_ordering='channels_last', **kwargs):
        """ Unpack data contained in batch for feeding in model.

        Args:
        - y_component: str, name of y_component to fetch, can be 'masks' or 'labels';
        - dim_ordering: str, can be 'channels_last' or 'channels_first';

        Returns:
        - x, y ndarrays;
        """
        x, y = [], [] if y_component is not None else None
        for i in range(len(self)):
            x.append(self.get(i, 'images'))
            if y_component == 'masks':
                y.append(self.get(i, 'masks'))
            if y_component == 'labels':
                y.append(self.labels[i])
        x, y = np.stack(x), np.stack(y)
        if dim_ordering == 'channels_last':
            x, y = x[..., np.newaxis], y[..., np.newaxis]
        elif dim_ordering == 'channels_first':
            x = x[:, np.newaxis, ...]
            if y_component == 'masks':
                y = y[:, np.newaxis, ...]
        return x, y

    @action
    def train_on_crop(self, model_name, y_component='labels', dim_ordering='channels_last', **kwargs):
        """ Train model on crops of CT-scans contained in batch.

        Args:
        - model_name: str, name of classification model;
        - y_component: str, name of y component, can be 'masks' or 'labels';
        - dim_ordering: str, dimension ordering, can be 'channels_first' or 'channels_last';

        Returns:
        - self, unchanged CTImagesMaskedBatch;
        """
        model = self.get_model_by_name(model_name)
        x, y_true = self.unpack_data(dim_ordering='channels_last',
                                     y_component=y_component, **kwargs)
        model.train_on_batch(x, y_true)
        
        y_pred = model.predict_on_batch(x)
        sys.stdout.write("Log loss on train: " + str(log_loss(y_pred, y_true)) + '\n')
        sys.stdout.write(str(pd.DataFrame(np.array([y_pred.ravel(), y_true.ravel()]).T)))
        sys.stdout.flush()
        clear_output(wait=True)

        return self

    @action
    def predict_on_crop(self, model_name, dst_dict, y_component='labels', dim_ordering='channels_last', **kwargs):
        """ Get predictions of model on crops of CT-scans contained in batch.

        Args:
        - model_name: str, name of classification model;
        - dst_dict: dictionary that will be updated by predictions;

        - y_component: str, name of y component, can be 'masks' or 'labels';
        - dim_ordering: str, dimension ordering, can be 'channels_first' or 'channels_last';

        Returns:
        - self, unchanged CTImagesMaskedBatch;
        """
        model = self.get_model_by_name(model_name)
        x, _ = self.unpack_data(dim_ordering=dim_ordering,
                                y_component=y_component, **kwargs)
        predicted_labels = model.predict_on_batch(x)
        dst_dict.update(zip(self.indices, predicted_labels))
        return self

    @action
    def test_on_crop_dataset(self, model_name, dataset, batch_size, callbacks=(), test_freq=10):
        """ Test model on data contained in batch. """

        if not hasattr(dataset, 'metrics'):
            dataset.metrics = []
        if not hasattr(dataset, 'counter'):
            dataset.counter = 0
        if dataset.counter % 20 == 0:
            model = self.get_model_by_name(model_name)
            y_pred_list, y_true_list = [], []
            for batch in dataset.gen_batch(batch_size):
                batch.load(fmt='blosc').create_labels_by_mask(threshold=10)
                x, batch_y_true = batch.unpack_data(dim_ordering='channels_last', y_component='labels')
                batch_y_pred = np.asarray(model.predict_on_batch(x))

                y_pred_list.append(batch_y_pred.ravel())
                y_true_list.append(batch_y_true.ravel())

            y_pred = np.concatenate(y_pred_list, axis=0)
            y_true = np.concatenate(y_true_list, axis=0)

            metrics_values = {m.__name__: m(y_pred, y_true) for m in callbacks}

            if not hasattr(dataset, 'metrics'):
                dataset.metrics = []
            if not hasattr(dataset, 'counter'):
                dataset.counter = 0
            dataset.metrics.append(metrics_values)
            if dataset.counter % test_freq == 0:
                print(pd.DataFrame(dataset.metrics))
                sys.stdout.flush()
                clear_output(wait=True)

        dataset.counter += 1
        return self

In [35]:
luna_index = FilesIndex(path='/home/kirill/ds_bowl/final_nodules_dump/*', dirs=True)
luna_dataset = Dataset(luna_index, batch_class=CustomBatch)
luna_dataset.cv_split(shares=0.9, shuffle=81)

train_pipeline = \
(
    luna_dataset.train
                .pipeline()
                .load(fmt='blosc', src_blosc=['images', 'masks',
                                              'origin', 'spacing'])
                .create_labels_by_mask(threshold=10)
                .train_on_crop('keras_resnet')
                .test_on_crop_dataset('keras_resnet',
                                      luna_dataset.test, batch_size=4,
                                      callbacks=(log_loss, accuracy),
                                      test_freq=20)
)

In [36]:
train_pipeline.run(16)

<lung_cancer.dataset.dataset.pipeline.Pipeline at 0x7f2174d9fcc0>