# Data Modeling

### Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from pathlib import Path
from typing import Tuple, Dict, Optional
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    average_precision_score,
    balanced_accuracy_score,
    precision_recall_curve,
    PrecisionRecallDisplay,
)
from itertools import cycle

In [None]:
src_path: str = "../src"
sys.path.append(src_path)

In [None]:
from data_modules.pneumonia_data_module import PneumoniaDataModule

In [None]:
# optinally randomly sample this amount of images for training
SAMPLE_SIZE: Optional[int] = 10000
IMG_SIZE: int = 260  # for EfficientNetB2
EPOCHS: int = 25
BATCH_SIZE: int = 128
RANDOM_SEED: int = 8080
DATA_ROOT: Path = Path("../data")
OUTPUTS_DIR: Path = DATA_ROOT.joinpath("model_outputs")
XRAY_IMAGES_ROOT: Path = Path("/home/uziel/Downloads/nih_chest_x_rays")
CHECKPOINT_PATH: Path = OUTPUTS_DIR.joinpath("model_checkpoint")
MODEL_PATH: Path = OUTPUTS_DIR.joinpath("pneumonia_xray_classifier")
HISTORY_PATH: Path = OUTPUTS_DIR.joinpath("training_history.csv")
BEST_TH_PATH: Path = OUTPUTS_DIR.joinpath("best_th.txt")

## 1. Load samples and images metadata

In [None]:
annot_df = pd.read_csv(DATA_ROOT.joinpath("processed_annotations.csv"))
annot_df

In [None]:
annot_df["image_path"] = annot_df["image_name"].map(
    {img_file.name: img_file for img_file in XRAY_IMAGES_ROOT.glob("**/*.png")}
)

## 2. Get train, val and test data loaders

We instantiate a Pytorch Lightning data module that takes care of the following under the hood:

1. Split data into train, val and test sets.
2. Set pre-processing and data augmentation transforms.
3. Initialize train, val and test datasets.

The data module can be used to extract the relevant data loaders of each set as needed.

In [None]:
if SAMPLE_SIZE is not None:
    annot_df = annot_df.sample(frac=1)[:SAMPLE_SIZE]

In [None]:
data_module = PneumoniaDataModule(annot_df)
data_module.setup("")

train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()
test_loader = data_module.test_dataloader()

In [None]:
train_data, val_data = (train_loader.dataset.data, val_loader.dataset.data)

### Check some key metadata distributions

In [None]:
pd.concat(
    [
        pd.concat(
            [
                train_data[col].value_counts(normalize=True).rename(col)
                for col in ["pneumonia", "patient_gender", "view_position"]
            ]
        ).rename("train_data"),
        pd.concat(
            [
                val_data[col].value_counts(normalize=True).rename(col)
                for col in ["pneumonia", "patient_gender", "view_position"]
            ]
        ).rename("val_data"),
    ],
    axis=1,
)

All relevant metadata fields are mostly equally distributed in training and validation sets.

### Inspect data augmentations on training data

In [None]:
t_x, t_y = next(iter(train_loader))
fig, m_axs = plt.subplots(4, 4, figsize=(16, 16))
for c_x, c_y, c_ax in zip(t_x, t_y, m_axs.flatten()):
    c_ax.imshow(c_x.permute(1, 2, 0), cmap="bone")
    if c_y == 1:
        c_ax.set_title("Pneumonia")
    else:
        c_ax.set_title("No Pneumonia")
    c_ax.axis("off")

## 4. Build model

Useful source: https://keras.io/examples/vision/image_classification_efficientnet_fine_tuning/

In [None]:
def load_pretrained_model(*args, **kwargs) -> tf.keras.Model:
    effnet_model = EfficientNetB2(include_top=True, weights="imagenet", *args, **kwargs)
    return tf.keras.Model(
        inputs=effnet_model.input, outputs=effnet_model.get_layer("block7b_add").output
    )


def build_model(
    base_model: tf.keras.Model, preprocessing_layers: tf.keras.Sequential
) -> tf.keras.Sequential:
    # 1. Freeze all EfficientNet blocks except the last one (Block 7)
    for layer in base_model.layers[:-28]:
        layer.trainable = False

    # 2. Build final model by adding some extra layers
    model = tf.keras.Sequential(
        [
            preprocessing_layers,
            base_model,
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(1, activation="sigmoid"),
        ]
    )

    # 3. Define optimizer, loss and metric to monitor
    optimizer = Adam()
    loss = "binary_focal_crossentropy"
    metrics = [
        tf.keras.metrics.TruePositives(name="tp"),
        tf.keras.metrics.FalsePositives(name="fp"),
        tf.keras.metrics.TrueNegatives(name="tn"),
        tf.keras.metrics.FalseNegatives(name="fn"),
        tf.keras.metrics.BinaryAccuracy(name="accuracy"),
        tf.keras.metrics.Precision(name="precision"),
        tf.keras.metrics.Recall(name="recall"),
        tf.keras.metrics.AUC(name="auc"),
        tf.keras.metrics.AUC(name="prc", curve="PR"),
    ]

    # 4. Compile model
    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

    return model

In [None]:
model = build_model(load_pretrained_model(), preprocessing_layers)

## 5. Train model

In [None]:
def train_model(
    model: tf.keras.Model,
    train_dataset: tf.data.Dataset,
    val_dataset: tf.data.Dataset,
    checkpoint_path: Path,
    epochs: int = 100,
):
    """Train model"""
    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        checkpoint_path,
        monitor="Recall",
        verbose=1,
        save_best_only=True,
        mode="max",
        save_weights_only=True,
    )

    early = tf.keras.callbacks.EarlyStopping(monitor="recall", mode="max", patience=10)

    callbacks_list = [checkpoint, early]

    return model.fit(
        train_dataset,
        epochs=epochs,
        validation_data=val_dataset,
        verbose=1,
        callbacks=callbacks_list,
    )

In [None]:
training_hist = train_model(
    model, train_dataset, val_dataset, checkpoint_path=CHECKPOINT_PATH, epochs=EPOCHS
)

In [None]:
history_df = pd.DataFrame(training_hist.history)
history_df.to_csv(HISTORY_PATH)

## 6. Evaluate model

In [None]:
def get_performance_metrics(y_true: tf.Tensor, y_pred: tf.Tensor) -> Dict[str, float]:
    """Compute multiple performance metrics

    Args:
        y_true: Ground truth labels for each observation.
        y_pred: Predicted labels for each observation.

    Returns:
        A dictionary containing multiple performance metrics
    """
    return {
        "precision": precision_score(y_true, y_pred),
        "recall": recall_score(y_true, y_pred),
        "f1_score": f1_score(y_true, y_pred),
        "roc_auc_score": roc_auc_score(y_true, y_pred),
        "average_precision_score": average_precision_score(y_true, y_pred),
        "balanced_accuracy_score": balanced_accuracy_score(y_true, y_pred),
    }


def plot_pr_curve(y_true: tf.Tensor, y_pred: tf.Tensor):
    """Plot precision-recall curve

    Args:
        performance_metrics: A dictionary of performance metrics including recall,
            precision and average precision scores.
    """
    _, ax = plt.subplots(figsize=(8, 8))

    f_scores = np.linspace(0.2, 0.8, num=4)
    labels = []
    for f_score in f_scores:
        x = np.linspace(0.01, 1)
        y = f_score * x / (2 * x - f_score)
        (l,) = plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=0.2)
        plt.annotate("f1={0:0.1f}".format(f_score), xy=(0.9, y[45] + 0.02))

    display = PrecisionRecallDisplay.from_predictions(y_true, y_pred, ax=ax)

    # add the legend for the iso-f1 curves
    handles, labels = display.ax_.get_legend_handles_labels()
    handles.extend([l])
    labels.extend(["iso-f1 curves"])

    # set the legend and the axes
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.legend(handles=handles, labels=labels, loc="best")
    ax.set_title("Precision-Recall curve")

    plt.show()


def plot_history(history_df: pd.DataFrame) -> None:
    """Plot training metrics.

    Args:
        history_df: History dataframe containing scores for each epoch.
    """
    # 0. Rename columns to more meaningful names for plot legends
    history_df = history_df.rename(
        columns={
            "loss": "Training loss",
            "val_loss": "Validation loss",
            "recall": "Training recall",
            "val_recall": "Validation recall",
            "precision": "Training precision",
            "val_precision": "Validation precision",
        }
    )
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
    plot_kwargs = dict(xlabel="Epoch", ylabel="Score")
    history_df[["Training loss", "Validation loss"]].plot(
        ax=axes[0], title="Loss during training", **plot_kwargs
    )
    history_df[
        [
            "Training recall",
            "Validation recall",
            "Training precision",
            "Validation precision",
        ]
    ].plot(ax=axes[1], title="Performance during training", **plot_kwargs)
    fig.tight_layout()

In [None]:
model.load_weights(CHECKPOINT_PATH)

y_true = np.concatenate([y for x, y in val_dataset], axis=0)
y_scores = model.predict(val_dataset, batch_size=32)
y_pred = (y_scores.flatten() > 0.5).astype(int)

performance_metrics = get_performance_metrics(y_true, y_pred)

In [None]:
plot_history(history_df)

### 6.1. Precision-recall curve

In [None]:
plot_pr_curve(y_true, y_scores)

### 6.2. Find threshold that optimizes recall (sensitivity or true positive rate)

> When a high recall test returns a negative result, you can be confident that the result is truly negative since a high recall test has low false negatives. Recall does not take false positives into account though, so you may have high recall but are still labeling a lot of negative cases as positive. Because of this, high recall tests are good for things like screening studies, where you want to make sure someone _doesn’t_ have a disease or worklist prioritization where you want to make sure that people _without_ the disease are being de-prioritized.

In [None]:
precision, recall, thresholds = precision_recall_curve(y_true, y_scores)

pr_stats = (
    pd.DataFrame(
        data=precision_recall_curve(y_true, y_scores),
        index=["precision", "recall", "threshold"],
    )
    .transpose()
    .sort_values(["recall", "precision"], ascending=False)
)
pr_stats

Because we want to maximize recall over precision, the best threshold is obtained by sorting our thresholds by recall, and then by precision if there is a tie.

In [None]:
best_th = pr_stats["threshold"][0]

BEST_TH_PATH.write_text(str(best_th))

print(f"The best threshold found was: {best_th}")

### 6.3. Visualize predicted vs true with the best threshold found

In [None]:
fig, m_axs = plt.subplots(10, 10, figsize=(16, 16))

for i, (c_x, c_y, c_ax) in enumerate(zip(*next(iter(val_dataset)), m_axs.flatten())):
    c_ax.imshow(c_x[:, :, 0], cmap="bone")
    if c_y == 1:
        if y_scores[i] > best_th:
            c_ax.set_title("1, 1")
        else:
            c_ax.set_title("1, 0")
    else:
        if y_scores[i] > best_th:
            c_ax.set_title("0, 1")
        else:
            c_ax.set_title("0, 0")
    c_ax.axis("off")

## 7. Persist model architecture

In [None]:
model.save(MODEL_PATH)