## Exploring the Extensibility of the 🤗 Datasets Library for Medical Images

In [None]:
import datetime
import glob
import os

import dask
import dask.dataframe as dd
import pandas as pd
import psutil
from datasets import Dataset, Image, load_dataset
from omegaconf import OmegaConf

from cyclops.utils.file import join
from use_cases.params.mimiciv.mortality_decompensation.constants_v1 import QUERIED_DIR

### Exploring existing functionalities that are relevant to CyclOps

In [None]:
# CONSTANTS
NUM_PROC = 4

### Tabular Data

#### Constructing a 🤗 Dataset from MIMICIV-v2.0 PostgreSQL Database

In [None]:
db_cfg = OmegaConf.load(join("..", "cyclops", "query", "configs", "config.yaml"))

con_str = (
    db_cfg.dbms
    + "://"
    + db_cfg.user
    + ":"
    + db_cfg.password
    + "@"
    + db_cfg.host
    + "/"
    + db_cfg.database
)

ds = Dataset.from_sql(
    sql="SELECT * FROM mimiciv_hosp.patients LIMIT 10",
    con=con_str,
    keep_in_memory=True,
)
ds

#### Constructing a 🤗 Dataset from local parquet files

In [None]:
parquet_files = list(glob.glob(join(QUERIED_DIR, "*.parquet")))
len(parquet_files)

In [None]:
# take the first 100 files
parquet_files = parquet_files[:100]

In [None]:
mimic_md_ds = load_dataset("parquet", data_files=parquet_files, num_proc=4)

# clear all other cache files, except for the current cache file
mimic_md_ds.cleanup_cache_files()

size_gb = mimic_md_ds["train"].dataset_size / (1024**3)
print(f"Dataset size (cache file) : {size_gb:.2f} GB")

print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

In [None]:
mimic_md_ds

##### Benchmarking Filtering operations: 🤗 Dataset vs. Dask

In [None]:
dask.config.set(scheduler="processes", num_workers=NUM_PROC)

ddf = dd.read_parquet(parquet_files)
len(ddf)

1. **Filtering on 1 column**

Get all rows where the values in column `event_cateogry` is in a list of values.

In [None]:
event_filter = [
    "Cadiovascular",
    "Dialysis",
    "Hemodynamics",
    "Neurological",
    "Toxicology",
    "General",
]

In [None]:
%%timeit
events_ddf = ddf[ddf["event_category"].isin(event_filter)].compute()

In [None]:
%%timeit
events_ds = mimic_md_ds["train"].filter(
    lambda examples: [
        example in event_filter for example in examples["event_category"]
    ],
    batched=True,
    num_proc=NUM_PROC,
    load_from_cache_file=False,  # timeit will run multiple times
)

2. **Filtering on multiple columns**

Get all items where the values in two columns are in a list of values for each column.

In [None]:
discharge_location_filter = ["HOME", "HOME HEALTH CARE"]
admission_location_filter = [
    "TRANSFER FROM HOSPITAL",
    "PHYSICIAN REFERRAL",
    "CLINIC REFERRAL",
]

In [None]:
%%timeit

location_ddf = ddf[
    (ddf["discharge_location"].isin(discharge_location_filter))
    & (ddf["admission_location"].isin(admission_location_filter))
].compute()

In [None]:
%%timeit

location_ds = mimic_md_ds["train"].filter(
    lambda examples: [
        example[0] in discharge_location_filter
        and example[1] in admission_location_filter
        for example in zip(
            examples["discharge_location"], examples["admission_location"]
        )
    ],
    batched=True,
    num_proc=NUM_PROC,
    load_from_cache_file=False,  # timeit will run multiple times
)

3. **Filtering on a datetime condition**

Get all rows where `death of death` occurred after January 1, 2020.

In [None]:
%%timeit
dod_ddf = ddf[ddf["dod"] > datetime.datetime(2020, 1, 1)].compute()

In [None]:
%%timeit

dod_ds = mimic_md_ds["train"].filter(
    lambda examples: [
        example is not None and example > datetime.datetime(2020, 1, 1)
        for example in examples["dod"]
    ],
    batched=True,
    num_proc=NUM_PROC,
    load_from_cache_file=False,  # timeit will run multiple times
)

4. **Filter on a condition on a column**

In [None]:
%%timeit
millenials_ddf = ddf[(ddf.age <= 40) & (ddf.age >= 25)].compute()

In [None]:
%%timeit
millenials_ds = mimic_md_ds["train"].filter(
    lambda examples: [25 <= example <= 40 for example in examples["age"]],
    batched=True,
    num_proc=NUM_PROC,
    load_from_cache_file=False,  # timeit will run multiple times
)

### Image Data - Constructing a 🤗 Dataset from image folder

From the 🤗 Datasets documentation, there are 3 ways to load local image data into a 🤗 Dataset:
1. **Load images from a folder with the following structure:**
    ```bash
    root_folder/train/class1/img1.png
    root_folder/train/class1/img2.png
    root_folder/train/class2/img1.png
    root_folder/train/class2/img2.png
    root_folder/test/class1/img1.png
    root_folder/test/class1/img2.png
    root_folder/test/class2/img1.png
    root_folder/test/class2/img2.png
    ...
    ```
    The folder names are the class names and the dataset splits (train/test) will automatically be recognized.
    The dataset can be loaded using the following code:
    ```python
    from datasets import load_dataset
    dataset = load_dataset("imagefolder", data_dir="root_folder")
    ```
    (This method also supports loading remote image folders from URLs.)
    
    The downside of this approach is that it uses PIL to load the images, which does not support many medical image formats like DICOM and NIfTI.

2. **Load images using a list of image paths**
    ```python
    from datasets import Dataset
    from datasets.features import Image
    dataset = Dataset.from_dict({"image": ["path/to/img1.png", "path/to/img2.png", ...]}).cast_column("image", Image())
    ```
    This approach is more flexible than the previous one, but it still has the same limitation of not supporting many medical image formats.

3. **Create a dataset loading script**

    This is the most flexible way to load and share different types of datasets that are not natively supported by 🤗 Datasets library.
    In fact, the `imagefolder` dataset is an example of a dataset loading script. In essence, we can extend that script to support more image formats like DICOM and NIfTI. That solves half the problem. The other half is that we need to create a new feature to extend the `Image` class to support decoding medical image formats.

#### Case Study: MIMIC-CXR-JPG v2.0.0
For this case study, we will combine CSV metadata and the `Image` feature to create a 🤗 Dataset from the MIMIC-CXR-JPG v2.0.0 dataset. The dataset is available on [PhysioNet](https://physionet.org/content/mimic-cxr-jpg/2.0.0/).

The dataset comes with 4 compressed CSV metadata files. The metadata files are `mimic-cxr-2.0.0-split.csv.gz`, `mimic-cxr-2.0.0-chexpert.csv.gz`, `mimic-cxr-2.0.0-negbio.csv.gz`, and `mimic-cxr-2.0.0-metadata.csv.gz`. The `mimic-cxr-2.0.0-split.csv.gz` file contains the train/val/test split for each image. The `mimic-cxr-2.0.0-chexpert.csv.gz` file contains the CheXpert labels for each image. The `mimic-cxr-2.0.0-negbio.csv.gz` file contains the NegBio labels for each image. The `mimic-cxr-2.0.0-metadata.csv.gz` file contains other metadata for each image. All the metadata files can be joined on the `subject_id` and `study_id` columns.

In [None]:
mimic_cxr_jpg_dir = "/mnt/data/clinical_datasets/mimic-cxr-jpg-2.0.0"

In [None]:
metadata_df = pd.read_csv(
    os.path.join(mimic_cxr_jpg_dir, "mimic-cxr-2.0.0-metadata.csv.gz")
)
negbio_df = pd.read_csv(
    os.path.join(mimic_cxr_jpg_dir, "mimic-cxr-2.0.0-negbio.csv.gz")
)
split_df = pd.read_csv(os.path.join(mimic_cxr_jpg_dir, "mimic-cxr-2.0.0-split.csv.gz"))

In [None]:
# join the 3 dataframes on subject_id and study_id
metadata_df = metadata_df.merge(
    split_df, on=["subject_id", "study_id", "dicom_id"]
).merge(negbio_df, on=["subject_id", "study_id"])

In [None]:
# select rows with images in folder 'p10' i.e. subject_id starts with 10
metadata_df = metadata_df[metadata_df["subject_id"].astype(str).str.startswith("10")]

In [None]:
# create HuggingFace Dataset from pandas DataFrame
mimic_cxr_ds = Dataset.from_pandas(
    metadata_df[metadata_df.split == "train"], split="train", preserve_index=False
)
mimic_cxr_ds

In [None]:
# create a new column with the full path to the image:
# mimic_cxr_jpg_dir + "p10" + "p" + subject_id + study_id + dicom_id + ".jpg"
def get_filename(examples):
    subject_ids = examples["subject_id"]
    study_ids = examples["study_id"]
    dicom_ids = examples["dicom_id"]
    examples["image"] = [
        os.path.join(
            mimic_cxr_jpg_dir,
            "files",
            "p10",
            "p" + str(subject_id),
            "s" + str(study_id),
            dicom_id + ".jpg",
        )
        for subject_id, study_id, dicom_id in zip(subject_ids, study_ids, dicom_ids)
    ]
    return examples


mimic_cxr_ds = mimic_cxr_ds.map(
    get_filename,
    batched=True,
    num_proc=NUM_PROC,
    remove_columns=["dicom_id", "split", "Rows", "Columns"],
)
mimic_cxr_ds

In [None]:
mimic_cxr_ds = mimic_cxr_ds.cast_column("image", Image())
mimic_cxr_ds.features

### Extending 🤗 Dataset to Load DICOM (and NIfTI) images
1. Create a new feature class that extends the `Image` class to support decoding medical image formats. Let's call it `MedicalImage`. This will use MONAI to decode the medical image formats.
2. Create a new dataset loading script that extends the `imagefolder` dataset loading script to support the `MedicalImage` feature class. We can call it `medical_imagefolder`.

In [None]:
from cyclops.datasets import medicalimagefolder  # noqa: E402
from cyclops.datasets.features import MedicalImage  # noqa: E402

In [None]:
dcm_files = glob.glob(
    "/mnt/data/clinical_datasets/pseudo_phi_dataset/Pseudo-PHI-DICOM-Data/**/*.dcm",
    recursive=True,
)

In [None]:
dicom_ds = Dataset.from_dict({"image": dcm_files}).cast_column("image", MedicalImage())
dicom_ds

In [None]:
dicom_ds.set_format("torch")
type(dicom_ds[0]["image"]["array"])

In [None]:
med_ds = load_dataset(medicalimagefolder, data_files=dcm_files)
med_ds

In [None]:
med_ds["train"].features

### Some Challenges

1. Handling metadata. What to do with it?
2. Encoding and decoding image bytes in the formats that are supported by the `MedicalImage` feature class.