# 'PatchCamelyon' image classification using Keras


### About the ML task and dataset

This notebook shows an example of training an _image classification_ [Keras](https://keras.io/) model.

The notebook **works best with GPU(s)** -- it runs fine using only CPUs, but training takes a longer time. Given the size of the dataset and model architecture, this example requires a 2-core notebook VM, and the notebook should use an attached GPU to run in a reasonable time frame.  On Terra, you can use the default GATK image customized to use **2 CPUs and 1 GPU**.

The [PatchCamelyon benchmark](https://www.tensorflow.org/datasets/catalog/patch_camelyon) consists of 327,680 color images (96 x 96px) extracted from histopathologic scans of lymph node sections. Each image is annotated with a
binary label indicating presence of metastatic tissue. 

The model uses one of Keras' prebuilt model architectures, [Xception](https://keras.io/api/applications/xception/). The training does [_transfer learning_](https://en.wikipedia.org/wiki/Transfer_learning) , bootstrapping from model weights trained on the ['imagenet'](https://en.wikipedia.org/wiki/ImageNet) dataset, then runs a [fine-tuning](https://d2l.ai/chapter_computer-vision/fine-tuning.html) stage.

<img src="https://storage.googleapis.com/tfds-data/visualization/fig/patch_camelyon-2.0.0.png" width="60%">

You can use this notebook as a template for experimenting with image classification on your own image data.
<!-- (**TBD**: more on how to do this.) -->

## Do some imports and set some variables

In [None]:
import os
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf
import tensorflow_datasets as tfds
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
from tensorflow import keras
from tensorflow.keras import layers

print(tf.__version__)

Get your workspace GCS bucket using Workspace Data.  

In [None]:
if ("GOOGLE_PROJECT" in os.environ):  # This env var is set when running in a Terra workspace
    from firecloud import api as fapi

    WORKSPACE_NAME = os.environ["WORKSPACE_NAME"]
    WORKSPACE_NAMESPACE = os.environ["WORKSPACE_NAMESPACE"]
    WORKSPACE_BUCKET = os.environ["WORKSPACE_BUCKET"]
else:
    print("Not running on Terra: you will need to set your GCP bucket manually.")

In [None]:
BUCKET = WORKSPACE_BUCKET
print(BUCKET)

In [None]:
TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
print(TIMESTAMP)

## Create the tissue datasets

This process will take a while. We'll download to the persistent disk, so that you only need to do the download once (per PD).

In [None]:
# load the input data from tensorflow_datasets
ds, ds_info = tfds.load(
    "patch_camelyon",
    with_info=True,
    as_supervised=True,
    data_dir="/home/jupyter/tensorflow_datasets",
)

# get the train, validation and test datasets
train_data = ds["train"]
valid_data = ds["validation"]
test_data = ds["test"]

In [None]:
print(ds_info)

In [None]:
# shuffle the train_data
buffer_size = 1000
train_data = train_data.shuffle(buffer_size)

# batch and prefetch
batch_size = 32
train_data = train_data.batch(batch_size).prefetch(1)
valid_data = valid_data.batch(batch_size).prefetch(1)
test_data = test_data.batch(batch_size).prefetch(1)

We can view a few of the images:

In [None]:
for images, labels in train_data.take(3):
    plt.figure(figsize=(4, 4))
    first_image = images[0]
    plt.imshow(first_image.numpy().astype("int32"))
    plt.axis("off")

## Define a Keras image classification model

In this section, we'll define the Keras model that we'll use for training. We'll use [transfer learning](https://en.wikipedia.org/wiki/Transfer_learning) for this example, starting with a model— the [Xception](https://keras.io/api/applications/xception/) convolutional neural network architecture — that has been trained on [ImageNet](https://www.image-net.org/) data, and adding some additional layers to that model. We'll 'freeze' the Xception base model, so that its weights don't change during training; only the weights of our new layers will change.

In [None]:
def get_compiled_model():
    base_model = keras.applications.Xception(
        weights="imagenet", input_shape=(96, 96, 3), include_top=False
    )

    base_model.trainable = False

    inputs = keras.Input(shape=(96, 96, 3))

    x = layers.Rescaling(1.0 / 255)(inputs)
    x = base_model(x, training=False)
    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dropout(0.2)(x)
    outputs = keras.layers.Dense(2, activation="softmax")(x)

    model = keras.Model(inputs, outputs)
    loss = tf.keras.losses.SparseCategoricalCrossentropy()

    # model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=1e-4),
        loss=loss,
        metrics=["accuracy"],
    )
    return (base_model, model)

In [None]:
strategy = tf.distribute.MirroredStrategy()
print("Number of devices: {}".format(strategy.num_replicas_in_sync))

In [None]:
if strategy.num_replicas_in_sync > 1:
    print("Using mirrored strategy.")
    with strategy.scope():
        base_model, model = get_compiled_model()
else:
    base_model, model = get_compiled_model()

In [None]:
model.summary()

Define some training 'callbacks'. One logs in a format used by [TensorBoard](https://www.tensorflow.org/tensorboard).  The other sets up model checkpointing. If training is interrupted for some reason, we can reconstitute the last-saved model from the checkpoint directory.

In [None]:
# LOG_DIR = f'./logs/{TIMESTAMP}'
LOG_DIR = f"{BUCKET}/logs/pc/{TIMESTAMP}"

print(LOG_DIR)
CHECKPOINT_DIR = f"./checkpoints/{TIMESTAMP}/checkpoints"
print(CHECKPOINT_DIR)

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=LOG_DIR, update_freq=300)

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=CHECKPOINT_DIR,
    #     save_weights_only=True,
    monitor="val_accuracy",
    mode="max",
    save_freq="epoch"
    #     save_best_only=True
)

## Train the model

Train the model, using transfer learning, for a few epochs. The base model weights are 'frozen' for this training run, and won't be updated.

In [None]:
model.fit(
    train_data,
    epochs=4,
    callbacks=[tensorboard_callback, model_checkpoint_callback],
    validation_data=valid_data,
)

## Fine-tune the trained model

Next, we'll do some model [fine-tuning](https://d2l.ai/chapter_computer-vision/fine-tuning.html), unfreezing the rest of the model weights.

In [None]:
# 'Unfreeze' the rest of the model
for layer in model.layers:
    layer.trainable = True

# we need to recompile the model for these modifications to take effect
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-4),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=["accuracy"],
)

In [None]:
model.fit(
    train_data,
    epochs=3,
    callbacks=[tensorboard_callback, model_checkpoint_callback],
    validation_data=valid_data,
)

In [None]:
model.evaluate(test_data)

### Save the trained model

#### Save to the local file system

In [None]:
model_path = f"./saved_model/{TIMESTAMP}"
print(f"model path: {model_path}")

In [None]:
# save the model
model.save(model_path)

#### Save the model to GCS

Alternately, you can save the model to a GCS bucket.

In [None]:
model_path_gcs = f"{BUCKET}/pcam/saved_models/{TIMESTAMP}"
print(f"GCS model path: {model_path_gcs}")
model.save(model_path_gcs)

#### Load a saved model

In [None]:
# later, you can load and use the saved model by providing a local or GCS path, e.g.:

model2 = keras.models.load_model(model_path_gcs)
model2.summary()

## Model prediction

We can now use the trained model for prediction.

In [None]:
LABELS = ["non_metastic", "metastic"]

In [None]:
for images, labels in test_data.take(1):
    print(f"labels: {labels}")
    predictions = model.predict(images)
    print(f"predictions: {predictions}")

In [None]:
for i, p in enumerate(predictions):
    idx = list(p).index(max(p))
    if i < 4:
        plt.figure(figsize=(4, 4))
        plt.imshow(images[i].numpy().astype("int32"))
        plt.axis("off")
        plt.title(
            f"image is predicted to be: {LABELS[idx]}, with label {LABELS[labels[i]]}"
        )

## Model metrics

Now let's derive some model metrics.  We'll get the predictions from the validation set, and use those for building a [confusion matrix](https://en.wikipedia.org/wiki/Confusion_matrix), as well as [precision, recall](https://en.wikipedia.org/wiki/Precision_and_recall), and [AUC](https://en.wikipedia.org/wiki/AUC) metrics.

In [None]:
ma = tf.keras.metrics.AUC()
mp = tf.keras.metrics.Precision()
mr = tf.keras.metrics.Recall()

all_preds = []
all_labels = []

for images, labels in valid_data.take(len(valid_data)):
    predictions = model.predict(images)
    y_preds = np.argsort(predictions, axis=1)[:, -1:]
    all_preds += list(y_preds.flatten())
    all_labels += list(labels.numpy())
    onehot_labels = tf.keras.utils.to_categorical(labels, num_classes=len(LABELS))
    ma.update_state(onehot_labels, predictions)
    mp.update_state(onehot_labels, predictions)
    mr.update_state(onehot_labels, predictions)

We'll show two different ways to create a confusion matrix -- `tf.math.confusion_matrix` and `sklearn.metrics.confusion_matrix`.

In [None]:
def show_confusion_matrix(cm, labels):
    plt.figure(figsize=(10, 8))
    fig, ax = plt.subplots(figsize=(10, 10))
    sns.heatmap(cm, xticklabels=labels, yticklabels=labels, annot=True, fmt="g")
    plt.xlabel("Prediction")
    plt.ylabel("Label")
    plt.show()


cm = tf.math.confusion_matrix(all_labels, all_preds, num_classes=len(LABELS))
# print(cm)

In [None]:
show_confusion_matrix(cm, LABELS)

We can optionally calculate 'percentage' information from the confusion matrix and plot that instead of the raw numbers, as shown below.

In [None]:
scm = confusion_matrix(all_labels, all_preds)
scm = scm.astype("float") / scm.sum(axis=1)[:, np.newaxis]
disp = ConfusionMatrixDisplay(scm, display_labels=LABELS)
fig, ax = plt.subplots(figsize=(12, 12))
disp.plot(ax=ax)

In [None]:
print(f"AUC: {ma.result().numpy()}")
print(f"Precision: {mp.result().numpy()}")
print(f"Recall: {mr.result().numpy()}")

## Provenance

In [None]:
import datetime
print(datetime.datetime.now())

In [None]:
!pip3 freeze

--------------------------------
Copyright 2021 Verily Life Sciences LLC

Use of this source code is governed by a BSD-style  
license that can be found in the LICENSE file or at  
https://developers.google.com/open-source/licenses/bsd