<h2>Train and validate a Faster-RCNN detection model</h2>

This notebook demonstrates how to train and validate a detection model using histomics_detect. We first illustrate how to formulate training and validation datasets, then use these with the keras. Different options for inference and training are demonstrated, as well as visualization of results.

In [None]:
import sys
import tensorflow as tf

# install histomics_detect
!pip install -e /tf/notebooks/histomics_detect

# add to system path
sys.path.append("/tf/notebooks/histomics_detect/")

<h2>Build tf.data.Datasets for training and validation</h2>

Download a dataset consisting of training and testing folders. Each sample consists of a paired .png and .csv file defining the locations of objects in the image. A parser function is defined to interpret information like the case and lab identifier from the dataset files for matching file pairs. During training, input images are randomly cropped on the fly to have uniform size.

In [None]:
# import dataset related packages
from histomics_detect.io import dataset
from histomics_detect.augmentation import crop, flip, jitter, shrink
import numpy as np
import os
import pooch

# training parameters
anchor_sizes = [24, 48, 64]  # width/height of square 1:1 anchors in pixels at input mag.
train_tile = 224  # input image size
min_area_thresh = 0.5  # % of object area that must be in random crop to be included
width = tf.constant(train_tile, tf.int32)
height = tf.constant(train_tile, tf.int32)
min_area = tf.constant(min_area_thresh, tf.float32)

# download data and unzip - pooch returns a list of all files in the archive
path = pooch.retrieve(
    url="https://northwestern.box.com/shared/static/m5n9zqnyoxtwb0xc5a08tg2ubmd5uowk",
    known_hash="152dbc5711b3c20adc52abd775072b6607e83572aed71befe9d7609131581e61",
    path=str(pooch.os_cache("pooch")) + os.sep + "data",
    processor=pooch.Unzip(),
)

# parse training and testing paths
sep = os.path.sep
root = sep.join(os.path.split(path[0])[0].split(sep)[0:-1]) + sep

# define filename parser for dataset generation
def parser(file):
    roi = file.split(".")[0]
    lab = file.split("_")[0].split("-")[1]
    return roi, lab


# generate training, validation datasets
ds_train = dataset(root + "train" + sep, parser, parser, train_tile)
ds_validate = dataset(root + "test" + sep, parser, parser, 0)

# build training dataset
ds_train = ds_train.map(lambda x, y, z: (*crop(x, y, width, height, min_area_thresh), z))
ds_train = ds_train.map(lambda x, y, z: (*flip(x, y), z))
ds_train = ds_train.map(lambda x, y, z: (x, jitter(y, 0.05), z))
ds_train = ds_train.map(lambda x, y, z: (x, shrink(y, 0.05), z))
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

# build validation datasets
ds_validate = ds_validate.prefetch(tf.data.experimental.AUTOTUNE)

<h2>Create and train detection model</h2>

Generate a faster-RCNN keras model using default hyperparameters. Assign losses, an optimizer, and fit the model. Basic performance metrics of the region proposal classifier and the box regressors are displayed during training epochs. In validation epochs more extensive metrics including mean average precision are displayed.

In [None]:
# import network generation and training packages
from histomics_detect.models.faster_rcnn import FasterRCNN, faster_rcnn_config

# get default network configurations
backbone_args, rpn_args, frcnn_args, train_args, validation_args = faster_rcnn_config()

# lower non-max suppression iou
validation_args["nms_iou"] = 0.2

# create FasterRCNN keras model
model = FasterRCNN(backbone_args, rpn_args, frcnn_args, train_args, validation_args, anchor_sizes)

# compile FasterRCNN model with losses
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=[tf.keras.losses.Hinge(), tf.keras.losses.Huber()],
)

# fit FasterRCNN model
model.fit(x=ds_train, batch_size=1, epochs=40, verbose=1, validation_data=ds_validate, validation_freq=40)

<h2>Inference using model.call() </h2>

Calling the model directly on an image applies all steps including roialign and non-max suppression. Arguments can be passed to control the sensitivities of the region proposal and the non-maximum suppression.

In [None]:
from histomics_detect.visualization import plot_inference

# generate and visualize thresholded, roialign outputs
data = ds_validate.shuffle(100).take(1).get_single_element()
rgb = tf.cast(data[0], tf.uint8)
regressions = model(rgb, tau=0.5, nms_iou=0.1)
plot_inference(rgb, regressions)

<h2>Pass a margin parameter to clear the border </h2>

You can clear predictions at the border by passing a margin parameter to call. This is helpful when performing inference on a tiled version of a whole-slide image and stitching results from overlapping tiles.

In [None]:
# repeat call providing margin parameter
regressions = model(rgb, tau=0.5, nms_iou=0.3, margin=32)
plot_inference(rgb, regressions)

<h2>Generating intermediate outputs using model.raw()</h2>

Outputs from the region-proposal network can be obtained by calling model.raw. These proposals can be further processed using custom functions or the provided methods for objectness thresholding, non-max suppression, or roialign. These outputs can also be used for test time augmentation where multiple inferences are aggregated prior to non-max suppression.

In [None]:
# generate raw rpn outputs
objectness, boxes, features = model.raw(rgb)

# threshold rpn proposals
boxes_positive, objectness_positive, positive = model.threshold(boxes, objectness, model.tau)

# perform non-max suppression on rpn positive predictions
boxes_nms, objectness_nms, selected = model.nms(boxes_positive, objectness_positive, 0.3)

# generate roialign predictions for rpn positive predictions
align_boxes = model.align(boxes_nms, features, model.field, model.pool, model.tiles)

# apply thresholding, nms, and roialign
plot_inference(rgb, align_boxes)

<h2>Batch inference - performance comparison</h2>

Using .predict is much faster than calling the model() in a list comprehension but combines all results in a single array. We can keep the outputs separated by wrapping the model to add the index of the input sequence to the results. This also allows other metadata to be passed through and captured as outputs (in this case the input image name).

In [None]:
import time

# sample validation dataset
trial_ds = ds_validate.take(10).prefetch(tf.data.experimental.AUTOTUNE)

# mapping model using data.Dataset.map keeps outputs from different images separate
start = time.time()
map_output = [element for element in trial_ds.map(lambda x, y, z: (model(x), y, z))]
print("dataset.map " + str(time.time() - start) + " seconds.")

# compare to using model.predict which merges the outputs from all images
start = time.time()
predict_output = model.predict(trial_ds)
print(".predict " + str(time.time() - start) + " seconds.")

# examine predict output
print(".predict output: " + str(tf.shape(predict_output)))

# define passthrough model
class WrappedModel(tf.keras.Model):
    def __init__(self, model, *args, **kwargs):
        super(WrappedModel, self).__init__(*args, **kwargs)
        self.model = model

    def call(self, inputs, *args, **kwargs):
        boxes = self.model(inputs[0], *args, **kwargs)
        index = tf.cast(inputs[3], tf.float32) * tf.ones((tf.shape(boxes)[0], 1))
        return (tf.concat([boxes, index], 1), inputs[1], inputs[2])


# wrap
wrapped = WrappedModel(model)

# combine model inputs with a tile index value
index_ds = tf.data.Dataset.range(len(trial_ds))
trial_ds = tf.data.Dataset.zip((trial_ds, index_ds))
trial_ds = trial_ds.map(lambda x, y: ((x[0], x[1], x[2], y), None, None))

# generate indexed predictions
start = time.time()
indexed_output = wrapped.predict(trial_ds)
print("wrapped .predict " + str(time.time() - start) + " seconds.")

# tile index column is added to predictions
print(indexed_output[0])

<h2>Save and load as Keras.Model</h2>

Saving as a keras model allows the restored model to be trained for additional cycles, and preserved access to keras functions for inference. Parallel inference with a keras model is also easier than with other formats.

In [None]:
# compute model output shape to trigger build - this shape can be changed after loading
model.compute_output_shape([224, 224, 3])

# save keras model
model.save("tcga_brca_model")

# load model
restored = tf.keras.models.load_model("tcga_brca_model", custom_objects={"FasterRCNN": FasterRCNN})

# check that outputs are same
assert tf.math.reduce_all(tf.math.equal(restored(rgb), model(rgb)))

<h2>Save and load as SavedModel</h2>

The model can also be saved and loaded in the SavedModel format for use with TensorFlow or NVIDIA inference servers.

In [None]:
# compute model output shape to trigger build - this shape can be changed after loading
model.compute_output_shape([224, 224, 3])

# save SavedModel
tf.saved_model.save(model, "tcga_brca_model")

# load model
restored = tf.saved_model.load("tcga_brca_model")
inference = restored.signatures["serving_default"]

# check that outputs are same within a tolerance
tf.debugging.assert_near(inference(tf.cast(rgb, tf.float32))["output_1"], model(rgb), 1e-6)