## Exploring the Extensibility of the 🤗 Datasets Library for Medical Images

In [None]:
import datetime
import glob
import os
from functools import partial
from typing import Dict, List, Mapping, Tuple, Union

import dask
import dask.dataframe as dd
import numpy as np
import pandas as pd
import PIL
import psutil
import torch
import torchxrayvision as xrv
import yaml
from datasets import Dataset, DatasetDict, load_dataset
from datasets.features import Image
from datasets.splits import Split
from monai.transforms import AddChanneld, Compose, Lambdad, Resized, ToDeviced
from omegaconf import OmegaConf
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from torchvision.transforms import PILToTensor

from cyclops.datasets.slicing import SlicingConfig
from cyclops.evaluate.metrics import MetricCollection, create_metric
from cyclops.models.catalog import create_model, list_models
from cyclops.models.constants import CONFIG_ROOT
from cyclops.utils.file import join
from use_cases.params.mimiciv.mortality_decompensation.constants_v1 import (
    ENCOUNTERS_FILE,
    QUERIED_DIR,
    TAB_FEATURES,
)

In [None]:
os.environ["HF_DATASETS_CACHE"] = "/mnt/data/.cache/huggingface/datasets"

In [None]:
# CONSTANTS
NUM_PROC = 4
TORCH_BATCH_SIZE = 64

## Exploring existing functionalities that are relevant to CyclOps

### Tabular Data

#### Constructing a 🤗 Dataset from MIMICIV-v2.0 PostgreSQL Database

In [None]:
db_cfg = OmegaConf.load(join("..", "cyclops", "query", "configs", "config.yaml"))

con_str = (
    db_cfg.dbms
    + "://"
    + db_cfg.user
    + ":"
    + db_cfg.password
    + "@"
    + db_cfg.host
    + "/"
    + db_cfg.database
)

ds = Dataset.from_sql(
    sql="SELECT * FROM mimiciv_hosp.patients LIMIT 10",
    con=con_str,
    keep_in_memory=True,
)
ds

#### Constructing a 🤗 Dataset from local parquet files

In [None]:
parquet_files = list(glob.glob(join(QUERIED_DIR, "*.parquet")))
len(parquet_files)

In [None]:
# take the first 100 files
parquet_files = parquet_files[:100]

In [None]:
mimic_md_ds = load_dataset(
    "parquet",
    data_files=parquet_files,
    split=Split.ALL,
    num_proc=4,
    cache_dir=os.environ["HF_DATASETS_CACHE"],
)

# clear all other cache files, except for the current cache file
mimic_md_ds.cleanup_cache_files()

size_gb = mimic_md_ds.dataset_size / (1024**3)
print(f"Dataset size (cache file) : {size_gb:.2f} GB")

print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
mimic_md_ds

In [None]:
mimic_md_ds.features

##### Benchmarking Filtering operations: 🤗 Dataset vs. Dask

In [None]:
dask.config.set(scheduler="processes", num_workers=NUM_PROC)

ddf = dd.read_parquet(parquet_files)
len(ddf)

1. **Filtering on 1 column**

Get all rows where the values in column `event_cateogry` is in a list of values.

In [None]:
event_filter = [
    "Cadiovascular",
    "Dialysis",
    "Hemodynamics",
    "Neurological",
    "Toxicology",
    "General",
]

In [None]:
%%timeit
events_ddf = ddf[ddf["event_category"].isin(event_filter)].compute()

In [None]:
%%timeit
events_ds = mimic_md_ds.filter(
    lambda examples: [
        example in event_filter for example in examples["event_category"]
    ],
    batched=True,
    num_proc=NUM_PROC,
    load_from_cache_file=False,  # timeit will run multiple times
)

2. **Filtering on multiple columns**

Get all items where the values in two columns are in a list of values for each column.

In [None]:
discharge_location_filter = ["HOME", "HOME HEALTH CARE"]
admission_location_filter = [
    "TRANSFER FROM HOSPITAL",
    "PHYSICIAN REFERRAL",
    "CLINIC REFERRAL",
]

In [None]:
%%timeit

location_ddf = ddf[
    (ddf["discharge_location"].isin(discharge_location_filter))
    & (ddf["admission_location"].isin(admission_location_filter))
].compute()

In [None]:
%%timeit

location_ds = mimic_md_ds.filter(
    lambda examples: [
        example[0] in discharge_location_filter
        and example[1] in admission_location_filter
        for example in zip(
            examples["discharge_location"], examples["admission_location"]
        )
    ],
    batched=True,
    num_proc=NUM_PROC,
    load_from_cache_file=False,  # timeit will run multiple times
)

3. **Filtering on a datetime condition**

Get all rows where `date of death` occurred after January 1, 2020.

In [None]:
%%timeit
dod_ddf = ddf[ddf["dod"] > datetime.datetime(2020, 1, 1)].compute()

In [None]:
%%timeit

dod_ds = mimic_md_ds.filter(
    lambda examples: [
        example is not None and example > datetime.datetime(2020, 1, 1)
        for example in examples["dod"]
    ],
    batched=True,
    num_proc=NUM_PROC,
    load_from_cache_file=False,  # timeit will run multiple times
)

4. **Filter on a condition on a column**

In [None]:
%%timeit
millenials_ddf = ddf[(ddf.age <= 40) & (ddf.age >= 25)].compute()

In [None]:
%%timeit
millenials_ds = mimic_md_ds.filter(
    lambda examples: [25 <= example <= 40 for example in examples["age"]],
    batched=True,
    num_proc=NUM_PROC,
    load_from_cache_file=False,  # timeit will run multiple times
)

### Image Data - Constructing a 🤗 Dataset from image folder

From the 🤗 Datasets documentation, there are 3 ways to load local image data into a 🤗 Dataset:
1. **Load images from a folder with the following structure:**
    ```bash
    root_folder/train/class1/img1.png
    root_folder/train/class1/img2.png
    root_folder/train/class2/img1.png
    root_folder/train/class2/img2.png
    root_folder/test/class1/img1.png
    root_folder/test/class1/img2.png
    root_folder/test/class2/img1.png
    root_folder/test/class2/img2.png
    ...
    ```
    The folder names are the class names and the dataset splits (train/test) will automatically be recognized.
    The dataset can be loaded using the following code:
    ```python
    from datasets import load_dataset
    dataset = load_dataset("imagefolder", data_dir="root_folder")
    ```
    (This method also supports loading remote image folders from URLs.)
    
    The downside of this approach is that it uses PIL to load the images, which does not support many medical image formats like DICOM and NIfTI.

2. **Load images using a list of image paths**
    ```python
    from datasets import Dataset
    from datasets.features import Image
    dataset = Dataset.from_dict({"image": ["path/to/img1.png", "path/to/img2.png", ...]}).cast_column("image", Image())
    ```
    This approach is more flexible than the previous one, but it still has the same limitation of not supporting many medical image formats.

3. **Create a dataset loading script**

    This is the most flexible way to load and share different types of datasets that are not natively supported by 🤗 Datasets library.
    In fact, the `imagefolder` dataset is an example of a dataset loading script. In essence, we can extend that script to support more image formats like DICOM and NIfTI. That solves half the problem. The other half is that we need to create a new feature to extend the `Image` class to support decoding medical image formats.

#### Case Study: MIMIC-CXR-JPG v2.0.0
For this case study, we will combine CSV metadata and the `Image` feature to create a 🤗 Dataset from the MIMIC-CXR-JPG v2.0.0 dataset. The dataset is available on [PhysioNet](https://physionet.org/content/mimic-cxr-jpg/2.0.0/).

The dataset comes with 4 compressed CSV metadata files. The metadata files are `mimic-cxr-2.0.0-split.csv.gz`, `mimic-cxr-2.0.0-chexpert.csv.gz`, `mimic-cxr-2.0.0-negbio.csv.gz`, and `mimic-cxr-2.0.0-metadata.csv.gz`. The `mimic-cxr-2.0.0-split.csv.gz` file contains the train/val/test split for each image. The `mimic-cxr-2.0.0-chexpert.csv.gz` file contains the CheXpert labels for each image. The `mimic-cxr-2.0.0-negbio.csv.gz` file contains the NegBio labels for each image. The `mimic-cxr-2.0.0-metadata.csv.gz` file contains other metadata for each image. All the metadata files can be joined on the `subject_id` and `study_id` columns.

In [None]:
mimic_cxr_jpg_dir = "/mnt/data/clinical_datasets/mimic-cxr-jpg-2.0.0"

In [None]:
metadata_df = pd.read_csv(
    os.path.join(mimic_cxr_jpg_dir, "mimic-cxr-2.0.0-metadata.csv.gz")
)
negbio_df = pd.read_csv(
    os.path.join(mimic_cxr_jpg_dir, "mimic-cxr-2.0.0-negbio.csv.gz")
)
split_df = pd.read_csv(os.path.join(mimic_cxr_jpg_dir, "mimic-cxr-2.0.0-split.csv.gz"))

In [None]:
# join the 3 dataframes on subject_id and study_id
metadata_df = metadata_df.merge(
    split_df, on=["subject_id", "study_id", "dicom_id"]
).merge(negbio_df, on=["subject_id", "study_id"])

In [None]:
# select rows with images in folder 'p10' i.e. subject_id starts with 10
metadata_df = metadata_df[metadata_df["subject_id"].astype(str).str.startswith("10")]

In [None]:
# create HuggingFace Dataset from pandas DataFrame
mimic_cxr_ds = Dataset.from_pandas(
    metadata_df[metadata_df.split == "train"], split="train", preserve_index=False
)
mimic_cxr_ds

In [None]:
# create a new column with the full path to the image:
# mimic_cxr_jpg_dir + "p10" + "p" + subject_id + study_id + dicom_id + ".jpg"
def get_filename(examples):
    subject_ids = examples["subject_id"]
    study_ids = examples["study_id"]
    dicom_ids = examples["dicom_id"]
    examples["image"] = [
        os.path.join(
            mimic_cxr_jpg_dir,
            "files",
            "p10",
            "p" + str(subject_id),
            "s" + str(study_id),
            dicom_id + ".jpg",
        )
        for subject_id, study_id, dicom_id in zip(subject_ids, study_ids, dicom_ids)
    ]
    return examples


mimic_cxr_ds = mimic_cxr_ds.map(
    get_filename,
    batched=True,
    num_proc=NUM_PROC,
    remove_columns=["dicom_id", "split", "Rows", "Columns"],
)
mimic_cxr_ds

In [None]:
mimic_cxr_ds = mimic_cxr_ds.cast_column("image", Image())
mimic_cxr_ds.features

## Extending 🤗 Dataset to Load DICOM (and NIfTI) images
1. Create a new feature class that extends the `Image` class to support decoding medical image formats. Let's call it `MedicalImage`. This will use MONAI to decode the medical image formats.
2. Create a new dataset loading script that extends the `imagefolder` dataset loading script to support the `MedicalImage` feature class. We can call it `medical_imagefolder`.

In [None]:
from cyclops.datasets import medicalimagefolder  # noqa: E402
from cyclops.datasets.features import MedicalImage  # noqa: E402

In [None]:
dcm_files = glob.glob(
    "/mnt/data/clinical_datasets/pseudo_phi_dataset/Pseudo-PHI-DICOM-Data/**/*.dcm",
    recursive=True,
)

In [None]:
dicom_ds = Dataset.from_dict({"image": dcm_files}).cast_column("image", MedicalImage())
dicom_ds

In [None]:
dicom_ds.set_format("torch")
type(dicom_ds[0]["image"]["array"])

In [None]:
med_ds = load_dataset(medicalimagefolder, data_files=dcm_files)
med_ds

In [None]:
med_ds["train"].features

### Some Challenges

1. Handling metadata. What to do with it?
2. Encoding and decoding image bytes in the formats that are supported by the `MedicalImage` feature class.

## Exploring Training and Evaluation of Scikit-Learn and PyTorch Models

In [None]:
# Utilities
def is_out_of_core(dataset_size) -> bool:
    """Check if dataset is too large to fit in memory."""
    return dataset_size > psutil.virtual_memory().available


def get_pandas_df(
    dataset: Union[Dataset, DatasetDict, Mapping],
    feature_cols: List[str] = None,
    label_col: str = None,
) -> Union[Tuple[pd.DataFrame, pd.Series], Dict[str, Tuple[pd.DataFrame, pd.Series]]]:
    """Convert dataset to pandas dataframe.

    NOTE: converting to pandas does not work with IterableDataset/IterableDatasetDict
    (i.e. when dataset is loaded with stream=True). So, this function should only be
    used with datasets that are loaded with stream=False and are small enough to fit
    in memory. Use :func:`is_out_of_core` to check if dataset is too large to fit in
    memory.


    Parameters
    ----------
    dataset : Union[Dataset, DatasetDict, Mapping]
        Dataset to convert to pandas dataframe.
    feature_cols : List[str], optional
        List of feature columns to include in the dataframe, by default None
    label_col : str, optional
        Label column to include in the dataframe, by default None

    Returns
    -------
    Union[Tuple[pd.DataFrame, pd.Series], Dict[str, Tuple[pd.DataFrame, pd.Series]]]
        Pandas dataframe or dictionary of pandas dataframes.

    Raises
    ------
    TypeError
        If dataset is not a Dataset, DatasetDict, or Mapping.


    """
    if isinstance(dataset, (DatasetDict, Mapping)):
        return {
            k: get_pandas_df(v, feature_cols=feature_cols, label_col=label_col)
            for k, v in dataset.items()
        }
    if isinstance(dataset, Dataset) and not is_out_of_core(dataset.dataset_size):
        # validate feature_cols and label_col
        if feature_cols is not None and not set(feature_cols).issubset(
            dataset.column_names
        ):
            raise ValueError("feature_cols must be a subset of dataset column names.")
        if label_col is not None and label_col not in dataset.column_names:
            raise ValueError("label_col must be a column name of dataset.")

        df = dataset.to_pandas(batched=False)  # set batched=True for large datasets

        if feature_cols is not None and label_col is not None:
            pd_dataset = (df[feature_cols], df[label_col])
        elif label_col is not None:
            pd_dataset = (df.drop(label_col, axis=1), df[label_col])
        elif feature_cols is not None:
            pd_dataset = (df[feature_cols], None)
        else:
            pd_dataset = (df, None)
        return pd_dataset

    raise TypeError(
        f"Expected dataset to be a Dataset or DatasetDict. Got: {type(dataset)}"
    )


def eval_slices(
    ds: Dataset,
    metrics: MetricCollection,
    slice_config: SlicingConfig,
    target_cols: Union[str, List[str]],
    batch_size: int = 5000,
) -> dict:
    """Evaluate slices of a dataset.

    Args:
        ds (Dataset): Dataset to evaluate.
        slice_config (SlicingConfig): SlicingConfig object.
        metric_collection (MetricCollection): MetricCollection object.
        target_cols (str): Name of the label column.

    Returns:
        dict: Dictionary of slice names and metrics.
    """
    if isinstance(target_cols, str):
        target_cols = [target_cols]

    assert isinstance(ds, Dataset), "`ds` must be a Hugging Face Dataset object."

    # if present, drop the `image` column; it is not needed for evaluation
    if "image" in ds.column_names:
        ds = ds.remove_columns("image")

    ds.set_format("numpy")

    slice_metrics = {}
    for slice_name, slice_func in slice_config.get_slices().items():
        slice_ds = ds.filter(slice_func, batched=True, batch_size=batch_size)
        print(slice_name)
        print("NUM_ROWS (sliced): ", slice_ds.num_rows)

        y_true = np.stack(
            [slice_ds[feature] for feature in target_cols], axis=1
        ).squeeze()
        y_pred = slice_ds["predictions"]
        slice_metrics[slice_name] = metrics(y_true, y_pred)
        metrics.reset_state()

    return slice_metrics

### Scikit-Learn

#### Data Loading

In [None]:
encounters_ds = load_dataset(
    "parquet", data_files=ENCOUNTERS_FILE, split=Split.ALL, keep_in_memory=True
)
encounters_ds.cleanup_cache_files()
encounters_ds

**Casting string columns to categorical/numerical columns**

```python
# cast string columns to some numerical type
# this could be categorical or one-hot encoded, depending on the model
# XXX: this might be easier to do/generalize after converting to a DataFrame.

features_copy = encounters_ds.features.copy()

for col in TAB_FEATURES:
    if features_copy[col].dtype in ["string", "bool"]:
        features_copy[col] = ClassLabel(names=encounters_ds.unique(col))

encounters_ds = encounters_ds.cast(features_copy)
encounters_ds.features
```

In [None]:
# split into train, validation, test - 0.8, 0.1, 0.1
# NOTE: train_test_split does not work with IterableDataset objects
encounters_ds = encounters_ds.train_test_split(test_size=0.2, seed=42)
encounters_ds_ = encounters_ds["test"].train_test_split(test_size=0.5, seed=42)
encounters_ds["validation"] = encounters_ds_.pop("train")
encounters_ds["test"] = encounters_ds_.pop("test")
encounters_ds

In [None]:
dod_col = np.asanyarray(encounters_ds["test"]["dod"], dtype="datetime64")
dod_col = dod_col[pd.notnull(dod_col)]

In [None]:
dod_col.max()

In [None]:
if is_out_of_core(dataset_size=encounters_ds["train"].dataset_size):
    raise ValueError("Dataset is too large to fit in memory.")

encounters_df = get_pandas_df(
    encounters_ds["train"], feature_cols=TAB_FEATURES[:-1], label_col=TAB_FEATURES[-1]
)

In [None]:
TAB_FEATURES

#### Pre-processing

In [None]:
# pre-processing pipeline
numeric_features = [0]  # ['age']
numeric_transformer = Pipeline(
    steps=[("imputer", SimpleImputer(strategy="median")), ("scaler", StandardScaler())]
)

categorical_features = [1, 2, 3]  # ['sex', 'admission_type', 'admission_location']
categorical_transformer = OneHotEncoder(handle_unknown="ignore")

preprocessor = ColumnTransformer(
    transformers=[
        ("num", numeric_transformer, numeric_features),
        ("cat", categorical_transformer, categorical_features),
    ]
)

# fit and transform
X_train = preprocessor.fit_transform(encounters_df[0].to_numpy())
y_train = encounters_df[1].to_numpy()

###### Other Ideas

**Normalize numerical columns**

```python
_pa_table = encounters_ds.data  # pyarrow table
for feature in TAB_FEATURES[:-1]:
    if not isinstance(features_copy[feature], ClassLabel) and features_copy[
        feature
    ].dtype in ["int64", "float64"]:
        mean = pa.compute.mean(_pa_table[feature])
        std = pa.compute.stddev(_pa_table[feature])
        feature_norm = pa.compute.divide_checked(
            pa.compute.subtract_checked(_pa_table[feature], mean), std
        )
        _pa_table = _pa_table.append_column(feature, feature_norm)

encounters_ds._data = _pa_table
```

**Split out features and target columns**

```python
def split_out(examples, features, targets):
    """Split out features and targets from the dataset"""

    if not isinstance(features, list):
        features = [features]
    if not isinstance(targets, list):
        targets = [targets]

    # split out features
    # example: 'attributes': [{'feature_1': 1, 'feature_2': 2}, ...]
    examples["attributes"] = [
        {feature: examples[feature][i] for feature in features}
        for i in range(len(examples[features[0]]))
    ]

    examples["targets"] = [
        {target: examples[target][i] for target in targets}
        for i in range(len(examples[targets[0]]))
    ]

    return examples

# XXX: Check that features and targets are in the dataset before mapping
# NOTE: Applying map on a `DatasetDict` (splits) is much slower than on a `Dataset`

encounters_ds_mapped = encounters_ds.map(
    partial(split_out, features=TAB_FEATURES[:-1], targets=TAB_FEATURES[-1]),
    batched=True,
    num_proc=NUM_PROC,
    remove_columns=TAB_FEATURES,
)
encounters_ds_mapped
```

#### Training

In [None]:
list_models("sklearn")

In [None]:
model_name = "xgb_classifier"
config_path = join(CONFIG_ROOT, model_name + ".yaml")
with open(config_path, "r") as f:
    config = yaml.safe_load(f)

model = create_model(model_name, **config["model_params"])

In [None]:
model = model.fit(X_train, y_train)

##### Other Ideas

```python
def train_model(examples):
    X = np.stack([examples[feature] for feature in TAB_FEATURES[:-1]], axis=1)
    y = examples["outcome_death"]
    model.partial_fit(X, y, classes=np.unique(y))
    return examples

if hasattr(model.model, "partial_fit"):
    encounters_ds["train"].with_format("numpy", columns=TAB_FEATURES).map(
        train_model,
        batched=True,
        batch_size=5000,
        num_proc=NUM_PROC,
    )
else:
    ds = encounters_ds["train"].with_format("numpy", columns=TAB_FEATURES)
    X_train = np.stack([ds[feature] for feature in TAB_FEATURES[:-1]], axis=1)
    y_train = ds[TAB_FEATURES[-1]]
    model.fit(X_train, y_train)
```

#### Evaluation

In [None]:
def get_predictions(examples: Dict[str, Union[List, np.ndarray]]) -> dict:
    X = np.stack([examples[feature] for feature in TAB_FEATURES[:-1]], axis=1)
    X = preprocessor.transform(X)
    try:
        examples["predictions"] = model.predict_proba(X)
    except AttributeError:  # some models don't have `predict_proba`
        examples["predictions"] = model.predict(X)
    return examples

In [None]:
ds_with_preds = (
    encounters_ds["test"]
    .with_format("numpy", columns=TAB_FEATURES, output_all_columns=True)
    .map(
        get_predictions,
        batched=True,
        batch_size=5000,
    )
)
ds_with_preds

In [None]:
# define the slices
feature_keys = [
    "dod",  # non-null/non-missing values in column
    [
        "admission_type",
        "admission_location",
    ],  # non-null/non-missing values in all columns in the list
]

feature_values = [
    {"sex": {"value": "M"}},  # feature value is M
    {
        "age": {
            "min_value": 18,
            "max_value": 65,
            "min_inclusive": True,
            "max_inclusive": False,
        }
    },  # feature value is between 18 and 65, inclusive of 18, exclusive of 65
    {
        "admission_type": {"value": ["EW EMER.", "DIRECT EMER.", "URGENT"]}
    },  # feature value is in the list
    {
        "admission_location": {
            "value": ["PHYSICIAN REFERRAL", "CLINIC REFERRAL", "WALK-IN/SELF REFERRAL"],
            "negate": True,
        }
    },  # feature value is NOT in the list
    {
        "dod": {"max_value": "2019-12-01", "keep_nulls": False}
    },  # possibly before COVID-19
    {
        "dod": {"max_value": "2019-12-01", "negate": True, "keep_nulls": False}
    },  # possibly during COVID-19
    {"admit_timestamp": {"month": [6, 7, 8, 9], "keep_nulls": False}},
    {
        "sex": {"value": "F"},
        "race": {
            "value": [
                "BLACK/AFRICAN AMERICAN",
                "BLACK/CARIBBEAN ISLAND",
                "BLACK/CAPE VERDEAN",
                "BLACK/AFRICAN",
            ]
        },
        "age": {"min_value": 25, "max_value": 40},
    },  # compound slice
]

# create the slice functions
slice_config = SlicingConfig()

for key in feature_keys:
    slice_config.add_feature_keys(key)

for feature_value in feature_values:
    slice_config.add_feature_values(feature_value)

# or
# slice_config = SlicingConfig(
#     feature_keys=feature_keys, feature_values=feature_values
# )

In [None]:
# define the metrics
metric_names = ["accuracy", "precision", "recall", "f1_score", "auroc"]
metrics = [create_metric(metric_name, task="binary") for metric_name in metric_names]
metric_collection = MetricCollection(metrics)

In [None]:
eval_slices(
    ds=ds_with_preds,
    metrics=metric_collection,
    slice_config=slice_config,
    target_cols=TAB_FEATURES[-1],
)

### PyTorch

##### Data Loading

In [None]:
def nihcxr_preprocess(df: pd.DataFrame, nihcxr_dir: str) -> pd.DataFrame:
    """Preprocess NIHCXR dataframe.

    Add a column with the path to the image and create one-hot encoded pathogies
    from Finding Labels column.

    Args:
        df (pd.DataFrame): NIHCXR dataframe.

    Returns:
        pd.DataFrame: pre-processed NIHCXR dataframe.
    """

    # Add path column
    df["image"] = df["Image Index"].apply(
        lambda x: os.path.join(nihcxr_dir, "images", x)
    )

    # Create one-hot encoded pathologies
    pathologies = df["Finding Labels"].str.get_dummies(sep="|")

    # Add one-hot encoded pathologies to dataframe
    df = pd.concat([df, pathologies], axis=1)

    return df


nihcxr_dir = "/mnt/data/clinical_datasets/NIHCXR"

test_df = pd.read_csv(
    join(nihcxr_dir, "test_list.txt"), header=None, names=["Image Index"]
)

# select only the images in the test list
df = pd.read_csv(join(nihcxr_dir, "Data_Entry_2017.csv"))
df.dropna(how="all", axis="columns", inplace=True)  # drop empty columns
df = df[df["Image Index"].isin(test_df["Image Index"])]

df = nihcxr_preprocess(df, nihcxr_dir)

# create a Dataset object
nih_ds = Dataset.from_pandas(df, preserve_index=False)
nih_ds = nih_ds.cast_column("image", Image())

In [None]:
pathologies = [
    "Atelectasis",
    "Cardiomegaly",
    "Consolidation",
    "Edema",
    "Effusion",
    "Emphysema",
    "Fibrosis",
    "Hernia",
    "Infiltration",
    "Mass",
    "No Finding",
    "Nodule",
    "Pleural_Thickening",
    "Pneumonia",
    "Pneumothorax",
]

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
nih_ds.features

##### Pre-processing

In [None]:
transforms = Compose(
    [
        # TorchVisiond(keys=("image",), name="PILToTensor"), doesn't work
        AddChanneld(keys=("image",)),
        Resized(keys=("image",), spatial_size=(1, 224, 224)),
        Lambdad(keys=("image"), func=lambda x: ((2 * (x / 255.0)) - 1.0) * 1024),
        ToDeviced(keys=("image",), device=device),
    ],
)


def apply_transforms(examples: Dict[str, List], transforms: callable) -> dict:
    """Apply transforms to examples."""

    # examples is a dict of lists; convert to list of dicts.
    # doing a conversion from PIL to tensor is necessary here when working
    # with the Image feature type.
    value_len = len(list(examples.values())[0])
    examples = [
        {
            k: PILToTensor()(v[i]) if isinstance(v[i], PIL.Image.Image) else v[i]
            for k, v in examples.items()
        }
        for i in range(value_len)
    ]

    # apply the transforms to each example
    examples = [transforms(example) for example in examples]

    # convert back to a dict of lists
    examples = {k: [d[k] for d in examples] for k in examples[0]}

    return examples

In [None]:
# from torch.utils.data import DataLoader
# from torch.utils.data.sampler import BatchSampler, RandomSampler

# batch_sampler = BatchSampler(
#     RandomSampler(nih_ds), batch_size=TORCH_BATCH_SIZE, drop_last=False
# )
# nih_dl = DataLoader(nih_ds, batch_sampler=batch_sampler)

# for batch in nih_dl:
#     print(batch)
#     break

##### Prediction

In [None]:
model = xrv.models.DenseNet(weights="densenet121-res224-nih")
model.classifier = torch.nn.Linear(1024, len(pathologies))
model.op_threshs = None
model.eval()
model.to(device)

In [None]:
def get_predictions_torch(examples):
    images = torch.stack(examples["image"]).squeeze(1)
    examples["predictions"] = model(images)
    return examples

In [None]:
nih_ds = nih_ds.with_transform(
    partial(apply_transforms, transforms=transforms),
    columns=["image"],
    output_all_columns=True,
).map(get_predictions_torch, batched=True, batch_size=TORCH_BATCH_SIZE)
nih_ds

##### Slice-wise Evaluation

In [None]:
# define the slices
feature_values = [
    {"Patient Gender": {"value": "M"}},
    {"Patient Gender": {"value": "F"}},
    {"Patient Age": {"min_value": 25, "max_value": 40}},
    {"Patient Age": {"min_value": 65}},
    {"View Position": {"value": "PA"}},
]

# create the slice functions
slice_config = SlicingConfig(feature_values=feature_values)

In [None]:
# define the metrics
metric_names = ["accuracy", "precision", "recall", "f1_score", "auroc"]
metrics = [
    create_metric(metric_name, task="multilabel", num_labels=len(pathologies))
    for metric_name in metric_names
]
metric_collection = MetricCollection(metrics)

In [None]:
eval_slices(
    ds=nih_ds,
    metrics=metric_collection,
    slice_config=slice_config,
    target_cols=pathologies,
)