# Exploring the Extensibility of the 🤗 Datasets Library for Medical Images

In [None]:
import datetime
import glob
import os
from functools import partial
from typing import Dict, List

import dask
import dask.dataframe as dd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import plotly.graph_objects as go
import psutil
import seaborn as sns
import torch
import torchxrayvision as xrv
import yaml
from datasets import Dataset, load_dataset
from datasets.features import ClassLabel, Image
from datasets.splits import Split
from monai.transforms import (
    AddChanneld,
    CenterSpatialCropd,
    Compose,
    Lambdad,
    ToDeviced,
)
from omegaconf import OmegaConf
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from torchvision.transforms import PILToTensor

from cyclops.data.slicer import SliceSpec
from cyclops.evaluate.metrics import MetricCollection, create_metric
from cyclops.models.catalog import create_model, list_models
from cyclops.models.constants import CONFIG_ROOT
from cyclops.utils.file import join
from use_cases.params.mimiciv.mortality_decompensation.constants_v1 import (
    ENCOUNTERS_FILE,
    QUERIED_DIR,
    TAB_FEATURES,
)

In [None]:
# CONSTANTS
NUM_PROC = 4
TORCH_BATCH_SIZE = 64

## Exploring existing functionalities that are relevant to CyclOps

### Tabular Data

#### Constructing a 🤗 Dataset from MIMICIV-v2.0 PostgreSQL Database

In [None]:
db_cfg = OmegaConf.load(join("..", "cyclops", "query", "configs", "config.yaml"))

con_str = (
    db_cfg.dbms
    + "://"
    + db_cfg.user
    + ":"
    + db_cfg.password
    + "@"
    + db_cfg.host
    + "/"
    + db_cfg.database
)

ds = Dataset.from_sql(
    sql="SELECT * FROM mimiciv_hosp.patients LIMIT 1000",
    con=con_str,
    keep_in_memory=True,
)
ds

#### Constructing a 🤗 Dataset from local parquet files

In [None]:
parquet_files = list(glob.glob(join(QUERIED_DIR, "*.parquet")))
len(parquet_files)

In [None]:
# take the first 300 files
parquet_files = parquet_files[:300]

In [None]:
mimiciv_ds = load_dataset(
    "parquet", data_files=parquet_files, split=Split.ALL, num_proc=NUM_PROC
)

# clear all other cache files, except for the current cache file
mimiciv_ds.cleanup_cache_files()

In [None]:
size_gb = mimiciv_ds.dataset_size / (1024**3)
print(f"Dataset size (cache file) : {size_gb:.2f} GB")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

In [None]:
mimiciv_ds.features

##### Benchmarking Filtering operations: 🤗 Dataset vs. Dask

In [None]:
dask.config.set(scheduler="processes", num_workers=NUM_PROC)

ddf = dd.read_parquet(parquet_files)
len(ddf)

1. **Filtering on 1 column**

Get all rows where the values in column `event_cateogry` is in a list of values.

In [None]:
event_filter = [
    "Cadiovascular",
    "Dialysis",
    "Hemodynamics",
    "Neurological",
    "Toxicology",
    "General",
]

In [None]:
%%timeit
events_ddf = ddf[ddf["event_category"].isin(event_filter)].compute()

In [None]:
%%timeit
events_ds = mimiciv_ds.filter(
    lambda examples: [
        example in event_filter for example in examples["event_category"]
    ],
    batched=True,
    num_proc=NUM_PROC,
    load_from_cache_file=False,  # timeit will run multiple times
)

2. **Filtering on multiple columns**

Get all items where the values in two columns are in a list of values for each column.

In [None]:
discharge_location_filter = ["HOME", "HOME HEALTH CARE"]
admission_location_filter = [
    "TRANSFER FROM HOSPITAL",
    "PHYSICIAN REFERRAL",
    "CLINIC REFERRAL",
]

In [None]:
%%timeit

location_ddf = ddf[
    (ddf["discharge_location"].isin(discharge_location_filter))
    & (ddf["admission_location"].isin(admission_location_filter))
].compute()

In [None]:
%%timeit

location_ds = mimiciv_ds.filter(
    lambda examples: [
        example[0] in discharge_location_filter
        and example[1] in admission_location_filter
        for example in zip(
            examples["discharge_location"], examples["admission_location"]
        )
    ],
    batched=True,
    num_proc=NUM_PROC,
    load_from_cache_file=False,  # timeit will run multiple times
)

3. **Filtering on a datetime condition**

Get all rows where `date of death` occurred after January 1, 2020.

In [None]:
%%timeit
dod_ddf = ddf[ddf["dod"] > datetime.datetime(2020, 1, 1)].compute()

In [None]:
%%timeit

dod_ds = mimiciv_ds.filter(
    lambda examples: [
        example is not None and example > datetime.datetime(2020, 1, 1)
        for example in examples["dod"]
    ],
    batched=True,
    num_proc=NUM_PROC,
    load_from_cache_file=False,  # timeit will run multiple times
)

4. **Filter on a condition on a column**

In [None]:
%%timeit
millenials_ddf = ddf[(ddf.age <= 40) & (ddf.age >= 25)].compute()

In [None]:
%%timeit
millenials_ds = mimiciv_ds.filter(
    lambda examples: [25 <= example <= 40 for example in examples["age"]],
    batched=True,
    num_proc=NUM_PROC,
    load_from_cache_file=False,  # timeit will run multiple times
)

### Image Data - Constructing a 🤗 Dataset from image folder

From the 🤗 Datasets documentation, there are 3 ways to load local image data into a 🤗 Dataset:
1. **Load images from a folder with the following structure:**
    ```bash
    root_folder/train/class1/img1.png
    root_folder/train/class1/img2.png
    root_folder/train/class2/img1.png
    root_folder/train/class2/img2.png
    root_folder/test/class1/img1.png
    root_folder/test/class1/img2.png
    root_folder/test/class2/img1.png
    root_folder/test/class2/img2.png
    ...
    ```
    The folder names are the class names and the dataset splits (train/test) will automatically be recognized.
    The dataset can be loaded using the following code:
    ```python
    from datasets import load_dataset
    dataset = load_dataset("imagefolder", data_dir="root_folder")
    ```
    (This method also supports loading remote image folders from URLs.)
    
    The downside of this approach is that it uses PIL to load the images, which does not support many medical image formats like DICOM and NIfTI.

2. **Load images using a list of image paths**
    ```python
    from datasets import Dataset
    from datasets.features import Image
    dataset = Dataset.from_dict({"image": ["path/to/img1.png", "path/to/img2.png", ...]}).cast_column("image", Image())
    ```
    This approach is more flexible than the previous one, but it still has the same limitation of not supporting many medical image formats.

3. **Create a dataset loading script**

    This is the most flexible way to load and share different types of datasets that are not natively supported by 🤗 Datasets library.
    In fact, the `imagefolder` dataset is an example of a dataset loading script. In essence, we can extend that script to support more image formats like DICOM and NIfTI. That solves half the problem. The other half is that we need to create a new feature to extend the `Image` class to support decoding medical image formats.

#### Case Study: MIMIC-CXR-JPG v2.0.0

For this case study, we will combine CSV metadata and the `Image` feature to create a 🤗 Dataset from the MIMIC-CXR-JPG v2.0.0 dataset. The dataset is available on [PhysioNet](https://physionet.org/content/mimic-cxr-jpg/2.0.0/).

The dataset comes with 4 compressed CSV metadata files. The metadata files are `mimic-cxr-2.0.0-split.csv.gz`, `mimic-cxr-2.0.0-chexpert.csv.gz`, `mimic-cxr-2.0.0-negbio.csv.gz`, and `mimic-cxr-2.0.0-metadata.csv.gz`. The `mimic-cxr-2.0.0-split.csv.gz` file contains the train/val/test split for each image. The `mimic-cxr-2.0.0-chexpert.csv.gz` file contains the CheXpert labels for each image. The `mimic-cxr-2.0.0-negbio.csv.gz` file contains the NegBio labels for each image. The `mimic-cxr-2.0.0-metadata.csv.gz` file contains other metadata for each image. All the metadata files can be joined on the `subject_id` and `study_id` columns.

In [None]:
mimic_cxr_jpg_dir = "/mnt/data/clinical_datasets/mimic-cxr-jpg-2.0.0"

In [None]:
# read metdata files using pandas
metadata_df = pd.read_csv(
    os.path.join(mimic_cxr_jpg_dir, "mimic-cxr-2.0.0-metadata.csv.gz")
)
negbio_df = pd.read_csv(
    os.path.join(mimic_cxr_jpg_dir, "mimic-cxr-2.0.0-negbio.csv.gz")
)
split_df = pd.read_csv(os.path.join(mimic_cxr_jpg_dir, "mimic-cxr-2.0.0-split.csv.gz"))

In [None]:
# join the 3 metadata dataframes on subject_id and study_id
metadata_df = metadata_df.merge(
    split_df, on=["subject_id", "study_id", "dicom_id"]
).merge(negbio_df, on=["subject_id", "study_id"])

In [None]:
# select rows with images in folder 'p10' i.e. subject_id starts with 10
metadata_df = metadata_df[metadata_df["subject_id"].astype(str).str.startswith("10")]

In [None]:
# create HuggingFace Dataset from pandas DataFrame
mimic_cxr_ds = Dataset.from_pandas(
    metadata_df[metadata_df.split == "train"], split="train", preserve_index=False
)
mimic_cxr_ds

In [None]:
# create a new column with the full path to the image:
# mimic_cxr_jpg_dir + "p10" + "p" + subject_id + study_id + dicom_id + ".jpg"


def get_filename(examples):
    subject_ids = examples["subject_id"]
    study_ids = examples["study_id"]
    dicom_ids = examples["dicom_id"]
    examples["image"] = [
        os.path.join(
            mimic_cxr_jpg_dir,
            "files",
            "p10",
            "p" + str(subject_id),
            "s" + str(study_id),
            dicom_id + ".jpg",
        )
        for subject_id, study_id, dicom_id in zip(subject_ids, study_ids, dicom_ids)
    ]
    return examples


mimic_cxr_ds = mimic_cxr_ds.map(
    get_filename,
    batched=True,
    num_proc=NUM_PROC,
    remove_columns=["dicom_id", "split", "Rows", "Columns"],
)
mimic_cxr_ds

In [None]:
mimic_cxr_ds = mimic_cxr_ds.cast_column("image", Image())
mimic_cxr_ds.features

In [None]:
from cyclops.data.utils import set_decode  # noqa: E402

set_decode(mimic_cxr_ds, decode=False)
mimic_cxr_ds[0]["image"]

In [None]:
set_decode(dataset=mimic_cxr_ds, decode=True)
mimic_cxr_ds[0]

## Extending 🤗 Dataset to Load DICOM (and NIfTI) images

In [None]:
%matplotlib widget


# code for plotting 3D images
# Taken from: https://www.datacamp.com/tutorial/matplotlib-3d-volumetric-data
def multi_slice_viewer(volume):
    fig, ax = plt.subplots()
    ax.volume = volume
    ax.index = volume.shape[0] // 2
    ax.imshow(volume[ax.index], cmap="gray")
    fig.canvas.mpl_connect("key_press_event", process_key)


def process_key(event):
    fig = event.canvas.figure
    ax = fig.axes[0]
    if event.key == "a":
        previous_slice(ax)
    elif event.key == "d":
        next_slice(ax)
    fig.canvas.draw()


def previous_slice(ax):
    """Go to the previous slice."""
    volume = ax.volume
    ax.index = (ax.index - 1) % volume.shape[0]  # wrap around using %
    ax.images[0].set_array(volume[ax.index])


def next_slice(ax):
    """Go to the next slice."""
    volume = ax.volume
    ax.index = (ax.index + 1) % volume.shape[0]
    ax.images[0].set_array(volume[ax.index])

In [None]:
ROOT_DIR = "/mnt/data/clinical_datasets/coherent-11-07-2022/dicom/"
# ROOT_DIR = "/mnt/data/clinical_datasets/pseudo_phi_dataset/Pseudo-PHI-DICOM-Data/"

dcm_files = glob.glob(ROOT_DIR + "/**/*.dcm", recursive=True)
len(dcm_files)

1. Create a new feature class that extends the `Image` class to support decoding medical image formats. Let's call it `MedicalImage`. This will use MONAI to decode the medical image formats.

In [None]:
from cyclops.data import MedicalImage  # noqa: E402

# or
# from cyclops.data.features import MedicalImage  # noqa: E402

In [None]:
dicom_ds = Dataset.from_dict({"image": dcm_files}).cast_column("image", MedicalImage())
print("Number of rows: ", dicom_ds.num_rows)
print("Features: ", dicom_ds.features)
print("Image column contents: ", list(dicom_ds[0]["image"].keys()))

In [None]:
img = dicom_ds[0]["image"]["array"].shape

2. Create a new dataset loading script that extends the `imagefolder` dataset 
loading script to support the `MedicalImage` feature class. We can call it 
`medical_imagefolder`. 

For cyclops, the dataset loading script can be found in `cyclops/datasets/packaged_loading_scripts`.
Our new dataset loading script can be used with `load_dataset` by simply passing
the string `"medical_imagefolder"` to the `path` argument. This works because
we haved added the path to the script to huggingface's _PACKAGED_DATASETS_MODULES
registry in `cyclops/datasets/__init__.py`. This means that `cyclops.data`
must be imported for the script to be registered.

In [None]:
med_ds = load_dataset("medicalimagefolder", data_files=dcm_files, split=Split.ALL)
print("Number of rows: ", med_ds.num_rows)
print("Features: ", med_ds.features)
print("Image column contents: ", list(med_ds[0]["image"].keys()))

In [None]:
med_img = med_ds[150]["image"]["array"]
multi_slice_viewer(med_img.T)

### Some Challenges

1. Handling metadata. What to do with it?
2. Encoding and decoding image bytes in the formats that are supported by the `MedicalImage` feature class.

## Exploring Training and Evaluation of Scikit-Learn and PyTorch Models

In [None]:
import cyclops.evaluate.evaluator as evaluator  # noqa: E402
from cyclops.evaluate.fairness import FairnessConfig  # noqa: E402
from cyclops.evaluate.fairness import evaluate_fairness  # noqa: E402

### Scikit-Learn

#### Data Loading

In [None]:
encounters_ds = load_dataset(
    "parquet", data_files=ENCOUNTERS_FILE, split=Split.ALL, keep_in_memory=True
)
encounters_ds.cleanup_cache_files()
encounters_ds

In [None]:
# split into train and test - 0.6, 0.4
# NOTE: train_test_split does not work with IterableDataset objects
encounters_ds = encounters_ds.cast_column(TAB_FEATURES[-1], ClassLabel(num_classes=2))
encounters_ds = encounters_ds.train_test_split(
    test_size=0.4, seed=42, stratify_by_column=TAB_FEATURES[-1]
)
encounters_ds

In [None]:
TAB_FEATURES

#### Pre-processing

In [None]:
# pre-processing pipeline
numeric_features = [0]  # ['age']
numeric_transformer = Pipeline(
    steps=[("imputer", SimpleImputer(strategy="median")), ("scaler", StandardScaler())]
)

categorical_features = [1, 2, 3]  # ['sex', 'admission_type', 'admission_location']
categorical_transformer = OneHotEncoder(handle_unknown="ignore")

preprocessor = ColumnTransformer(
    transformers=[
        ("num", numeric_transformer, numeric_features),
        ("cat", categorical_transformer, categorical_features),
    ]
)

In [None]:
# get a count of the positive and negative samples
import pyarrow.compute as pc  # noqa: E402

value_counts = pc.value_counts(encounters_ds["train"]._data[TAB_FEATURES[-1]]).tolist()
pos_count = value_counts[1]["counts"]
neg_count = value_counts[0]["counts"]

#### Training

In [None]:
model_dict = {}

for model_name in list_models("sklearn"):
    if "classifier" not in model_name:  # use only classifiers
        continue

    # load the config file for the model
    config_path = join(CONFIG_ROOT, model_name + ".yaml")
    with open(config_path, "r") as f:
        cfg = yaml.safe_load(f)

    if model_name == "xgb_classifier":
        # set the scale_pos_weight parameter to account for the class imbalance
        cfg["scale_pos_weight"] = neg_count / pos_count

    model_dict[model_name] = create_model(model_name, **cfg)

In [None]:
for model_name, model in model_dict.items():
    print(f"Training {model_name}...")
    model_dict[model_name] = model.fit(
        encounters_ds["train"],
        feature_columns=TAB_FEATURES[:-1],
        target_columns=[TAB_FEATURES[-1]],
        transforms=preprocessor,
    )

#### Evaluation

In [None]:
# specify some filters to apply to the dataset
slice_list = [
    # remove null values in column
    {"dod": {"keep_nulls": False}},
    {
        "admission_type": {"keep_nulls": True, "negate": True},
        "admission_location": {"keep_nulls": False},
    },
    # filter by exact value
    {"sex": {"value": "M"}},
    # filter numeric values by range
    {
        "age": {
            "min_value": 18,
            "max_value": 65,
            "min_inclusive": True,
            "max_inclusive": False,
        }
    },
    # filter by value in list
    {"admission_type": {"value": ["EW EMER.", "DIRECT EMER.", "URGENT"]}},
    # filter string values by substring
    {"admission_location": {"contains": "REFERRAL"}},
    # filter by date range (time string format: YYYY-MM-DD)
    {"dod": {"max_value": "2019-12-01", "keep_nulls": True}},
    # negate a filter
    {"dod": {"max_value": "2019-12-01", "negate": True}},
    # filter by month (1-12)
    {"admit_timestamp": {"month": [6, 7, 8, 9], "keep_nulls": False}},
    {
        "sex": {"value": "F"},
        "race": {"contains": ["BLACK", "WHITE"]},
        "age": {"min_value": 25, "max_value": 40},
    },  # compound slice
]

# create the slice functions
slice_spec = SliceSpec()
for slice_ in slice_list:
    slice_spec.add_slice_spec(slice_)

# or
# slice_spec = SliceSpec(spec_list=slice_list)

In [None]:
# define the metrics
metric_names = ["accuracy", "precision", "recall", "f1_score", "auroc"]
metrics = [create_metric(metric_name, task="binary") for metric_name in metric_names]
tab_metrics = MetricCollection(metrics)

In [None]:
tab_eval_result = evaluator.evaluate(
    encounters_ds,
    tab_metrics,
    split="test",
    models=model_dict,
    transforms=preprocessor,
    feature_columns=TAB_FEATURES[:-1],
    target_columns=TAB_FEATURES[-1],
    slice_spec=slice_spec,
    batch_size=None,  # load all data into memory
)

In [None]:
# plot evaluation results
reformed_dict = {}
for outerKey, innerDict in tab_eval_result.items():
    for innerKey, values in innerDict.items():
        reformed_dict[(outerKey, innerKey)] = values

tidy_df = pd.melt(
    pd.DataFrame(reformed_dict).T.rename_axis(["model", "slice"]),
    ignore_index=False,
    var_name="metric",
).reset_index()

sns.catplot(
    data=tidy_df,
    x="slice",
    y="value",
    hue="model",
    row="slice",
    col="metric",
    kind="bar",
    sharey=True,
    sharex=False,
)

##### Fairness

In [None]:
specificity = create_metric(metric_name="specificity", task="binary")
sensitivity = create_metric(metric_name="sensitivity", task="binary")

fpr = 1 - specificity
fnr = 1 - sensitivity
ber = (fpr + fnr) / 2  # balanced error rate

fairness_metric_collection = MetricCollection(
    {
        "Sensitivity": sensitivity,
        "Specificity": specificity,
        "FPR": fpr,
        "FNR": fnr,
        "BER": ber,
    }
)

In [None]:
fairness_config = FairnessConfig(
    metrics=fairness_metric_collection,
    dataset=None,  # dataset is passed from the evaluator
    target_columns=None,  # target columns are passed from the evaluator
    groups=["sex", "age"],
    group_bins={"age": [26, 42, 58, 68]},
    group_base_values={"sex": "M", "age": 40},
    thresholds=[0.1, 0.5, 0.9],
)

In [None]:
tab_model_analysis_results = evaluator.evaluate(
    encounters_ds,
    tab_metrics,
    split="test",
    models=model_dict,
    feature_columns=TAB_FEATURES[:-1],
    target_columns=TAB_FEATURES[-1],
    transforms=preprocessor,
    slice_spec=slice_spec,
    batch_size=-1,  # use all examples at once
    fairness_config=fairness_config,
    override_fairness_metrics=False,  # use separate metrics for evaluating fairness
)

In [None]:
reformed_fairness_dict = {}
for outerKey, innerDict in tab_model_analysis_results["fairness"].items():
    for innerKey, values in innerDict.items():
        reformed_fairness_dict[(outerKey, innerKey)] = values

tidy_fairness_df = pd.melt(
    pd.DataFrame(reformed_fairness_dict).T.rename_axis(["model", "slice"]),
    ignore_index=False,
    var_name="metric",
).reset_index()

sns.catplot(
    data=tidy_fairness_df,
    x="slice",
    y="value",
    hue="model",
    row="metric",
    col="slice",
    kind="bar",
    sharey=False,
    sharex=False,
)

### PyTorch

#### Data Loading

In [None]:
def nihcxr_preprocess(df: pd.DataFrame, nihcxr_dir: str) -> pd.DataFrame:
    """Preprocess NIHCXR dataframe.

    Add a column with the path to the image and create one-hot encoded pathogies
    from Finding Labels column.

    Args:
        df (pd.DataFrame): NIHCXR dataframe.

    Returns:
        pd.DataFrame: pre-processed NIHCXR dataframe.
    """

    # Add path column
    df["image"] = df["Image Index"].apply(
        lambda x: os.path.join(nihcxr_dir, "images", x)
    )

    # Create one-hot encoded pathologies
    pathologies = df["Finding Labels"].str.get_dummies(sep="|")

    # Add one-hot encoded pathologies to dataframe
    df = pd.concat([df, pathologies], axis=1)

    return df


nihcxr_dir = "/mnt/data/clinical_datasets/NIHCXR"

test_df = pd.read_csv(
    join(nihcxr_dir, "test_list.txt"), header=None, names=["Image Index"]
)

# select only the images in the test list
df = pd.read_csv(join(nihcxr_dir, "Data_Entry_2017.csv"))
df.dropna(how="all", axis="columns", inplace=True)  # drop empty columns
df = df[df["Image Index"].isin(test_df["Image Index"])]

df = nihcxr_preprocess(df, nihcxr_dir)

# create a Dataset object
nih_ds = Dataset.from_pandas(df, preserve_index=False)
nih_ds = nih_ds.cast_column("image", Image())

In [None]:
nih_ds.features

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

#### Pre-processing

In [None]:
transforms = Compose(
    [
        # TorchVisiond(keys=("image",), name="PILToTensor"), doesn't work
        AddChanneld(keys=("image",)),
        CenterSpatialCropd(keys=("image",), roi_size=(1, 224, 224)),
        Lambdad(keys=("image"), func=lambda x: ((2 * (x / 255.0)) - 1.0) * 1024),
        ToDeviced(keys=("image",), device=device),
    ],
)


def apply_transforms(examples: Dict[str, List], transforms: callable) -> dict:
    """Apply transforms to examples."""

    # examples is a dict of lists; convert to list of dicts.
    # doing a conversion from PIL to tensor is necessary here when working
    # with the Image feature type.
    value_len = len(list(examples.values())[0])
    examples = [
        {
            k: PILToTensor()(v[i]) if isinstance(v[i], PIL.Image.Image) else v[i]
            for k, v in examples.items()
        }
        for i in range(value_len)
    ]

    # apply the transforms to each example
    examples = [transforms(example) for example in examples]

    # convert back to a dict of lists
    examples = {k: [d[k] for d in examples] for k in examples[0]}

    return examples

In [None]:
# from torch.utils.data import DataLoader
# from torch.utils.data.sampler import BatchSampler, RandomSampler

# nih_dl = DataLoader(
#     nih_ds.with_transform(
#         partial(apply_transforms, transforms=transforms),
#         columns=["image"],
#         output_all_columns=True,
#     ),
#     batch_size=TORCH_BATCH_SIZE,
#     drop_last=False
# )

# for batch in nih_dl:
#     print(batch)
#     break

#### Prediction

In [None]:
model = xrv.models.DenseNet(weights="densenet121-res224-nih")
model.eval()
model.to(device)

In [None]:
from datasets.combine import concatenate_datasets  # noqa: E402


def get_predictions_torch(examples):
    images = torch.stack(examples["image"]).squeeze(1)
    preds = model(images)
    return {"predictions": preds}


with nih_ds.formatted_as(
    "custom",
    columns=["image"],
    transform=partial(apply_transforms, transforms=transforms),
):
    preds_ds = nih_ds.map(
        get_predictions_torch,
        batched=True,
        batch_size=TORCH_BATCH_SIZE,
        remove_columns=nih_ds.column_names,
    )

    nih_ds = concatenate_datasets([nih_ds, preds_ds], axis=1)

In [None]:
nih_ds.features

In [None]:
from cyclops.data.slicer import filter_value  # noqa: E402

# remove any rows with No Finding == 1
nih_ds = nih_ds.filter(
    partial(filter_value, column_name="No Finding", value=1, negate=True), batched=True
)

# remove the No Finding column and adjust the predictions to account for it
nih_ds = nih_ds.map(
    lambda x: {
        "predictions": x["predictions"][:14],
    },
    remove_columns=["No Finding"],
)
nih_ds.features

In [None]:
# get the list of pathologies
pathologies = model.pathologies[:14]
pathologies

#### Evaluation

In [None]:
# define the slices
slices = [
    {"Patient Gender": {"value": "M"}},
    {"Patient Age": {"min_value": 20, "max_value": 40}},
]

# create the slice functions
slice_spec = SliceSpec(spec_list=slices)

In [None]:
auroc = create_metric(
    metric_name="auroc",
    task="multilabel",
    num_labels=len(pathologies),
    thresholds=np.arange(0, 1, 0.01),
)

In [None]:
nih_eval_results = evaluator.evaluate(
    dataset=nih_ds,
    metrics=auroc,
    feature_columns="image",
    target_columns=pathologies,
    prediction_column_prefix="predictions",
    remove_columns="image",
    slice_spec=slice_spec,
)

In [None]:
# plot the results
plots = []

for slice_name, slice_results in nih_eval_results.items():
    plots.append(
        go.Scatter(
            x=pathologies,
            y=slice_results["MultilabelAUROC"],
            name="Overall" if slice_name == "overall" else slice_name,
            mode="markers",
        )
    )

fig = go.Figure(data=plots)
fig.update_layout(
    title="Multilabel AUROC by Pathology and Slice",
    title_x=0.5,
    title_font_size=20,
    xaxis_title="Pathology",
    yaxis_title="Multilabel AUROC",
)
fig.update_traces(
    marker=dict(size=12, line=dict(width=2, color="DarkSlateGrey")),
    selector=dict(mode="markers"),
)
fig.show()

#### Fairness

In [None]:
specificity = create_metric(
    metric_name="specificity",
    task="multilabel",
    num_labels=len(pathologies),
)
sensitivity = create_metric(
    metric_name="sensitivity",
    task="multilabel",
    num_labels=len(pathologies),
)

fpr = 1 - specificity
fnr = 1 - sensitivity

balanced_error_rate = (fpr + fnr) / 2

In [None]:
nih_fairness_result = evaluate_fairness(
    metrics=balanced_error_rate,
    metric_name="BalancedErrorRate",
    dataset=nih_ds,
    remove_columns="image",
    target_columns=pathologies,
    prediction_columns="predictions",
    groups=["Patient Age", "Patient Gender"],
    group_bins={"Patient Age": [20, 40, 60, 80]},
    group_base_values={"Patient Age": 20, "Patient Gender": "M"},
)

##### Plots

In [None]:
# plot group size per slice
plots = []

for slice_name, slice_results in nih_fairness_result.items():
    plots.append(
        go.Bar(
            x=[slice_name],
            y=[slice_results["Group Size"]],
            name=slice_name,
        )
    )

fig = go.Figure(data=plots)
fig.update_layout(
    title="Size of Each Group",
    title_x=0.5,
    title_font_size=20,
    xaxis_title="Group",
    yaxis_title="Group Size",
    showlegend=False,
)

In [None]:
# plot metrics per slice
plots = []

for slice_name, slice_results in nih_fairness_result.items():
    plots.append(
        go.Scatter(
            x=pathologies,
            y=slice_results["BalancedErrorRate"],
            name=slice_name,
            mode="markers",
        )
    )

fig = go.Figure(data=plots)
fig.update_layout(
    title="Balanced Error Rate by Pathology and Group",
    title_x=0.5,
    title_font_size=20,
    xaxis_title="Pathology",
    yaxis_title="Balanced Error Rate",
)
fig.update_traces(
    marker=dict(size=12, line=dict(width=2, color="DarkSlateGrey")),
    selector=dict(mode="markers"),
)
fig.show()

In [None]:
# plot parity difference per slice
plots = []

for slice_name, slice_results in nih_fairness_result.items():
    plots.append(
        go.Scatter(
            x=pathologies,
            y=slice_results["BalancedErrorRate Parity"],
            name=slice_name,
            mode="markers",
        )
    )

fig = go.Figure(data=plots)
fig.update_layout(
    title="Balanced Error Rate Parity by Pathology and Group",
    title_x=0.5,
    title_font_size=20,
    xaxis_title="Pathology",
    yaxis_title="Balanced Error Rate Parity",
)
fig.update_traces(
    marker=dict(size=12, line=dict(width=2, color="DarkSlateGrey")),
    selector=dict(mode="markers"),
)
fig.show()

##### Alternative

In [None]:
fairness_config = FairnessConfig(
    metrics=balanced_error_rate,
    metric_name="BalancedErrorRate",
    dataset=None,  # dataset is passed from the evaluator
    target_columns=None,  # target columns are passed from the evaluator
    groups=["Patient Age", "Patient Gender"],
    group_bins={"Patient Age": [20, 40, 60, 80]},
    group_base_values={"Patient Age": 20, "Patient Gender": "M"},
)

evaluator.evaluate(
    dataset=nih_ds,
    metrics=auroc,
    target_columns=pathologies,
    slice_spec=slice_spec,
    remove_columns=["image"],
    fairness_config=fairness_config,
    override_fairness_metrics=False,
)