In [None]:
import sys, os
import pandas as pd
import numpy as np
from six import raise_from
import pydicom
import matplotlib.pyplot as plt

import warnings 
warnings.filterwarnings("ignore")

In [None]:
!pip install --upgrade git+https://github.com/fizyr/keras-retinanet

In [None]:
import keras
import keras.preprocessing.image
from keras.callbacks import ReduceLROnPlateau

import keras_retinanet.losses
from keras_retinanet.models.resnet import resnet50_retinanet
from keras_retinanet.utils.keras_version import check_keras_version
from keras_retinanet.preprocessing.generator import Generator

### Create datagen

In [None]:
class RSNAGenerator(Generator):
    """Dataset class for training pneumonia detection on the RSNA pneumonia dataset."""
    def __init__(self, annotations_df, data_dir, class_mapping, batch_size, image_extension='.dcm', **kwargs):        
        self.data_dir = data_dir
        self.classes = class_mapping
        self.image_extension = image_extension
        self.annotations_df = annotations_df
        self.labels = {v: k for k, v in self.classes.items()}
        self.image_names = list(annotations_df['patientId'].unique())
        
        for key, value in self.classes.items():
            self.labels[value] = key
        super(RSNAGenerator, self).__init__(batch_size=batch_size)

    def size(self):
        """ Size of the dataset."""
        return len(self.image_names)

    def num_classes(self):
        """ Number of classes in the dataset."""
        return len(self.classes)

    def has_label(self, label):
        """ Return True if label is a known label."""
        return label in self.labels

    def has_name(self, name):
        """ Returns True if name is a known class."""
        return name in self.classes

    def name_to_label(self, name):
        """ Map name to label."""
        return self.classes[name]

    def label_to_name(self, label):
        """ Map label to name."""
        return self.labels[label]

    def image_aspect_ratio(self, image_index):
        fp = os.path.join(self.data_dir, 
            self.image_names[image_index] + self.image_extension)
        ds = pydicom.read_file(fp)
        image = ds.pixel_array
        return float(image.shape[1]) / float(image.shape[0])

    def load_image(self, image_index):
        fp = os.path.join(self.data_dir, 
            self.image_names[image_index] + self.image_extension)
        ds = pydicom.read_file(fp)
        image = ds.pixel_array
        # If grayscale. Convert to RGB for consistency.
        if len(image.shape) != 3 or image.shape[2] != 3:
            image = np.stack((image,) * 3, -1)
        return image
    
    def load_annotations(self, image_index):
        patient_id = self.image_names[image_index]
        image_annotation = self.annotations_df[self.annotations_df['patientId']==patient_id]
        
        annotations = {
            'labels': np.empty((image_annotation.shape[0]),), 
            'bboxes': np.empty((image_annotation.shape[0], 4))}
        
        for idx, (_, element) in enumerate(image_annotation.iterrows()):
            annotations['labels'][idx] = int(element['Target'])
            annotations['bboxes'][idx, 0] = float(element['x'])
            annotations['bboxes'][idx, 1] = float(element['y'])
            annotations['bboxes'][idx, 2] = float(element['x'] + element['width'])
            annotations['bboxes'][idx, 3] = float(element['y'] + element['height'])
        return annotations

### Create model

In [None]:
def load_retinanet(weights, num_classes, freeze=True):
    modifier = freeze if freeze else None
    model = resnet50_retinanet(num_classes=num_classes, modifier=modifier)
    model.load_weights(weights, by_name=True, skip_mismatch=True)
    return model
  
def model_compile(model, lr, clipnorm=1e-5):
    model.compile(
        loss={'regression'    : keras_retinanet.losses.smooth_l1(),
              'classification': keras_retinanet.losses.focal()},
        optimizer=keras.optimizers.adam(lr, clipnorm))
    return model

### Training

In [None]:
%%time
datagen = RSNAGenerator(
    pd.read_csv('../stage_2_train_labels.csv'),
    '../stage_2_train_images/',
    {0:0, 1:1}, 
    batch_size=4)

In [None]:
!wget https://github.com/fizyr/keras-retinanet/releases/download/0.5.0/resnet50_coco_best_v2.1.0.h5
path_to_weights = '../resnet50_coco_best_v2.1.0.h5'

In [None]:
keras.backend.clear_session()
model = load_retinanet(path_to_weights, 2, freeze=False)
model = model_compile(model, lr=1e-5)

In [None]:
history = model.fit_generator(
    datagen,
    steps_per_epoch=100,
    validation_data=datagen,
    validation_steps=10,
    callbacks=None,
    epochs=25,
    verbose=1)

In [None]:
plt.figure(figsize=(17,5))
plt.subplot(131)
plt.plot(history.history["loss"], label="Train loss")
plt.plot(history.history["val_loss"], label="Valid loss")
plt.legend()
plt.subplot(132)
plt.plot(history.history["regression_loss"], label="regression loss")
plt.plot(history.history["val_regression_loss"], label="val regression loss")
plt.legend()
plt.subplot(133)
plt.plot(history.history["classification_loss"], label="classification loss")
plt.plot(history.history["val_classification_loss"], label="val classification loss")
plt.legend()
plt.show()