# NIHCXR Clinical Drift

## Load Libraries

In [None]:
from functools import partial

from monai.transforms import AddChanneld, Compose, Lambdad, Resized, ToDeviced
from torchxrayvision.models import DenseNet

from cyclops.datasets.loader import load_nihcxr
from cyclops.datasets.slicer import SliceSpec
from cyclops.datasets.utils import apply_transforms
from cyclops.monitor import ClinicalShiftApplicator, Detector, Reductor, TSTester
from cyclops.monitor.plotter import plot_drift_experiment, plot_drift_timeseries
from cyclops.monitor.utils import get_device

## Query Data

In [None]:
nih_ds = load_nihcxr("/mnt/data/NIHCXR")

## Split Source/Target Datasets

In [None]:
source_slice = None
target_slice = SliceSpec(spec_list=[{"Patient Gender": {"value": "F"}}])

shifter = ClinicalShiftApplicator("custom", source=source_slice, target=target_slice)
source_ds, target_ds = shifter.apply_shift(nih_ds, num_proc=6)

In [None]:
source_slice = None
target_slice = SliceSpec(spec_list=[{"timestamp": {"hour": 0}}])

shifter = ClinicalShiftApplicator("custom", source=source_slice, target=target_slice)
source_ds, target_ds = shifter.apply_shift(nih_ds, num_proc=6)

## Set Transforms 

In [None]:
device = get_device()
transforms = Compose(
    [
        AddChanneld(keys=("features",), allow_missing_keys=True),
        Resized(
            keys=("features",), spatial_size=(1, 224, 224), allow_missing_keys=True
        ),
        Lambdad(
            keys=("features",),
            func=lambda x: ((2 * (x / 255.0)) - 1.0) * 1024,
            allow_missing_keys=True,
        ),
        ToDeviced(keys=("features",), device=device, allow_missing_keys=True),
    ]
)

source_ds = source_ds.with_transform(
    partial(apply_transforms, transforms=transforms),
    columns=["features"],
    output_all_columns=True,
)
target_ds = target_ds.with_transform(
    partial(apply_transforms, transforms=transforms),
    columns=["features"],
    output_all_columns=True,
)

## Sensitivity Test

#### Experiment w/ Dimensionality Reduction Techniques


In [None]:
model = DenseNet(weights="densenet121-res224-all")
dr_methods = {
    "BBSE": Reductor(dr_method="bbse-soft", model=model, device=device),
    "BBSE + TXRV-AE": Reductor(
        dr_method="bbse-soft+txrv-ae", model=model, device=device
    ),
    "TXRV-AE": Reductor(dr_method="txrv-ae", device=device),
}
results = {}

for name, dr_method in dr_methods.items():
    detector = Detector(
        "sensitivity_test",
        reductor=dr_method,
        tester=TSTester(tester_method="mmd"),
        source_sample_size=1000,
        target_sample_size=[50, 100, 200, 300, 400, 600, 800, 1000],
        num_runs=3,
    )
    result = detector.detect_shift(source_ds, target_ds)
    results[name] = result
plot_drift_experiment(results)

#### Experiment w/ Different Clinical Shifts


In [None]:
source_slice = None
target_slices = {
    "SEX: MALE": SliceSpec(spec_list=[{"Patient Gender": {"value": "M"}}]),
    "SEX: FEMALE": SliceSpec(spec_list=[{"Patient Gender": {"value": "F"}}]),
    "AGE: 18-35": SliceSpec(
        spec_list=[{"Patient Age": {"min_value": 18, "max_value": 35}}]
    ),
    "AGE: 35-65": SliceSpec(
        spec_list=[{"Patient Age": {"min_value": 35, "max_value": 65}}]
    ),
}
results = {}

for name, target_slice in target_slices.items():
    source_slice = None
    target_slice = SliceSpec(spec_list=[{"Patient Gender": {"value": "M"}}])
    shifter = ClinicalShiftApplicator(
        "custom", source=source_slice, target=target_slice
    )
    source_ds, target_ds = shifter.apply_shift(nih_ds, num_proc=6)

    source_ds = source_ds.with_transform(
        partial(apply_transforms, transforms=transforms),
        columns=["features"],
        output_all_columns=True,
    )
    target_ds = target_ds.with_transform(
        partial(apply_transforms, transforms=transforms),
        columns=["features"],
        output_all_columns=True,
    )

    detector = Detector(
        "sensitivity_test",
        reductor=Reductor(dr_method="bbse-soft", model=model, device=device),
        tester=TSTester(tester_method="mmd"),
        source_sample_size=1000,
        target_sample_size=[50, 100, 200, 300, 400, 600, 800, 1000],
        num_runs=3,
    )
    results[name] = detector.detect_shift(source_ds, target_ds)
plot_drift_experiment(results)

#### Experiment w/ Models Trained on Different Datasets

In [None]:
models = {
    "MODEL: NIH": "densenet121-res224-nih",
    "MODEL: CHEXPERT": "densenet121-res224-chex",
    "MODEL: PADCHEST": "densenet121-res224-pc",
}
results = {}

for model_name, model in models.items():
    detector = Detector(
        "sensitivity_test",
        reductor=Reductor(dr_method="bbse-soft", model=DenseNet(model), device=device),
        tester=TSTester(tester_method="mmd"),
        source_sample_size=1000,
        target_sample_size=[50, 100, 200, 300, 400, 600, 800, 1000],
        num_runs=3,
    )
    results[model_name] = detector.detect_shift(source_ds, target_ds)
plot_drift_experiment(results)

## Rolling Window

#### Experiment w/ Synthetic Timestamps Using Biweekly Window

In [None]:
model = DenseNet(weights="densenet121-res224-all")
detector = Detector(
    "rolling_window_drift",
    reductor=Reductor(dr_method="bbse-soft", model=model, device=device),
    tester=TSTester(tester_method="mmd"),
    source_sample_size=1000,
    target_sample_size=50,
    timestamp_column="timestamp",
    window_size="2W",
)

results = detector.detect_shift(source_ds, target_ds)

plot_drift_timeseries(results)