# NIHCXR Synthetic Drift - Categorical Shift

## Load Libraries

In [None]:
from cyclops.monitor import (
    Detector,
    Experimenter,
    Reductor,
    SyntheticShiftApplicator,
    TSTester,
)
from cyclops.monitor.plotter import plot_drift_samples_pval
from cyclops.monitor.utils import Loader

## Query Data

In [None]:
import os
from functools import partial
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import PIL
import torch
from datasets import Dataset, Image
from monai.transforms import (
    AddChanneld,
    Compose,
    EnsureChannelFirstd,
    Lambdad,
    Resized,
    ToDeviced,
)
from torchvision.transforms import PILToTensor
from torchxrayvision.datasets import XRayCenterCrop, XRayResizer
from torchxrayvision.models import DenseNet

from cyclops.monitor.utils import nihcxr_preprocess

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"

In [None]:
transforms = Compose(
    [
        #         TorchVisiond(keys=("image",), name="PILToTensor"), doesn't work
        AddChanneld(keys=("features",), allow_missing_keys=True),
        Resized(
            keys=("features",), spatial_size=(1, 224, 224), allow_missing_keys=True
        ),
        Lambdad(
            keys=("features",),
            func=lambda x: ((2 * (x / 255.0)) - 1.0) * 1024,
            allow_missing_keys=True,
        ),
        #         XRayCenterCrop(), XRayResizer(224, engine="cv2"),
        #         XRayResizer(224, engine="cv2"),
        ToDeviced(keys=("features",), device=device, allow_missing_keys=True),
    ],
)


def apply_transforms(examples: Dict[str, List], transforms: callable) -> dict:
    """Apply transforms to examples."""

    # examples is a dict of lists; convert to list of dicts.
    # doing a conversion from PIL to tensor is necessary here when working
    # with the Image feature type.
    value_len = len(list(examples.values())[0])
    examples = [
        {
            k: PILToTensor()(v[i]) if isinstance(v[i], PIL.Image.Image) else v[i]
            for k, v in examples.items()
        }
        for i in range(value_len)
    ]

    # apply the transforms to each example
    examples = [transforms(example) for example in examples]

    # convert back to a dict of lists
    examples = {k: [d[k] for d in examples] for k in examples[0]}

    return examples

In [None]:
nihcxr_dir = "/home/akore/NIHCXR"
df = pd.read_csv(os.path.join(nihcxr_dir, "Data_Entry_2017.csv"))
df = nihcxr_preprocess(df, nihcxr_dir)
nih_ds = Dataset.from_pandas(df, preserve_index=False)
nih_ds = nih_ds.cast_column("features", Image())

In [None]:
# random sample from huggingface arrow dataset
nih_ds = nih_ds.select(np.random.choice(nih_ds.shape[0], 1000, replace=False))

In [None]:
nih_ds = nih_ds.with_transform(
    partial(apply_transforms, transforms=transforms),
    columns=["features"],
    output_all_columns=True,
)

In [None]:
# get transforms used in nih_ds

# tr = nih_ds.format['format_kwargs']['transform'].keywords['transforms'].transforms
# comp.transforms += (EnsureChannelFirstd(keys=("image",)),)
# comp.transforms = comp.transforms[1:]

In [None]:
model = DenseNet(weights="densenet121-res224-all")
reductor = Reductor(dr_method="bbse-soft", model=model, device="cuda")

## Initalize Reductor, Tester & Detector

In [None]:
features = reductor.transform(nih_ds, batch_size=32, num_workers=1)

In [None]:
tester = TSTester(
    tester_method="mmd",
)

detector = Detector(reductor=reductor, tester=tester, device="cuda")

detector.fit(nih_ds, batch_size=32, num_workers=1)

## Setup Baseline Experiment

In [None]:
baseline_experiment = Experimenter(
    "sensitivity_test",
    detector=detector,
)

## Setup Drift Experiments (Categorical Shift)

In [None]:
shiftapplicators = []
shift_type = ["categorical_shift"] * 3
cat_col = ["gender", "view", "age"]
target_categories = ["M", "PA", "18-35"]

for s_type, col, target in zip(shift_type, cat_col, target_categories):
    shiftapplicators.append(
        SyntheticShiftApplicator(
            shift_type=s_type,
            categorical_column=col,
            target_category=target,
        )
    )

experiments = []
for shiftapplicator in shiftapplicators:
    drift_experiment = Experimenter(
        "sensitivity_test",
        detector=detector,
        shiftapplicator=shiftapplicator,
    )
    experiments.append(drift_experiment)

## Run Experiments

In [None]:
baseline_results = baseline_experiment.run(nih_ds)
drift_results = []
for experiment in experiments:
    drift_results.append(experiment.run(nih_ds))

## Gather Results

In [None]:
results_dict = {}
results_dict.update({"baseline": baseline_results})
for itr, result in enumerate(drift_results):
    results_dict.update({f"{cat_col[itr]}: {target_categories[itr]}": result})

## Plot Experimental Results

In [None]:
plot_drift_samples_pval(results_dict, 0.05)