# Identify Pneumothorax Disease in Chest X-Rays, Inference Notebook

More information about training the networks can be found [here]( https://github.com/reyvaz/pneumothorax_detection)

#### Uses a 2-step approach for the identification of pneumothorax disease on the test dataset images. 

The 1st step attempts to classify x-rays as presenting pneumothorax disease or not. To do so, it uses an ensemble of EfficientNet based image classifiers. The ensemble predictions are the simple average across all classifiers in the ensemble. 

In the 2nd step, if the image was classified as likely having the disease in step 1, it tries to identify the location of the disease within the x-ray image. To do this, it uses an ensemble of Unet and Unet++ segmentation CNNs, all with EfficientNet encoders. The mask predictions are the simple average predictions across all CNNs in the the ensemble.

**Credits**:

This notebook was inspired by Siddhartha’s [Unet Plus Plus with EfficientNet Encoder](https://www.kaggle.com/meaninglesslives/nested-unet-with-efficientnet-encoder) notebook.

**References**:

Kaiming He, Xiangyu Zhang, Shaoqing Ren, & Jian Sun. (2015). Deep Residual Learning for Image Recognition.

Mingxing Tan, & Quoc V. Le. (2020). EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks.

Olaf Ronneberger, Philipp Fischer, & Thomas Brox. (2015). U-Net: Convolutional Networks for Biomedical Image Segmentation.

Zhou, Z., Siddiquee, M., Tajbakhsh, N., & Liang, J. (2019). UNet++: Redesigning Skip Connections to Exploit Multiscale Features in Image Segmentation IEEE - Transactions on Medical Imaging.

## Required Packages

In [None]:
import os, sys, re
from time import time, strftime, gmtime
start_notebook = time()

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from kaggle_datasets import KaggleDatasets

import tensorflow as tf
import tensorflow.keras.layers as L

!git clone -q https://github.com/reyvaz/tpu_segmentation.git
!pip config set global.disable-pip-version-check true >/dev/null
!pip install -qr tpu_segmentation/requirements.txt >/dev/null
from tpu_segmentation import *

import mask_functions as mf

print('Tensorflow version: ', tf.__version__)
AUTO = tf.data.experimental.AUTOTUNE 

## Distribution Strategy

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None
    print('TPU not found')
if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu) 
else:
    strategy = tf.distribute.get_strategy()

## Constants and TFRecs File Paths

In [None]:
IMAGE_SIZE = [1024, 1024] # Original size of the images
N_CLASSES = 1
N_CHANNELS = 1
N_REPLICAS = strategy.num_replicas_in_sync

gcs_path = KaggleDatasets().get_gcs_path('siimacr-pneumothorax-segmentation-tfrecs')
TFRECS_TEST = tf.io.gfile.glob(gcs_path + '/tfrecs/*test*.tfrec')
n_test_examples = count_data_items(TFRECS_TEST)
print('Number of TEST TFRecs: ', len(TFRECS_TEST))
print('Number of TEST examples: ', n_test_examples)

## Dataset Pipeline

In [None]:
def read_test_tfrecord(example, str_feat):
    features = {
        str_feat: tf.io.FixedLenFeature([], tf.string)
        }
    example = tf.io.parse_single_example(example, features)
    return example[str_feat]
        
def load_test_dataset(filenames, str_feat = 'image'):
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
    dataset = dataset.map(lambda x: read_test_tfrecord(x, str_feat), 
                          num_parallel_calls=AUTO)
    return dataset

def decode_resize_image(image_data, target_size, image_size = IMAGE_SIZE,
                        make_rgb = True, n_channels = N_CHANNELS):

    image = tf.image.decode_jpeg(image_data, channels=n_channels)
    image = tf.cast(image, tf.float32) / 255.0  
    if target_size != image_size: image = tf.image.resize(image, target_size)
    if make_rgb:
        image = tf.image.grayscale_to_rgb(image)
        n_channels = 3
    return tf.reshape(image, [*target_size, n_channels]) 

describe_ds = lambda x: print(re.sub('[<>]', '', str(x)))

def get_test_dataset(filenames, target_size, imgs_per_replica, make_rgb = True):
    
    batch_size = imgs_per_replica * N_REPLICAS
    n_test = count_data_items(filenames)
    min_steps = np.ceil(n_test/batch_size).astype(int)

    dataset = load_test_dataset(filenames)
    dataset = dataset.map(lambda image: decode_resize_image(
        image, target_size, make_rgb=make_rgb), AUTO)
    
    dataset = dataset.batch(batch_size).prefetch(AUTO) 
    describe_ds(dataset)
    return dataset, n_test, min_steps

## Visualize Test Examples

In [None]:
n_rows, n_cols = 2, 5

temp_dataset = get_test_dataset(TFRECS_TEST[:1], (256, 256), 1, make_rgb = False)[0]
temp_dataset = temp_dataset.unbatch().take(n_rows*n_cols)

fig, axs = plt.subplots(n_rows, n_cols, figsize=(25, 4*n_rows))
for c, item in enumerate(temp_dataset):
    ax = fig.axes[c]
    ax.imshow(item, cmap=plt.cm.bone)
    ax.axis('off')

## Pretrained Weights Metadata

In [None]:
base_dir = '../input/pneumothorax-segmentation-base/'
bin_meta = pd.read_csv(base_dir + 'bin_weights_meta.csv')
seg_meta = pd.read_csv(base_dir + 'seg_weights_meta.csv')
weigths_meta = bin_meta.append(seg_meta, ignore_index = True)
weights_names = dict(zip(weigths_meta.key, weigths_meta.filename))

## Functions to Build Ensembles

In [None]:
weights_dir = base_dir + 'weights/'

def load_pretrained_model(weights_id, compile_model = True, opt = [], loss = [], 
                          metrics = 'default', details = []):

    prefix, size = weights_id.split('_')
    size = eval(size.replace('x', ', '))
    base, base_ver, model_type, _  = prefix.split('-')
    wname = weights_names[weights_id]
    weights_path = weights_dir + wname

    base_name = base + base_ver

    if 'bin' in model_type: 
        builder = build_classifier
        if metrics == 'default': metrics = ['accuracy', 'AUC']
            
    elif 'unetpp' in model_type: builder = xnet
    elif 'unet' in model_type and not 'unetpp' in model_type: builder = unet
    
    if 'unet' in model_type and metrics == 'default': 
        metrics = [dice_coef, dice_avg]
    
    with strategy.scope():
        model = builder(base_name, 1, input_shape=(*size, 3), weights = None)
        model.load_weights(weights_path)
        if compile_model: model.compile(optimizer=opt, loss=loss, metrics=metrics)
            
    if len(details) > 0: 
        scope = locals()
        return (model, *[eval(d, scope) for d in details])
    else: return model

def assemble_ensemble(weights_ids, outter_size = (1024, 1024), 
                      ensemble_type = 'binary', metrics = 'default'):
    
    ensemble_outputs = []
    resized_inputs = {}
    with strategy.scope():
        x = L.Input(shape=(*outter_size, 3))
        for i, w in enumerate(weights_ids):
            model, size = load_pretrained_model(w, compile_model = False, details = ['size'])
            model._name = '{}-M{}'.format(model.name, i)
            if size == outter_size:
                model_output = model(x)
            else:
                if not str(size) in resized_inputs: 
                    resized_inputs[str(size)] = tf.image.resize(x, size)
                model_output = model(resized_inputs[str(size)])
                if ensemble_type == 'segmentation':
                    model_output = tf.image.resize(model_output, outter_size)

            ensemble_outputs.append(model_output)

        y = L.Average(name = 'Simple_Average')(ensemble_outputs)

        if metrics == 'default':
            if ensemble_type == 'segmentation': metrics = [dice_coef, dice_avg] 
            else: metrics = ['accuracy', 'AUC']

        name = '{}_Ensemble'.format(ensemble_type.title())
        ensemble = tf.keras.Model(inputs=x, outputs=y, name=name)
        ensemble.compile(optimizer=[], loss=[], metrics=metrics)
    return ensemble

## Retrieving Ordered Image IDs from TFRecs

In [None]:
ids_ds = load_test_dataset(TFRECS_TEST, str_feat = 'img_id').batch(512)
test_ids_bytes = []
for item in ids_ds:
    test_ids_bytes += list(item.numpy())
assert len(test_ids_bytes) == 3205

test_ids = [i.decode() for i in test_ids_bytes]
print('Num Test Examples: ', len(test_ids))

# Binary Predictions

### Dataset for Binary Predictions

In [None]:
filenames = TFRECS_TEST
target_size = (1024, 1024)
imgs_per_replica = 8

test_ds, n_test, min_steps = get_test_dataset(filenames, target_size, imgs_per_replica)

### Build Ensemble for Binary Predictions

CNNs were trained across 5 cross-validation folds. The ensemble for binary predictions will include the top 2 best performing classifiers corresponding to each of the k-folds used during training. 

In [None]:
temp = bin_meta.groupby('fold')['metric'].nlargest(2)
idxs = [i[1] for i in temp.index]
binary_members = bin_meta.loc[idxs].reset_index(drop=True)

binary_ensemble_members = binary_members.key.values
print('Number of Binary Ensemble Members: {}'.format(len(binary_ensemble_members)))

binary_ensemble = assemble_ensemble(binary_ensemble_members)
binary_ensemble.summary()

### Perform Binary Predictions

In [None]:
start_binary_preds = time()
bin_preds = binary_ensemble.predict(test_ds, verbose=1)

time_binary_preds = time() - start_binary_preds
min_secs = lambda secs: strftime("%M:%S", gmtime(secs))
print('Time to make {} binary predictions: {} (MM:SS)'.format(n_test, min_secs(time_binary_preds)))

In [None]:
bin_preds = bin_preds.squeeze()
binary_probs = dict(zip(test_ids, bin_preds))
binary_preds_df = pd.DataFrame(binary_probs.items(), columns = ['ImageId', 'pred_prob'])

display(binary_preds_df.head())

In [None]:
del binary_ensemble

# Mask Predictions

### Build Segmentation Ensemble

In [None]:
size = (544, 544)
segmentation_ensemble_members = seg_meta.key.values
print('Number of Segmentation Ensemble Members: {}'.format(len(segmentation_ensemble_members)))

segmentation_ensemble = assemble_ensemble(segmentation_ensemble_members,
                                          outter_size = size, ensemble_type = 'segmentation' )
segmentation_ensemble.summary()

## Mask Prediction Dataset(s)

Unlike for the binary predictions, there's not sufficient memory to store the predicted masks for the entire test dataset. These mask predicitions are required for post-processing before encoding into RLE. 

Although mask predictions are not required for the entire test dataset (i.e. mask predictions for images predicited negative for pneumothorax are not needed), in order to take advantage of the TPU during inference, mask predictions will be done for all test images. 

For this, the test dataset will be divided into several parts. Predictions will sequentially be made for each part, post-processed, and rle-encoded.

In [None]:
test_ds, n_test, total_steps = get_test_dataset(filenames, size, imgs_per_replica)
print('Total prediction steps: ', total_steps)

n_parts = 3
len_part = np.ceil(min_steps/n_parts).astype(int)
print('Max steps per part:', len_part)

ds_remain = test_ds
test_ds_parts = []
for d in range(n_parts):
    dset = ds_remain.take(len_part)
    test_ds_parts.append(dset)
    ds_remain = ds_remain.skip(len_part)

## Predict Mask, Post-Process, Encode and Record RLE

In [None]:
binary_treshhold = 0.60
thresh_max = 0.75
thresh_min = 0.40
min_area = 200

pred_rles = dict([[id, -1] for id in test_ids])
prelim_masked = binary_preds_df.ImageId[binary_preds_df.pred_prob > binary_treshhold].values
start_mask_preds = time()

i = 0
masked_ids = []
for p, test_ds_part in enumerate(test_ds_parts): 
    print('\nPredicting and processing part {} of {}'.format(p+1, n_parts))
    preds = segmentation_ensemble.predict(test_ds_part, verbose = 1)
    preds = np.squeeze(preds)
    print('Shape of predictions matrix {}: {}\n'.format(p+1, preds.shape))

    for pred in preds:
        test_id = test_ids[i]
        if binary_probs[test_id] > binary_treshhold:
            pred_ = pred.copy()
            pred  = (pred > thresh_max).astype(int)
            if pred.sum() > min_area: 
                pred = (pred_ > thresh_min).astype(int)
                pred = np.expand_dims(pred, axis = 2)
                pred_mask = tf.image.resize(pred, IMAGE_SIZE)
                pred_mask = np.squeeze(pred_mask)
                pred_mask = (np.round(pred_mask)*255).astype(int)
                mask_rle = mf.mask2rle(pred_mask.T, *IMAGE_SIZE)
                pred_rles[test_id] = mask_rle
                masked_ids.append(test_id)
        i += 1
    del preds, pred, pred_, pred_mask

time_mask_preds = time() - start_mask_preds
del segmentation_ensemble
print('Time to predict and post-process {} images: {} (MM:SS)'.format(n_test, min_secs(time_mask_preds)))

### Visualize Predicted Examples

In [None]:
skip = 100
demo_ds = test_ds_parts[0].unbatch().skip(skip)

n_rows = 2
n_cols = 5

masked_examples = {}
unmasked_examples = {}
max_examples = n_rows*n_cols

for i, image in enumerate(demo_ds):
    test_id = test_ids[i+skip]
    if test_id in masked_ids and len(masked_examples) < max_examples:
        masked_examples[test_id] = image
    elif not test_id in masked_ids and len(unmasked_examples) < max_examples: 
        unmasked_examples[test_id] = image
    if len(masked_examples) == len(unmasked_examples) == max_examples: break


#### Predictions with Pneumothorax Disease

In [None]:
fig, axs = plt.subplots(n_rows, n_cols, figsize=(25, 4*n_rows))
for c, (img_id, image) in enumerate(masked_examples.items()):
    image = tf.image.resize(image, IMAGE_SIZE)
    
    mask = mf.rle2mask(pred_rles[img_id], *IMAGE_SIZE)/255
    mask = contoured_mask(mask.T, rgb_color = (200, 0, 150), alpha = 0.35)

    ax = fig.axes[c]
    ax.imshow(image, cmap=plt.cm.bone)
    ax.imshow(mask)
    ax.axis('off')
    ax.set_title('Image ID: {}'.format(img_id), fontdict={'fontsize': 13})

#### Predictions with no Pneumothorax Disease

In [None]:
fig, axs = plt.subplots(n_rows, n_cols, figsize=(25, 4*n_rows))
for c, (img_id, image) in enumerate(unmasked_examples.items()):
    ax = fig.axes[c]
    ax.imshow(image, cmap=plt.cm.bone)
    ax.axis('off')
    ax.set_title('Image ID: {}'.format(img_id), fontdict={'fontsize': 13})

## Create Submission File

In [None]:
sub_df = pd.DataFrame(pred_rles.items(), columns=['ImageId', 'EncodedPixels'])
sub_df.to_csv('submission.csv', index=False)

In [None]:
!rm -rf tpu_segmentation mask_functions.py __pycache__

In [None]:
time_notebook = time() - start_notebook
print('Time to run notebook {} (MM:SS)'.format(min_secs(time_notebook)))