In [None]:
import base64
from functools import partial
from io import BytesIO

import numpy as np
import plotly.graph_objects as go
import torch
from datasets.combine import concatenate_datasets  # noqa: E402
from monai.transforms import AddChanneld, Compose, Lambdad, Resized, ToDeviced
from torchxrayvision.models import DenseNet

from cyclops.data.loader import load_nihcxr
from cyclops.data.slicer import filter_value  # noqa: E402
from cyclops.data.slicer import SliceSpec
from cyclops.data.utils import apply_transforms
from cyclops.evaluate import evaluator
from cyclops.evaluate.fairness import evaluate_fairness  # noqa: E402
from cyclops.evaluate.metrics.factory import create_metric
from cyclops.monitor import ClinicalShiftApplicator, Detector, Reductor, TSTester
from cyclops.monitor.plotter import plot_drift_experiment
from cyclops.monitor.utils import get_device
from cyclops.report import ModelCardReport

device = get_device()


def plot_to_str(fig, dpi=300, transparent=True):
    img = BytesIO()
    fig.savefig(img, format="png", dpi=dpi, transparent=transparent)
    return f'data:image/{format};base64,\
        {base64.encodebytes(img.getvalue()).decode("utf-8")}'


def plot_to_str_plotly(fig, scale=2):
    img = BytesIO()
    fig.write_image(
        img,
        format="png",
        scale=scale,
    )
    return f'data:image/{format};base64,\
        {base64.encodebytes(img.getvalue()).decode("utf-8")}'

In [None]:
nih_ds = load_nihcxr("/mnt/data/clinical_datasets/NIHCXR")

In [None]:
transforms = Compose(
    [
        AddChanneld(keys=("features",), allow_missing_keys=True),
        Resized(
            keys=("features",), spatial_size=(1, 224, 224), allow_missing_keys=True
        ),
        Lambdad(
            keys=("features",),
            func=lambda x: ((2 * (x / 255.0)) - 1.0) * 1024,
            allow_missing_keys=True,
        ),
        ToDeviced(keys=("features",), device=device, allow_missing_keys=True),
    ]
)

model = DenseNet(weights="densenet121-res224-nih")

In [None]:
source_slice = None
target_slices = {
    "SEX: MALE": SliceSpec(spec_list=[{"Patient Gender": {"value": "M"}}]),
    "SEX: FEMALE": SliceSpec(spec_list=[{"Patient Gender": {"value": "F"}}]),
    "AGE: 18-35": SliceSpec(
        spec_list=[{"Patient Age": {"min_value": 18, "max_value": 35}}]
    ),
    "AGE: 35-65": SliceSpec(
        spec_list=[{"Patient Age": {"min_value": 35, "max_value": 65}}]
    ),
}
results = {}

for name, target_slice in target_slices.items():
    source_slice = None
    shifter = ClinicalShiftApplicator(
        "custom", source=source_slice, target=target_slice
    )
    ds_source, ds_target = shifter.apply_shift(nih_ds, num_proc=6)

    ds_source = ds_source.with_transform(
        partial(apply_transforms, transforms=transforms),
        columns=["features"],
        output_all_columns=True,
    )
    ds_target = ds_target.with_transform(
        partial(apply_transforms, transforms=transforms),
        columns=["features"],
        output_all_columns=True,
    )

    detector = Detector(
        "sensitivity_test",
        reductor=Reductor(dr_method="bbse-soft", model=model, device=device),
        tester=TSTester(tester_method="mmd"),
        source_sample_size=1000,
        target_sample_size=[50, 100, 200, 400, 800, 1000],
        num_runs=3,
    )
    results[name] = detector.detect_shift(ds_source, ds_target)
drift_plot = plot_drift_experiment(results, axes_color="white")

# drift_plot = plot_to_str(fig)

In [None]:
model.to(device)
model.eval()


def get_predictions_torch(examples):
    images = torch.stack(examples["features"]).squeeze(1)
    preds = model(images)
    return {"predictions": preds}


with nih_ds.formatted_as(
    "custom",
    columns=["features"],
    transform=partial(apply_transforms, transforms=transforms),
):
    preds_ds = nih_ds.map(
        get_predictions_torch,
        batched=True,
        batch_size=64,
        remove_columns=nih_ds.column_names,
    )

    nih_ds = concatenate_datasets([nih_ds, preds_ds], axis=1)

In [None]:
# remove any rows with No Finding == 1
nih_ds = nih_ds.filter(
    partial(filter_value, column_name="No Finding", value=1, negate=True), batched=True
)

# remove the No Finding column and adjust the predictions to account for it
nih_ds = nih_ds.map(
    lambda x: {
        "predictions": x["predictions"][:14],
    },
    remove_columns=["No Finding"],
)
nih_ds.features

In [None]:
pathologies = model.pathologies[:14]

auroc = create_metric(
    metric_name="auroc",
    task="multilabel",
    num_labels=len(pathologies),
    thresholds=np.arange(0, 1, 0.01),
)

# define the slices
slices = [
    {"Patient Gender": {"value": "M"}},
    {"Patient Gender": {"value": "F"}},
]

# create the slice functions
slice_spec = SliceSpec(spec_list=slices)

nih_eval_results = evaluator.evaluate(
    dataset=nih_ds,
    metrics=auroc,
    feature_columns="features",
    target_columns=pathologies,
    prediction_column_prefix="predictions",
    remove_columns="features",
    slice_spec=slice_spec,
)

# plot the results
plots = []

for slice_name, slice_results in nih_eval_results.items():
    plots.append(
        go.Scatter(
            x=pathologies,
            y=slice_results["MultilabelAUROC"],
            name="Overall" if slice_name == "overall" else slice_name,
            mode="markers",
        )
    )
layout = go.Layout(
    paper_bgcolor="rgba(0,0,0,0)",
    plot_bgcolor="rgba(0,0,0,0)",
    font=dict(color="white"),
)

perf_metric_gender = go.Figure(data=plots, layout=layout)
perf_metric_gender.update_layout(
    # title="Multilabel AUROC by Pathology and Slice",
    title_x=0.5,
    title_font_size=20,
    xaxis_title="Pathology",
    yaxis_title="Multilabel AUROC",
    width=1024,
    height=768,
)
perf_metric_gender.update_traces(
    marker=dict(size=12, line=dict(width=2, color="DarkSlateGrey")),
    selector=dict(mode="markers"),
)
# perf_metric_gender = fig.to_image(format="svg")
# perf_metric_gender = plot_to_str_plotly(fig)

In [None]:
pathologies = model.pathologies[:14]

auroc = create_metric(
    metric_name="auroc",
    task="multilabel",
    num_labels=len(pathologies),
    thresholds=np.arange(0, 1, 0.01),
)

# define the slices
slices = [
    {"Patient Age": {"min_value": 19, "max_value": 35}},
    {"Patient Age": {"min_value": 35, "max_value": 65}},
    {"Patient Age": {"min_value": 65, "max_value": 100}},
]

# create the slice functions
slice_spec = SliceSpec(spec_list=slices)

nih_eval_results = evaluator.evaluate(
    dataset=nih_ds,
    metrics=auroc,
    feature_columns="features",
    target_columns=pathologies,
    prediction_column_prefix="predictions",
    remove_columns="features",
    slice_spec=slice_spec,
)


# plot the results
plots = []

for slice_name, slice_results in nih_eval_results.items():
    plots.append(
        go.Scatter(
            x=pathologies,
            y=slice_results["MultilabelAUROC"],
            name="Overall" if slice_name == "overall" else slice_name,
            mode="markers",
        )
    )
layout = go.Layout(
    paper_bgcolor="rgba(0,0,0,0)",
    plot_bgcolor="rgba(0,0,0,0)",
    font=dict(color="white"),
)

perf_metric_age = go.Figure(data=plots, layout=layout)
perf_metric_age.update_layout(
    # title="Multilabel AUROC by Pathology and Slice",
    title_x=0.5,
    title_font_size=20,
    xaxis_title="Pathology",
    yaxis_title="Multilabel AUROC",
    width=1024,
    height=768,
)
perf_metric_age.update_traces(
    marker=dict(size=12, line=dict(width=2, color="DarkSlateGrey")),
    selector=dict(mode="markers"),
)
# perf_metric_age = fig.to_image(format="svg")
# perf_metric_age = plot_to_str_plotly(fig)

In [None]:
specificity = create_metric(
    metric_name="specificity",
    task="multilabel",
    num_labels=len(pathologies),
)
sensitivity = create_metric(
    metric_name="sensitivity",
    task="multilabel",
    num_labels=len(pathologies),
)

fpr = 1 - specificity
fnr = 1 - sensitivity

balanced_error_rate = (fpr + fnr) / 2

nih_fairness_result = evaluate_fairness(
    metrics=balanced_error_rate,
    metric_name="BalancedErrorRate",
    dataset=nih_ds,
    remove_columns="features",
    target_columns=pathologies,
    prediction_columns="predictions",
    groups=["Patient Age", "Patient Gender"],
    group_bins={"Patient Age": [20, 40, 60, 80]},
    group_base_values={"Patient Age": 20, "Patient Gender": "M"},
)

In [None]:
# plot metrics per slice
plots = []

for slice_name, slice_results in nih_fairness_result.items():
    plots.append(
        go.Scatter(
            x=pathologies,
            y=slice_results["BalancedErrorRate"],
            name=slice_name,
            mode="markers",
        )
    )
layout = go.Layout(
    paper_bgcolor="rgba(0,0,0,0)",
    plot_bgcolor="rgba(0,0,0,0)",
    font=dict(color="white"),
)
fairness_1 = go.Figure(data=plots, layout=layout)
fairness_1.update_layout(
    # title="Balanced Error Rate by Pathology and Group",
    title_x=0.5,
    title_font_size=20,
    xaxis_title="Pathology",
    yaxis_title="Balanced Error Rate",
    width=1024,
    height=768,
)
fairness_1.update_traces(
    marker=dict(size=12, line=dict(width=2, color="DarkSlateGrey")),
    selector=dict(mode="markers"),
)
# fairness_1 = fig.to_image(format="svg")
# fairness_1 = plot_to_str_plotly(fig)

In [None]:
# plot parity difference per slice
plots = []

for slice_name, slice_results in nih_fairness_result.items():
    plots.append(
        go.Scatter(
            x=pathologies,
            y=slice_results["BalancedErrorRate Parity"],
            name=slice_name,
            mode="markers",
        )
    )
layout = go.Layout(
    paper_bgcolor="rgba(0,0,0,0)",
    plot_bgcolor="rgba(0,0,0,0)",
    font=dict(color="white"),
)
fairness_2 = go.Figure(data=plots, layout=layout)
fairness_2.update_layout(
    # title="Balanced Error Rate Parity by Pathology and Group",
    title_x=0.5,
    title_font_size=20,
    xaxis_title="Pathology",
    yaxis_title="Balanced Error Rate Parity",
    width=1024,
    height=768,
)
fairness_2.update_traces(
    marker=dict(size=12, line=dict(width=2, color="DarkSlateGrey")),
    selector=dict(mode="markers"),
)
# fairness_2 = fig.to_image(format="svg")
# fairness_2 = plot_to_str_plotly(fig)

## Populate model card fields

In [None]:
report = ModelCardReport()

# model details for NIH Chest X-Ray model
report.log_from_dict(
    data={
        "name": "NIH Chest X-Ray Multi-label Classification Model",
        "description": "This model is a DenseNet121 model trained on the NIH Chest \
            X-Ray dataset, which contains 112,120 frontal-view X-ray images of 30,805 \
            unique patients with the fourteen text-mined disease labels from the \
            associated radiological reports. The labels are Atelectasis, Cardiomegaly, \
            Effusion, Infiltration, Mass, Nodule, Pneumonia, Pneumothorax, \
            Consolidation, Edema, Emphysema, Fibrosis, Pleural Thickening, and Hernia. \
            The model was trained on 80% of the data and evaluated on the remaining \
            20%.",
        "references": [{"link": "https://arxiv.org/abs/2111.00595"}],
    },
    section_name="Model Details",
)

report.log_citation(
    citation="""@inproceedings{Cohen2022xrv,
    title = {{TorchXRayVision: A library of chest X-ray datasets and models}},
    author = {Cohen, Joseph Paul and Viviano, Joseph D. and Bertin, \
    Paul and Morrison,Paul and Torabian, Parsa and Guarrera, \
    Matteo and Lungren, Matthew P and Chaudhari,\
    Akshay and Brooks, Rupert and Hashir, \
    Mohammad and Bertrand, Hadrien},
    booktitle = {Medical Imaging with Deep Learning},
    url = {https://github.com/mlmed/torchxrayvision},
    arxivId = {2111.00595},
    year = {2022}
    }""",
)

report.log_citation(
    citation="""@inproceedings{cohen2020limits,
    title={On the limits of cross-domain generalization\
            in automated X-ray prediction},
    author={Cohen, Joseph Paul and Hashir, Mohammad and Brooks, \
        Rupert and Bertrand, Hadrien},
    booktitle={Medical Imaging with Deep Learning},
    year={2020},
    url={https://arxiv.org/abs/2002.02497}
    }""",
)

report.log_owner(name="Machine Learning and Medicine Lab", contact="mlmed.org")

# considerations
report.log_user(description="Radiologists")
report.log_user(description="Data Scientists")

report.log_use_case(
    description="The model can be used to predict the presence of 14 pathologies \
        in chest X-ray images.",
    kind="primary",
)
report.log_descriptor(
    name="limitations",
    description="The limitations of this model include its inability to detect \
                pathologies that are not included in the 14 labels of the NIH \
                Chest X-Ray dataset. Additionally, the model may not perform \
                well on images that are of poor quality or that contain \
                artifacts. Finally, the model may not generalize well to\
                populations that are not well-represented in the training \
                data, such as patients from different geographic regions or \
                with different demographics.",
    section_name="considerations",
)
report.log_descriptor(
    name="tradeoffs",
    description="The model can help radiologists to detect pathologies in \
        chest X-ray images, but it may not generalize well to populations \
        that are not well-represented in the training data.",
    section_name="considerations",
)
report.log_risk(
    risk="One ethical risk of the model is that it may not generalize well to \
        populations that are not well-represented in the training data,\
        such as patients from different geographic regions \
        or with different demographics. ",
    mitigation_strategy="A mitigation strategy for this risk is to ensure \
        that the training data is diverse and representative of the population \
            that the model will be used on. Additionally, the model should be \
            regularly evaluated and updated to ensure that it continues to \
            perform well on diverse populations. Finally, the model should \
            be used in conjunction with human expertise to ensure that \
            any biases or limitations are identified and addressed.",
)
report.log_fairness_assessment(
    affected_group="Patients with rare pathologies",
    benefit="The model can help radiologists to detect pathologies in \
        chest X-ray images.",
    harm="The model may not generalize well to populations that are not \
        well-represented in the training data.",
    mitigation_strategy="A mitigation strategy for this risk is to ensure that \
        the training data is diverse and representative of the population.",
)


# qualitative analysis
report.log_plotly_figure(
    fig=perf_metric_gender,
    alt="MultiLabel AUROC by Pathology",
    section_name="Quantitative Analysis",
)
report.log_plotly_figure(
    fig=perf_metric_age,
    caption="MultiLabel AUROC by Pathology",
    section_name="Quantitative Analysis",
)

# fairness analysis
report.log_from_dict(
    data={
        "fairness_reports": [
            {
                "type": "Balanced Error Rate by Pathology",
                "slice": None,
                "segment": "Age and Gender",
                "description": None,
            },
            {
                "type": "Balanced Error Rate Parity by Pathology",
                "slice": None,
                "segment": "Age and Gender",
                "description": None,
            },
        ]
    },
    section_name="Fairness Analysis",
)

report.log_plotly_figure(
    fig=fairness_1,
    alt="Balanced Error Rate by Pathology",
    section_name="Fairness Analysis",
)
report.log_plotly_figure(
    fig=fairness_2,
    alt="Balanced Error Rate Parity by Pathology",
    section_name="Fairness Analysis",
)

# explainability analysis
report.log_from_dict(
    data={
        "explainability_reports": [
            {
                "type": "Drift Sensitivity Experiment",
                "slice": "Age and Sex",
                "description": "Conduct sensitivity experiments to determine if \
                    the model is sensitive to changes in the input data by slicing \
                    the data along patient attributes and increasing the prevalence \
                    of the attribute in the data.",
            }
        ]
    },
    section_name="Explainability Analysis",
)

# report.log_graphic(
#     fig=drift_plot,
#     caption="Drift Sensitivity Experiment",
#     section_name="Explainability Analysis",
# )

In [None]:
report.export()