In [None]:
import base64
import inspect
import re
from functools import partial
from io import BytesIO

import jinja2
import numpy as np
import plotly.graph_objects as go
import torch
from datasets.combine import concatenate_datasets  # noqa: E402
from monai.transforms import AddChanneld, Compose, Lambdad, Resized, ToDeviced
from pybtex.database import parse_string
from pybtex.plugin import find_plugin
from pydantic import BaseModel
from pydantic.main import ModelMetaclass
from torchxrayvision.models import DenseNet

from cyclops.data.loader import load_nihcxr
from cyclops.data.slicer import filter_value  # noqa: E402
from cyclops.data.slicer import SliceSpec
from cyclops.data.utils import apply_transforms
from cyclops.evaluate import evaluator
from cyclops.evaluate.fairness import evaluate_fairness  # noqa: E402
from cyclops.evaluate.metrics.factory import create_metric
from cyclops.monitor import ClinicalShiftApplicator, Detector, Reductor, TSTester
from cyclops.monitor.plotter import plot_drift_experiment
from cyclops.monitor.utils import get_device
from cyclops.report.model_card.model_card import (
    Citation,
    ExplainabilityReport,
    FairnessAssessment,
    FairnessReport,
    Graphic,
    GraphicsCollection,
    Limitation,
    ModelCard,
    Owner,
    PerformanceMetric,
    Reference,
    Risk,
    Tradeoff,
    UseCase,
    User,
)

device = get_device()


def plot_to_str(fig, dpi=300, transparent=True):
    img = BytesIO()
    fig.savefig(img, format="png", dpi=dpi, transparent=transparent)
    return f'data:image/{format};base64,\
        {base64.encodebytes(img.getvalue()).decode("utf-8")}'


def plot_to_str_plotly(fig, scale=2):
    img = BytesIO()
    fig.write_image(
        img,
        format="png",
        scale=scale,
    )
    return f'data:image/{format};base64,\
        {base64.encodebytes(img.getvalue()).decode("utf-8")}'

In [None]:
nih_ds = load_nihcxr("/mnt/data/clinical_datasets/NIHCXR")

In [None]:
transforms = Compose(
    [
        AddChanneld(keys=("features",), allow_missing_keys=True),
        Resized(
            keys=("features",), spatial_size=(1, 224, 224), allow_missing_keys=True
        ),
        Lambdad(
            keys=("features",),
            func=lambda x: ((2 * (x / 255.0)) - 1.0) * 1024,
            allow_missing_keys=True,
        ),
        ToDeviced(keys=("features",), device=device, allow_missing_keys=True),
    ]
)

model = DenseNet(weights="densenet121-res224-nih")
source_slice = None
target_slices = {
    "SEX: MALE": SliceSpec(spec_list=[{"Patient Gender": {"value": "M"}}]),
    "SEX: FEMALE": SliceSpec(spec_list=[{"Patient Gender": {"value": "F"}}]),
    "AGE: 18-35": SliceSpec(
        spec_list=[{"Patient Age": {"min_value": 18, "max_value": 35}}]
    ),
    "AGE: 35-65": SliceSpec(
        spec_list=[{"Patient Age": {"min_value": 35, "max_value": 65}}]
    ),
}
results = {}

for name, target_slice in target_slices.items():
    source_slice = None
    shifter = ClinicalShiftApplicator(
        "custom", source=source_slice, target=target_slice
    )
    ds_source, ds_target = shifter.apply_shift(nih_ds, num_proc=6)

    ds_source = ds_source.with_transform(
        partial(apply_transforms, transforms=transforms),
        columns=["features"],
        output_all_columns=True,
    )
    ds_target = ds_target.with_transform(
        partial(apply_transforms, transforms=transforms),
        columns=["features"],
        output_all_columns=True,
    )

    detector = Detector(
        "sensitivity_test",
        reductor=Reductor(dr_method="bbse-soft", model=model, device=device),
        tester=TSTester(tester_method="mmd"),
        source_sample_size=1000,
        target_sample_size=[50, 100, 200, 400, 800, 1000],
        num_runs=3,
    )
    results[name] = detector.detect_shift(ds_source, ds_target)
fig = plot_drift_experiment(results, axes_color="white")

drift_plot = plot_to_str(fig)

In [None]:
model.to(device)
model.eval()


def get_predictions_torch(examples):
    images = torch.stack(examples["features"]).squeeze(1)
    preds = model(images)
    return {"predictions": preds}


with nih_ds.formatted_as(
    "custom",
    columns=["features"],
    transform=partial(apply_transforms, transforms=transforms),
):
    preds_ds = nih_ds.map(
        get_predictions_torch,
        batched=True,
        batch_size=64,
        remove_columns=nih_ds.column_names,
    )

    nih_ds = concatenate_datasets([nih_ds, preds_ds], axis=1)

In [None]:
# remove any rows with No Finding == 1
nih_ds = nih_ds.filter(
    partial(filter_value, column_name="No Finding", value=1, negate=True), batched=True
)

# remove the No Finding column and adjust the predictions to account for it
nih_ds = nih_ds.map(
    lambda x: {
        "predictions": x["predictions"][:14],
    },
    remove_columns=["No Finding"],
)
nih_ds.features

In [None]:
pathologies = model.pathologies[:14]

auroc = create_metric(
    metric_name="auroc",
    task="multilabel",
    num_labels=len(pathologies),
    thresholds=np.arange(0, 1, 0.01),
)

# define the slices
slices = [
    {"Patient Gender": {"value": "M"}},
    {"Patient Gender": {"value": "F"}},
]

# create the slice functions
slice_spec = SliceSpec(spec_list=slices)

nih_eval_results = evaluator.evaluate(
    dataset=nih_ds,
    metrics=auroc,
    feature_columns="features",
    target_columns=pathologies,
    prediction_column_prefix="predictions",
    remove_columns="features",
    slice_spec=slice_spec,
)

# plot the results
plots = []

for slice_name, slice_results in nih_eval_results.items():
    plots.append(
        go.Scatter(
            x=pathologies,
            y=slice_results["MultilabelAUROC"],
            name="Overall" if slice_name == "overall" else slice_name,
            mode="markers",
        )
    )
layout = go.Layout(
    paper_bgcolor="rgba(0,0,0,0)",
    plot_bgcolor="rgba(0,0,0,0)",
    font=dict(color="white"),
)

fig = go.Figure(data=plots, layout=layout)
fig.update_layout(
    # title="Multilabel AUROC by Pathology and Slice",
    title_x=0.5,
    title_font_size=20,
    xaxis_title="Pathology",
    yaxis_title="Multilabel AUROC",
    width=1024,
    height=768,
)
fig.update_traces(
    marker=dict(size=12, line=dict(width=2, color="DarkSlateGrey")),
    selector=dict(mode="markers"),
)
# perf_metric_gender = fig.to_image(format="svg")
perf_metric_gender = plot_to_str_plotly(fig)

In [None]:
pathologies = model.pathologies[:14]

auroc = create_metric(
    metric_name="auroc",
    task="multilabel",
    num_labels=len(pathologies),
    thresholds=np.arange(0, 1, 0.01),
)

# define the slices
slices = [
    {"Patient Age": {"min_value": 19, "max_value": 35}},
    {"Patient Age": {"min_value": 35, "max_value": 65}},
    {"Patient Age": {"min_value": 65, "max_value": 100}},
]

# create the slice functions
slice_spec = SliceSpec(spec_list=slices)

nih_eval_results = evaluator.evaluate(
    dataset=nih_ds,
    metrics=auroc,
    feature_columns="features",
    target_columns=pathologies,
    prediction_column_prefix="predictions",
    remove_columns="features",
    slice_spec=slice_spec,
)


# plot the results
plots = []

for slice_name, slice_results in nih_eval_results.items():
    plots.append(
        go.Scatter(
            x=pathologies,
            y=slice_results["MultilabelAUROC"],
            name="Overall" if slice_name == "overall" else slice_name,
            mode="markers",
        )
    )
layout = go.Layout(
    paper_bgcolor="rgba(0,0,0,0)",
    plot_bgcolor="rgba(0,0,0,0)",
    font=dict(color="white"),
)

fig = go.Figure(data=plots, layout=layout)
fig.update_layout(
    # title="Multilabel AUROC by Pathology and Slice",
    title_x=0.5,
    title_font_size=20,
    xaxis_title="Pathology",
    yaxis_title="Multilabel AUROC",
    width=1024,
    height=768,
)
fig.update_traces(
    marker=dict(size=12, line=dict(width=2, color="DarkSlateGrey")),
    selector=dict(mode="markers"),
)
# perf_metric_age = fig.to_image(format="svg")
perf_metric_age = plot_to_str_plotly(fig)

In [None]:
specificity = create_metric(
    metric_name="specificity",
    task="multilabel",
    num_labels=len(pathologies),
)
sensitivity = create_metric(
    metric_name="sensitivity",
    task="multilabel",
    num_labels=len(pathologies),
)

fpr = 1 - specificity
fnr = 1 - sensitivity

balanced_error_rate = (fpr + fnr) / 2

nih_fairness_result = evaluate_fairness(
    metrics=balanced_error_rate,
    metric_name="BalancedErrorRate",
    dataset=nih_ds,
    remove_columns="features",
    target_columns=pathologies,
    prediction_columns="predictions",
    groups=["Patient Age", "Patient Gender"],
    group_bins={"Patient Age": [20, 40, 60, 80]},
    group_base_values={"Patient Age": 20, "Patient Gender": "M"},
)

In [None]:
# plot metrics per slice
plots = []

for slice_name, slice_results in nih_fairness_result.items():
    plots.append(
        go.Scatter(
            x=pathologies,
            y=slice_results["BalancedErrorRate"],
            name=slice_name,
            mode="markers",
        )
    )
layout = go.Layout(
    paper_bgcolor="rgba(0,0,0,0)",
    plot_bgcolor="rgba(0,0,0,0)",
    font=dict(color="white"),
)
fig = go.Figure(data=plots, layout=layout)
fig.update_layout(
    # title="Balanced Error Rate by Pathology and Group",
    title_x=0.5,
    title_font_size=20,
    xaxis_title="Pathology",
    yaxis_title="Balanced Error Rate",
    width=1024,
    height=768,
)
fig.update_traces(
    marker=dict(size=12, line=dict(width=2, color="DarkSlateGrey")),
    selector=dict(mode="markers"),
)
# fairness_1 = fig.to_image(format="svg")
fairness_1 = plot_to_str_plotly(fig)

In [None]:
# plot parity difference per slice
plots = []

for slice_name, slice_results in nih_fairness_result.items():
    plots.append(
        go.Scatter(
            x=pathologies,
            y=slice_results["BalancedErrorRate Parity"],
            name=slice_name,
            mode="markers",
        )
    )
layout = go.Layout(
    paper_bgcolor="rgba(0,0,0,0)",
    plot_bgcolor="rgba(0,0,0,0)",
    font=dict(color="white"),
)
fig = go.Figure(data=plots, layout=layout)
fig.update_layout(
    # title="Balanced Error Rate Parity by Pathology and Group",
    title_x=0.5,
    title_font_size=20,
    xaxis_title="Pathology",
    yaxis_title="Balanced Error Rate Parity",
    width=1024,
    height=768,
)
fig.update_traces(
    marker=dict(size=12, line=dict(width=2, color="DarkSlateGrey")),
    selector=dict(mode="markers"),
)
# fairness_2 = fig.to_image(format="svg")
fairness_2 = plot_to_str_plotly(fig)

## Bootstrap model card from VerifyML model card editor and scaffold assets

In [None]:
# Initialize the model card
def scaffold_model(base_model: BaseModel) -> BaseModel:
    """Recursively initialize a pydantic model with default values."""
    assert isinstance(
        base_model, BaseModel
    ), f"Expected a pydantic BaseModel instance, got {type(base_model)} instead."

    for field in base_model.__fields__:
        field_type = base_model.__fields__[field].type_

        if (
            type(field_type) is ModelMetaclass
            and base_model.__fields__[field].default_factory is None
        ):
            sub_model = scaffold_model(field_type())
            setattr(base_model, field, sub_model)
        else:
            default = base_model.__fields__[field].default
            if base_model.__fields__[field].default_factory is not None:
                default = base_model.__fields__[field].default_factory()
            setattr(base_model, field, default)
    return base_model

## Populate model card fields

In [None]:
mc = ModelCard()
mc = scaffold_model(mc)

# model details for NIH Chest X-Ray model
mc.model_details.name = "NIH Chest X-Ray Multi-label Classification Model"

mc.model_details.overview = (
    "This model is a DenseNet121 model trained on the NIH Chest X-Ray dataset."
)

mc.model_details.documentation = "The model was trained on the NIH Chest X-Ray dataset,\
    which contains 112,120 frontal-view X-ray images of 30,805 unique patients with the\
    fourteen text-mined disease labels from the associated radiological reports.\
    The labels are Atelectasis, Cardiomegaly, Effusion, Infiltration, Mass, Nodule,\
    Pneumonia, Pneumothorax, Consolidation, Edema, Emphysema, Fibrosis,\
    Pleural Thickening, and Hernia. The model was trained on 80% of the data\
    and evaluated on the remaining 20%."
mc.model_details.references.append(
    Reference(reference="https://arxiv.org/abs/2111.00595")
)
mc.model_details.citations.append(
    Citation(
        style="APA",
        citation="""@inproceedings{Cohen2022xrv,
        title = {{TorchXRayVision: A library of chest X-ray datasets and models}},
        author = {Cohen, Joseph Paul and Viviano, Joseph D. and Bertin, \
        Paul and Morrison,Paul and Torabian, Parsa and Guarrera, \
        Matteo and Lungren, Matthew P and Chaudhari,\
        Akshay and Brooks, Rupert and Hashir, \
        Mohammad and Bertrand, Hadrien},
        booktitle = {Medical Imaging with Deep Learning},
        url = {https://github.com/mlmed/torchxrayvision},
        arxivId = {2111.00595},
        year = {2022}
        }""",
    )
)

mc.model_details.citations.append(
    Citation(
        style="APA",
        citation="""@inproceedings{cohen2020limits,
        title={On the limits of cross-domain generalization\
             in automated X-ray prediction},
        author={Cohen, Joseph Paul and Hashir, Mohammad and Brooks, \
            Rupert and Bertrand, Hadrien},
        booktitle={Medical Imaging with Deep Learning},
        year={2020},
        url={https://arxiv.org/abs/2002.02497}
        }""",
    )
)


mc.model_details.owners = [
    Owner(name="Machine Learning and Medicine Lab", contact="mlmed.org")
]

# considerations
mc.considerations.users.extend(
    [User(description="Radiologists"), User(description="Data Scientists")]
)
mc.considerations.use_cases.append(
    UseCase(
        description="The model can be used to predict the presence of 14 pathologies \
            in chest X-ray images."
    )
)
mc.considerations.limitations.append(
    Limitation(
        # describe limits of chest x-ray classification model
        description="The limitations of this model include its inability to detect \
                    pathologies that are not included in the 14 labels of the NIH \
                    Chest X-Ray dataset. Additionally, the model may not perform \
                    well on images that are of poor quality or that contain \
                    artifacts. Finally, the model may not generalize well to\
                    populations that are not well-represented in the training \
                    data, such as patients from different geographic regions or \
                    with different demographics."
    )
)
mc.considerations.tradeoffs.append(
    Tradeoff(
        description="The model can help radiologists to detect pathologies in \
            chest X-ray images, but it may not generalize well to populations \
            that are not well-represented in the training data."
    )
)
mc.considerations.ethical_considerations.append(
    Risk(
        name="One ethical risk of the model is that it may not generalize well to \
            populations that are not well-represented in the training data,\
            such as patients from different geographic regions \
            or with different demographics. ",
        mitigation_strategy="A mitigation strategy for this risk is to ensure \
            that the training data is diverse and representative of the population \
              that the model will be used on. Additionally, the model should be \
                regularly evaluated and updated to ensure that it continues to \
                perform well on diverse populations. Finally, the model should \
                be used in conjunction with human expertise to ensure that \
                any biases or limitations are identified and addressed.",
    )
)
mc.considerations.fairness_assessment.append(
    FairnessAssessment(
        group_at_risk="Patients with rare pathologies",
        benefits="The model can help radiologists to detect pathologies in \
            chest X-ray images.",
        harms="The model may not generalize well to populations that are not \
            well-represented in the training data.",
        mitigation_strategy="A mitigation strategy for this risk is to ensure that \
            the training data is diverse and representative of the population.",
    )
)


# Create 4 PerformanceMetric to store our results
mc.quantitative_analysis.performance_metrics = [
    PerformanceMetric() for i in range(0, 2)
]

mc.quantitative_analysis.performance_metrics[0].type = "MultiLabel AUROC by Pathology"
mc.quantitative_analysis.performance_metrics[0].slice = "Male/Female"
# instantiate GraphicsCollection as workaround to store graphics for the plots.
mc.quantitative_analysis.performance_metrics[0].graphics = GraphicsCollection()
mc.quantitative_analysis.performance_metrics[0].graphics.collection = [
    Graphic(name="auroc_sex", image=perf_metric_gender)
]
mc.quantitative_analysis.performance_metrics[1].type = "MultiLabel AUROC by Pathology"
mc.quantitative_analysis.performance_metrics[1].slice = "Age Brackets"
# instantiate GraphicsCollection as workaround to store graphics for the plots.
mc.quantitative_analysis.performance_metrics[1].graphics = GraphicsCollection()
mc.quantitative_analysis.performance_metrics[1].graphics.collection = [
    Graphic(name="auroc_age", image=perf_metric_age)
]

mc.fairness_analysis.fairness_reports = [FairnessReport() for i in range(0, 2)]

mc.fairness_analysis.fairness_reports[0].type = "Balanced Error Rate by Pathology"
mc.fairness_analysis.fairness_reports[0].slice = None
mc.fairness_analysis.fairness_reports[0].segment = "Age and Gender"
mc.fairness_analysis.fairness_reports[0].description = None
# instantiate GraphicsCollection as workaround to store graphics for the plots.
mc.fairness_analysis.fairness_reports[0].graphics = GraphicsCollection()
mc.fairness_analysis.fairness_reports[0].graphics.collection = [
    Graphic(name="fairness_ber", image=fairness_1)
]

mc.fairness_analysis.fairness_reports[
    1
].type = "Balanced Error Rate Parity by Pathology"
mc.fairness_analysis.fairness_reports[1].slice = None
mc.fairness_analysis.fairness_reports[1].segment = "Age and Gender"
mc.fairness_analysis.fairness_reports[1].description = None
# instantiate GraphicsCollection as workaround to store graphics for the plots.
mc.fairness_analysis.fairness_reports[1].graphics = GraphicsCollection()
mc.fairness_analysis.fairness_reports[1].graphics.collection = [
    Graphic(name="fairness_berp", image=fairness_2)
]

mc.explainability_analysis.explainability_reports = [
    ExplainabilityReport() for i in range(0, 1)
]

mc.explainability_analysis.explainability_reports[
    0
].type = "Drift Sensitivity Experiment"
mc.explainability_analysis.explainability_reports[0].slice = "Age and Sex"
mc.explainability_analysis.explainability_reports[
    0
].description = "Conduct sensitivity experiments to determine if the model is \
    sensitive to changes in the input data by slicing the data along patient \
    attributes and increasing the prevalence of the attribute in the data."
# instantiate GraphicsCollection as workaround to store graphics for the plots.
mc.explainability_analysis.explainability_reports[0].graphics = GraphicsCollection()
mc.explainability_analysis.explainability_reports[0].graphics.collection = [
    Graphic(name="drift_exp", image=drift_plot)
]


jinja_env = jinja2.Environment(
    loader=jinja2.FileSystemLoader("../model_card/template/"),
    autoescape=True,
    auto_reload=True,
    cache_size=0,
)


# Custom filter method
def regex_replace(s, find, replace):
    """A non-optimal implementation of a regex filter"""
    return re.sub(find, replace, s)


jinja_env.filters["regex_replace"] = regex_replace

jinja_env.tests["list"] = lambda x: isinstance(x, list)

jinja_env.tests["class"] = lambda x: inspect.isclass(x)


def empty(x):
    empty = True
    for _, obj in x:
        if isinstance(obj, list):
            if len(obj) > 0:
                empty = False
        elif isinstance(obj, GraphicsCollection):
            if len(obj.collection) > 0:
                empty = False
        elif obj is not None:
            empty = False
    return empty


jinja_env.tests["empty"] = empty


def bib2html(citation, style, exclude_fields=None):
    HTML = find_plugin("pybtex.backends", "html")()
    style = style.lower()
    if style == "apa":
        style = find_plugin("pybtex.style.formatting", style)()
    else:
        style = None
    bibliography = parse_string(citation, "bibtex")
    exclude_fields = exclude_fields or []
    if exclude_fields:
        for entry in bibliography.entries.values():
            for ef in exclude_fields:
                if ef in entry.fields.__dict__["_dict"]:
                    del entry.fields.__dict__["_dict"][ef]
    if style:
        bibliography = style.format_bibliography(bibliography)
    return "<br>".join(entry.text.render(HTML) for entry in bibliography)


def render_citation(obj):
    return bib2html(obj.citation, obj.style)


jinja_env.filters["render_citation"] = render_citation

template = jinja_env.get_template("cyclops_generic_template_dark.jinja")

content = template.render(model_card=mc)

with open("report.html", "w+") as f:
    f.write(content)