# Data Modeling

### Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from itertools import chain
from torch import Tensor
from pathlib import Path
from sklearn.metrics import (
    precision_recall_curve,
    PrecisionRecallDisplay,
)
from typing import Iterable
from rich import traceback
from torchvision.models import MobileNet_V3_Large_Weights

In [None]:
src_path: str = "../src"
sys.path.append(src_path)
_ = traceback.install()

In [None]:
from data_modules.pneumonia_data_module import PneumoniaDataModule
from models.pneumonia_classifier import PneumoniaClassifier

In [None]:
EPOCHS: int = 50
REQUIRED_TRANSFORMS = MobileNet_V3_Large_Weights.DEFAULT.transforms()
BATCH_SIZE: int = 6
RANDOM_SEED: int = 8080
DATA_ROOT: Path = Path("../data")
OUTPUTS_DIR: Path = DATA_ROOT.joinpath("model_outputs")
OUTPUTS_DIR.mkdir(exist_ok=True, parents=True)
XRAY_IMAGES_ROOT: Path = Path("/home/uziel/Downloads/nih_chest_x_rays")
LOGS_PATH: Path = OUTPUTS_DIR.joinpath("mobilenet_v3_large")
HISTORY_PATH: Path = OUTPUTS_DIR.joinpath("training_history.csv")
BEST_TH_PATH: Path = OUTPUTS_DIR.joinpath("best_th.txt")

## 1. Load samples and images metadata

In [None]:
annot_df = pd.read_csv(DATA_ROOT.joinpath("processed_annotations.csv"))
annot_df

In [None]:
annot_df["image_path"] = annot_df["image_name"].map(
    {img_file.name: img_file for img_file in XRAY_IMAGES_ROOT.glob("**/*.png")}
)

## 2. Get train, val and test data loaders

We instantiate a Pytorch Lightning data module that takes care of the following under the hood:

1. Split data into train, val and test sets.
2. Set pre-processing and data augmentation transforms.
3. Initialize train, val and test datasets.

The data module can be used to extract the relevant data loaders of each set as needed.

In [None]:
data_module = PneumoniaDataModule(
    annot_df,
    required_transforms=REQUIRED_TRANSFORMS,
    batch_size=BATCH_SIZE,
    balance_train=True,
    random_seed=RANDOM_SEED,
)
data_module.setup("")

train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()
test_loader = data_module.test_dataloader()

In [None]:
train_data, val_data = (train_loader.dataset.data, val_loader.dataset.data)

### Check some key metadata distributions

In [None]:
pd.concat(
    [
        pd.concat(
            [
                train_data[col].value_counts(normalize=True).rename(col)
                for col in ["pneumonia", "patient_gender", "view_position"]
            ]
        ).rename("train_data"),
        pd.concat(
            [
                val_data[col].value_counts(normalize=True).rename(col)
                for col in ["pneumonia", "patient_gender", "view_position"]
            ]
        ).rename("val_data"),
    ],
    axis=1,
)

All relevant metadata fields are mostly equally distributed in training and validation sets.

### Inspect data augmentations on training data

In [None]:
t_x, t_y = next(iter(train_loader))
fig, m_axs = plt.subplots(2, 3, figsize=(12, 8))
for c_x, c_y, c_ax in zip(t_x, t_y, m_axs.flatten()):
    c_ax.imshow(c_x.permute(1, 2, 0), cmap="bone")
    if c_y == 1:
        c_ax.set_title("Pneumonia")
    else:
        c_ax.set_title("No Pneumonia")
    c_ax.axis("off")

## 4. Build model

In [None]:
model = PneumoniaClassifier()

## 5. Train model

In [None]:
def train_model(
    model: LightningModule,
    train_loader: LightningDataModule,
    val_loader: LightningDataModule,
    test_loader: LightningDataModule,
    logs_path: Path,
    epochs: int = 100,
    **kwargs
):
    """Train model

    Args:
        model: Model to train.
        train_loader: Training data loader.
        logs_path: Where to store
        epochs:

    Returns:
        Trainer object.
    """
    trainer = Trainer(
        default_root_dir=logs_path,
        callbacks=[EarlyStopping(monitor="val_loss", patience=10, mode="min")],
        max_epochs=epochs,
        **kwargs
    )

    trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)
    test_results = trainer.test(model=model, dataloaders=test_loader, ckpt_path="best")

    return trainer, test_results

In [None]:
trainer, test_results = train_model(
    model,
    train_loader,
    val_loader,
    test_loader,
    logs_path=LOGS_PATH,
    epochs=EPOCHS,
    accelerator="gpu",
)

## 6. Evaluate model

In [None]:
def plot_pr_curve(y_true: Tensor, y_pred: Tensor):
    """Plot precision-recall curve

    Args:
        performance_metrics: A dictionary of performance metrics including recall,
            precision and average precision scores.
    """
    _, ax = plt.subplots(figsize=(8, 8))

    f_scores = np.linspace(0.2, 0.8, num=4)
    labels = []
    for f_score in f_scores:
        x = np.linspace(0.01, 1)
        y = f_score * x / (2 * x - f_score)
        (l,) = plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=0.2)
        plt.annotate("f1={0:0.1f}".format(f_score), xy=(0.9, y[45] + 0.02))

    display = PrecisionRecallDisplay.from_predictions(y_true, y_pred, ax=ax)

    # add the legend for the iso-f1 curves
    handles, labels = display.ax_.get_legend_handles_labels()
    handles.extend([l])
    labels.extend(["iso-f1 curves"])

    # set the legend and the axes
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.legend(handles=handles, labels=labels, loc="best")
    ax.set_title("Precision-Recall curve")

    plt.show()


def smooth(scalars: Iterable[float], weight: float = 0.5) -> Iterable[float]:
    last = scalars[0]
    smoothed = list()
    for point in scalars:
        smoothed_val = last * weight + (1 - weight) * point
        smoothed.append(smoothed_val)
        last = smoothed_val

    return smoothed


def plot_metrics(metrics: pd.DataFrame) -> None:
    """Plot training metrics.

    Args:
        history_df: History dataframe containing scores for each epoch.
    """
    # 0. Prune metrics
    metrics = metrics[
        metrics.columns[metrics.columns.str.contains("|".join(("train", "val")))]
    ]
    train_cols = [c for c in metrics.columns if "train" in c]
    metrics[train_cols] = metrics[train_cols].shift(-1)
    metrics = metrics.dropna(how="all")

    # 1. Plot setup
    fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(10, 15))
    plot_kwargs = dict(xlabel="Training step", ylabel="Score")

    # 1.1. Losses plot
    loss_cols = metrics.columns[metrics.columns.str.contains("loss")]
    metrics[loss_cols].apply(lambda x: smooth(x, 0.9)).plot(
        ax=axes[0], title="Loss during training", **plot_kwargs
    ).legend(loc="upper right")

    # 1.2. Stats cols
    stats_cols = metrics.columns[
        metrics.columns.str.contains("|".join(("true", "false")))
    ]
    metrics[stats_cols].apply(lambda x: smooth(x, 0.9)).plot(
        ax=axes[1], title="Stats during training", **plot_kwargs
    ).legend(loc="upper right")

    # 1.3. Stats cols
    binary_metrics_cols = metrics.columns.difference(loss_cols).difference(stats_cols)
    metrics[binary_metrics_cols].apply(lambda x: smooth(x, 0.9)).plot(
        ax=axes[2], title="Stats during training", **plot_kwargs
    ).legend(loc="upper right")

    fig.tight_layout()

### Test results

In [None]:
print(test_results)

In [None]:
metrics_file = sorted(
    LOGS_PATH.glob("**/metrics.csv"),
    key=lambda file: file.stat().st_mtime,
    reverse=True,
)[0]
metrics = pd.read_csv(metrics_file)
plot_metrics(metrics)

### 6.1. Precision-recall curve

In [None]:
predictions = trainer.predict(model=model, dataloaders=test_loader, ckpt_path="best")
y_true = [int(x) for x in chain(*[targets for img, targets in iter(test_loader)])]
y_scores = [float(x) for x in chain(*predictions)]

In [None]:
plot_pr_curve(y_true, y_scores)

### 6.2. Find threshold that optimizes recall (sensitivity or true positive rate)

> When a high recall test returns a negative result, you can be confident that the result is truly negative since a high recall test has low false negatives. Recall does not take false positives into account though, so you may have high recall but are still labeling a lot of negative cases as positive. Because of this, high recall tests are good for things like screening studies, where you want to make sure someone _doesn’t_ have a disease or worklist prioritization where you want to make sure that people _without_ the disease are being de-prioritized.

In [None]:
precision, recall, thresholds = precision_recall_curve(y_true, y_scores)

pr_stats = (
    pd.DataFrame(
        data=precision_recall_curve(y_true, y_scores),
        index=["precision", "recall", "threshold"],
    )
    .transpose()
    .sort_values(["recall", "precision"], ascending=False)
)
pr_stats

Because we want to maximize recall over precision, the best threshold is obtained by sorting our thresholds by recall, and then by precision if there is a tie.

In [None]:
best_th = pr_stats["threshold"].iloc[0]

BEST_TH_PATH.write_text(str(best_th))

print(f"The best threshold found was: {best_th}")

### 6.3. Visualize predicted vs true with the best threshold found

In [None]:
fig, m_axs = plt.subplots(2, 3, figsize=(12, 8))

for i, (c_x, c_y, c_ax) in enumerate(zip(*next(iter(val_loader)), m_axs.flatten())):
    c_ax.imshow(c_x.permute(1, 2, 0), cmap="bone")
    if c_y == 1:
        if y_scores[i] > best_th:
            c_ax.set_title("1, 1")
        else:
            c_ax.set_title("1, 0")
    else:
        if y_scores[i] > best_th:
            c_ax.set_title("0, 1")
        else:
            c_ax.set_title("0, 0")
    c_ax.axis("off")