# Chest X-Ray Disease Classification

This notebook shows chest x-ray classification on the NIH dataset using a pretrained model from the TorchXRayVision library and CyclOps to generate a model card.

### Import Libraries

In [None]:
import os
import shutil
from datetime import date
from functools import partial
from pathlib import Path

import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import torch
from datasets.combine import concatenate_datasets  # noqa: E402
from dateutil.relativedelta import relativedelta
from monai.transforms import AddChanneld, Compose, Lambdad, Resized, ToDeviced
from torchxrayvision.models import DenseNet

from cyclops.data.loader import load_nihcxr
from cyclops.data.slicer import filter_value  # noqa: E402
from cyclops.data.slicer import SliceSpec
from cyclops.data.utils import apply_transforms
from cyclops.evaluate import evaluator
from cyclops.evaluate.fairness import evaluate_fairness  # noqa: E402
from cyclops.evaluate.metrics.factory import create_metric
from cyclops.monitor.utils import get_device
from cyclops.report import ModelCardReport
from cyclops.report.plot.classification import ClassificationPlotter
from cyclops.report.utils import get_metrics_trends
from cyclops.utils.file import join

device = get_device()

CyclOps offers a package for documentation of the model through a model card. The `ModelCardReport` class is used to populate and generate the model card as an HTML file. The model card has the following sections:
- Model Details: This section contains descriptive metadata about the model such as the owners, version, license, etc.
- Model Parameters: This section contains the technical details of the model such as the model architecture, training parameters, etc.
- Considerations: This section contains descriptions of the considerations involved in developing and using the model such as the intended use, limitations, etc.
- Quantitative Analysis: This section contains the performance metrics of the model for different sets of the data and subpopulations.
- Explainaibility Analysis: This section contains the explainability metrics of the model.
- Fairness Analysis: This section contains the fairness metrics of the model.

We will use this to document the model development process as we go along and generate the model card at the end.

`The model card tool is a work in progress and is subject to change.`

In [None]:
report = ModelCardReport()

### Load Dataset

In [None]:
data_dir = "/mnt/data/clinical_datasets/NIHCXR"
nih_ds = load_nihcxr(data_dir)
nih_ds = nih_ds.select(range(4000))

transforms = Compose(
    [
        AddChanneld(keys=("features",), allow_missing_keys=True),
        Resized(
            keys=("features",), spatial_size=(1, 224, 224), allow_missing_keys=True
        ),
        Lambdad(
            keys=("features",),
            func=lambda x: ((2 * (x / 255.0)) - 1.0) * 1024,
            allow_missing_keys=True,
        ),
        ToDeviced(keys=("features",), device=device, allow_missing_keys=True),
    ]
)

### Load Model and get Predictions

In [None]:
model = DenseNet(weights="densenet121-res224-nih")
model.to(device)
model.eval()


def get_predictions_torch(examples):
    images = torch.stack(examples["features"]).squeeze(1)
    preds = model(images)
    return {"predictions": preds}


with nih_ds.formatted_as(
    "custom",
    columns=["features"],
    transform=partial(apply_transforms, transforms=transforms),
):
    preds_ds = nih_ds.map(
        get_predictions_torch,
        batched=True,
        batch_size=32,
        remove_columns=nih_ds.column_names,
    )

    nih_ds = concatenate_datasets([nih_ds, preds_ds], axis=1)

# remove any rows with No Finding == 1
nih_ds = nih_ds.filter(
    partial(filter_value, column_name="No Finding", value=1, negate=True), batched=True
)

# remove the No Finding column and adjust the predictions to account for it
nih_ds = nih_ds.map(
    lambda x: {
        "predictions": x["predictions"][:14],
    },
    remove_columns=["No Finding"],
)
nih_ds.features

### Multilabel AUROC by Pathology and Sex

In [None]:
pathologies = model.pathologies[:14]

auroc = create_metric(
    metric_name="auroc",
    task="multilabel",
    num_labels=len(pathologies),
    thresholds=np.arange(0, 1, 0.01),
)

# define the slices
slices = [
    {"Patient Gender": {"value": "M"}},
    {"Patient Gender": {"value": "F"}},
]

# create the slice functions
slice_spec = SliceSpec(spec_list=slices)

nih_eval_results_gender = evaluator.evaluate(
    dataset=nih_ds,
    metrics=auroc,
    feature_columns="features",
    target_columns=pathologies,
    prediction_column_prefix="predictions",
    remove_columns="features",
    slice_spec=slice_spec,
)

# plot the results
plots = []

for slice_name, slice_results in nih_eval_results_gender.items():
    plots.append(
        go.Scatter(
            x=pathologies,
            y=slice_results["MultilabelAUROC"],
            name="Overall" if slice_name == "overall" else slice_name,
            mode="markers",
        )
    )

perf_metric_gender = go.Figure(data=plots)
perf_metric_gender.update_layout(
    title="Multilabel AUROC by Pathology and Sex",
    title_x=0.5,
    title_font_size=20,
    xaxis_title="Pathology",
    yaxis_title="Multilabel AUROC",
    width=1024,
    height=768,
)
perf_metric_gender.update_traces(
    marker=dict(size=12, line=dict(width=2)),
    selector=dict(mode="markers"),
)

### Multilabel AUROC by Pathology and Age

In [None]:
pathologies = model.pathologies[:14]

auroc = create_metric(
    metric_name="auroc",
    task="multilabel",
    num_labels=len(pathologies),
    thresholds=np.arange(0, 1, 0.01),
)

# define the slices
slices = [
    {"Patient Age": {"min_value": 19, "max_value": 35}},
    {"Patient Age": {"min_value": 35, "max_value": 65}},
    {"Patient Age": {"min_value": 65, "max_value": 100}},
]

# create the slice functions
slice_spec = SliceSpec(spec_list=slices)

nih_eval_results_age = evaluator.evaluate(
    dataset=nih_ds,
    metrics=auroc,
    feature_columns="features",
    target_columns=pathologies,
    prediction_column_prefix="predictions",
    remove_columns="features",
    slice_spec=slice_spec,
)


# plot the results
plots = []

for slice_name, slice_results in nih_eval_results_age.items():
    plots.append(
        go.Scatter(
            x=pathologies,
            y=slice_results["MultilabelAUROC"],
            name="Overall" if slice_name == "overall" else slice_name,
            mode="markers",
        )
    )

perf_metric_age = go.Figure(data=plots)
perf_metric_age.update_layout(
    title="Multilabel AUROC by Pathology and Age",
    title_x=0.5,
    title_font_size=20,
    xaxis_title="Pathology",
    yaxis_title="Multilabel AUROC",
    width=1024,
    height=768,
)
perf_metric_age.update_traces(
    marker=dict(size=12, line=dict(width=2)),
    selector=dict(mode="markers"),
)

In [None]:
fig = px.pie(
    values=[nih_ds["Patient Gender"].count("M"), nih_ds["Patient Gender"].count("F")],
    names=["Male", "Female"],
)

fig.update_layout(title="Gender Distribution")

report.log_plotly_figure(
    fig=fig,
    caption="Gender Distribution",
    section_name="datasets",
)

fig.show()

In [None]:
fig = px.histogram(nih_ds["Patient Age"])
fig.update_traces(showlegend=False)
fig.update_layout(
    title="Age Distribution",
    xaxis_title="Age",
    yaxis_title="Count",
    bargap=0.2,
)

report.log_plotly_figure(
    fig=fig,
    caption="Age Distribution",
    section_name="datasets",
)

fig.show()

In [None]:
fig = px.bar(x=pathologies, y=[np.array(nih_ds[p]).sum() for p in pathologies])
fig.update_layout(
    title="Pathology Distribution",
    xaxis_title="Pathology",
    yaxis_title="Count",
    bargap=0.2,
)

report.log_plotly_figure(
    fig=fig, caption="Pathology Distribution", section_name="datasets"
)

fig.show()

### Balanced Error Rate by Pathology and Age

In [None]:
specificity = create_metric(
    metric_name="specificity",
    task="multilabel",
    num_labels=len(pathologies),
)
sensitivity = create_metric(
    metric_name="sensitivity",
    task="multilabel",
    num_labels=len(pathologies),
)

fpr = 1 - specificity
fnr = 1 - sensitivity

balanced_error_rate = (fpr + fnr) / 2

nih_fairness_result_age = evaluate_fairness(
    metrics=balanced_error_rate,
    metric_name="BalancedErrorRate",
    dataset=nih_ds,
    remove_columns="features",
    target_columns=pathologies,
    prediction_columns="predictions",
    groups=["Patient Age"],
    group_bins={"Patient Age": [35, 65]},
    group_base_values={"Patient Age": 50},
)


# plot metrics per slice
plots = []
for slice_name, slice_results in nih_fairness_result_age.items():
    plots.append(
        go.Scatter(
            x=pathologies,
            y=slice_results["BalancedErrorRate"],
            name=slice_name,
            mode="markers",
        )
    )
fairness_age = go.Figure(data=plots)
fairness_age.update_layout(
    title="Balanced Error Rate by Age vs. Pathology",
    title_x=0.5,
    title_font_size=20,
    xaxis_title="Pathology",
    yaxis_title="Balanced Error Rate",
    width=1024,
    height=768,
)
fairness_age.update_traces(
    marker=dict(size=12, line=dict(width=2)),
    selector=dict(mode="markers"),
)

### Balanced Error Rate Parity by Pathology and Age

In [None]:
# plot parity difference per slice
plots = []

for slice_name, slice_results in nih_fairness_result_age.items():
    plots.append(
        go.Scatter(
            x=pathologies,
            y=slice_results["BalancedErrorRate Parity"],
            name=slice_name,
            mode="markers",
        )
    )

fairness_age_parity = go.Figure(data=plots)
fairness_age_parity.update_layout(
    title="Balanced Error Rate Parity by Age vs. Pathology",
    title_x=0.5,
    title_font_size=20,
    xaxis_title="Pathology",
    yaxis_title="Balanced Error Rate Parity",
    width=1024,
    height=768,
)
fairness_age_parity.update_traces(
    marker=dict(size=12, line=dict(width=2)),
    selector=dict(mode="markers"),
)

### Balanced Error Rate by Pathology and Age+Sex

In [None]:
nih_fairness_result_age_gender = evaluate_fairness(
    metrics=balanced_error_rate,
    metric_name="BalancedErrorRate",
    dataset=nih_ds,
    remove_columns="features",
    target_columns=pathologies,
    prediction_columns="predictions",
    groups=["Patient Age", "Patient Gender"],
    group_bins={"Patient Age": [35, 65]},
    group_base_values={"Patient Age": 50, "Patient Gender": "M"},
)

# plot metrics per slice
plots = []

for slice_name, slice_results in nih_fairness_result_age_gender.items():
    plots.append(
        go.Scatter(
            x=pathologies,
            y=slice_results["BalancedErrorRate"],
            name=slice_name,
            mode="markers",
        )
    )

fairness_age_gender = go.Figure(data=plots)
fairness_age_gender.update_layout(
    title="Balanced Error Rate by Age&Gender vs. Pathology",
    title_x=0.5,
    title_font_size=20,
    xaxis_title="Pathology",
    yaxis_title="Balanced Error Rate",
    width=1024,
    height=768,
)
fairness_age_gender.update_traces(
    marker=dict(size=12, line=dict(width=2)),
    selector=dict(mode="markers"),
)

### Balanced Error Rate Parity by Pathology and Age+Sex

In [None]:
# plot parity difference per slice
plots = []

for slice_name, slice_results in nih_fairness_result_age_gender.items():
    plots.append(
        go.Scatter(
            x=pathologies,
            y=slice_results["BalancedErrorRate Parity"],
            name=slice_name,
            mode="markers",
        )
    )

fairness_age_gender_parity = go.Figure(data=plots)
fairness_age_gender_parity.update_layout(
    title="Balanced Error Rate Parity by Age&Gender vs. Pathology",
    title_x=0.5,
    title_font_size=20,
    xaxis_title="Pathology",
    yaxis_title="Balanced Error Rate Parity",
    width=1024,
    height=768,
)
fairness_age_gender_parity.update_traces(
    marker=dict(size=12, line=dict(width=2)),
    selector=dict(mode="markers"),
)

### Log Performance Metrics as Tests w/ Thresholds

In [None]:
results_flat = {}
for slice, metrics in nih_eval_results_age.items():
    for name, metric in metrics.items():
        results_flat[f"{slice}/{name}"] = metric.mean()
        for itr, m in enumerate(metric):
            results_flat[f"{slice} ({pathologies[itr]})/{name}"] = m
for slice, metrics in nih_eval_results_gender.items():
    for name, metric in metrics.items():
        results_flat[f"{slice}/{name}"] = metric.mean()
        for itr, m in enumerate(metric):
            results_flat[f"{slice} ({pathologies[itr]})/{name} "] = m

for name, metric in results_flat.items():
    split, name = name.split("/")
    report.log_quantitative_analysis(
        "performance",
        name=name,
        value=metric,
        metric_slice=split,
        pass_fail_thresholds=0.7,
        pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),
    )

### Performance over time

We can monitor the change of performance metrics over time by leveraging historical reports. This involves using the `get_metrics_trends` function to gather prior metrics and merge them with recent results. We can specify which metrics and slices we wish to observe. Once the data is collected, we generate the plot using the `metrics_trends` method from the plotter, which can then be integrated into the report.

Please note, for the purpose of this tutorial, we will create three dummy reports to demonstrate the process of plotting these metric trends.

In [None]:
# Generating dummy reports
dummy_report_num = 3
dummy_report_dir = join(os.getcwd(), "dummy_reports_cxr")

for i in range(dummy_report_num):
    # Create a dummy model card report
    dummy_report = ModelCardReport(output_dir=dummy_report_dir)
    # Add noise to the recent results to simulate the model performance change
    noise = np.random.uniform(-0.1, 0.1)
    dummy_result = {key: max(0, value - noise) for key, value in results_flat.items()}

    for name, metric in dummy_result.items():
        split, name = name.split("/")
        dummy_report.log_quantitative_analysis(
            "performance",
            name=name,
            value=metric,
            metric_slice=split,
            pass_fail_thresholds=0.7,
            pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),
        )
    # Rename the report folder to the dummy date to simulate the time change
    dummy_report_path = dummy_report.export()
    date_dir = Path(dummy_report_path).parents[1]
    dummy_date = date.today() + relativedelta(months=-(6 * (i + 1)))
    new_dir = f"{dummy_report_dir}/{dummy_date}"
    if os.path.exists(new_dir):
        shutil.rmtree(new_dir)
    os.rename(date_dir, new_dir)

# Collecting performance metrics from previous reports and current report
trends = get_metrics_trends(
    report_directory=dummy_report_dir,
    flat_results=results_flat,
    slice_names=["overall"],
)
shutil.rmtree(dummy_report_dir)

plotter = ClassificationPlotter(task_type="binary")
plotter.set_template("plotly_white")

# Plotting the performance over time
trends_plot = plotter.metrics_trends(trends)
report.log_plotly_figure(
    fig=trends_plot,
    caption="Performance over time",
    section_name="quantitative analysis",
)
trends_plot.show()

## Populate Model Card Fields

In [None]:
# model details for NIH Chest X-Ray model
report.log_from_dict(
    data={
        "name": "NIH Chest X-Ray Multi-label Classification Model",
        "description": "This model is a DenseNet121 model trained on the NIH Chest \
            X-Ray dataset, which contains 112,120 frontal-view X-ray images of 30,805 \
            unique patients with the fourteen text-mined disease labels from the \
            associated radiological reports. The labels are Atelectasis, Cardiomegaly, \
            Effusion, Infiltration, Mass, Nodule, Pneumonia, Pneumothorax, \
            Consolidation, Edema, Emphysema, Fibrosis, Pleural Thickening, and Hernia. \
            The model was trained on 80% of the data and evaluated on the remaining \
            20%.",
        "references": [{"link": "https://arxiv.org/abs/2111.00595"}],
    },
    section_name="Model Details",
)

report.log_citation(
    citation="""@inproceedings{Cohen2022xrv,
    title = {{TorchXRayVision: A library of chest X-ray datasets and models}},
    author = {Cohen, Joseph Paul and Viviano, Joseph D. and Bertin, \
    Paul and Morrison,Paul and Torabian, Parsa and Guarrera, \
    Matteo and Lungren, Matthew P and Chaudhari,\
    Akshay and Brooks, Rupert and Hashir, \
    Mohammad and Bertrand, Hadrien},
    booktitle = {Medical Imaging with Deep Learning},
    url = {https://github.com/mlmed/torchxrayvision},
    arxivId = {2111.00595},
    year = {2022}
    }""",
)

report.log_citation(
    citation="""@inproceedings{cohen2020limits,
    title={On the limits of cross-domain generalization\
            in automated X-ray prediction},
    author={Cohen, Joseph Paul and Hashir, Mohammad and Brooks, \
        Rupert and Bertrand, Hadrien},
    booktitle={Medical Imaging with Deep Learning},
    year={2020},
    url={https://arxiv.org/abs/2002.02497}
    }""",
)

report.log_owner(name="Machine Learning and Medicine Lab", contact="mlmed.org")

# considerations
report.log_user(description="Radiologists")
report.log_user(description="Data Scientists")

report.log_use_case(
    description="The model can be used to predict the presence of 14 pathologies \
        in chest X-ray images.",
    kind="primary",
)
report.log_descriptor(
    name="limitations",
    description="The limitations of this model include its inability to detect \
                pathologies that are not included in the 14 labels of the NIH \
                Chest X-Ray dataset. Additionally, the model may not perform \
                well on images that are of poor quality or that contain \
                artifacts. Finally, the model may not generalize well to\
                populations that are not well-represented in the training \
                data, such as patients from different geographic regions or \
                with different demographics.",
    section_name="considerations",
)
report.log_descriptor(
    name="tradeoffs",
    description="The model can help radiologists to detect pathologies in \
        chest X-ray images, but it may not generalize well to populations \
        that are not well-represented in the training data.",
    section_name="considerations",
)
report.log_risk(
    risk="One ethical risk of the model is that it may not generalize well to \
        populations that are not well-represented in the training data,\
        such as patients from different geographic regions \
        or with different demographics. ",
    mitigation_strategy="A mitigation strategy for this risk is to ensure \
        that the training data is diverse and representative of the population \
            that the model will be used on. Additionally, the model should be \
            regularly evaluated and updated to ensure that it continues to \
            perform well on diverse populations. Finally, the model should \
            be used in conjunction with human expertise to ensure that \
            any biases or limitations are identified and addressed.",
)
report.log_fairness_assessment(
    affected_group="Patients with rare pathologies",
    benefit="The model can help radiologists to detect pathologies in \
        chest X-ray images.",
    harm="The model may not generalize well to populations that are not \
        well-represented in the training data.",
    mitigation_strategy="A mitigation strategy for this risk is to ensure that \
        the training data is diverse and representative of the population.",
)


# qualitative analysis
report.log_plotly_figure(
    fig=perf_metric_gender,
    caption="MultiLabel AUROC by Pathology",
    section_name="Quantitative Analysis",
)

report.log_plotly_figure(
    fig=perf_metric_age,
    caption="MultiLabel AUROC by Pathology",
    section_name="Quantitative Analysis",
)

report.log_plotly_figure(
    fig=fairness_age,
    caption="Balanced Error Rate by Age vs. Pathology",
    section_name="Fairness Analysis",
)
report.log_plotly_figure(
    fig=fairness_age_parity,
    caption="Balanced Error Rate Parity by Age vs. Pathology",
    section_name="Fairness Analysis",
)

report.log_plotly_figure(
    fig=fairness_age_gender,
    caption="Balanced Error Rate by Age&Gender vs. Pathology",
    section_name="Fairness Analysis",
)
report.log_plotly_figure(
    fig=fairness_age_gender_parity,
    caption="Balanced Error Rate Parity by Age&Gender vs.Pathology",
    section_name="Fairness Analysis",
)

In [None]:
report_path = report.export()
shutil.copy(f"{report_path}", ".")

You can view the generated HTML [report](./model_card.html).