# NIHCXR Synthetic Drift - Gaussian Shift

## Load Libraries

In [5]:
from cyclops.monitor.datasets import NIHCXRDataset

from cyclops.monitor import (
    Detector,
    Experimenter,
    Reductor,
    SyntheticShiftApplicator,
    TSTester,
)

from cyclops.monitor.plotter import plot_drift_samples_pval
from torchxrayvision.models import DenseNet

## Query Data

In [2]:
# Load the dataset
dataset = NIHCXRDataset(cfg_path="../../../cyclops/monitor/datasets/configs/nihcxr.yaml")
dataset, metadata, metadata_mapping = dataset.get_data()

## Initalize Reductor, Tester & Detector

In [7]:
reductor = Reductor(
    dr_method="bbse-soft",
    model=DenseNet,
    weights="densenet121-res224-all",
)

tester = TSTester(
    tester_method="mmd",
)

detector = Detector(
    reductor=reductor,
    tester=tester,
    device='cuda'
)

detector.fit(dataset, progress=False)

## Setup Baseline Experiment

In [None]:
baseline_experiment = Experimenter(
    "sensitivity_test",
    detector=detector,
)

## Setup Drift Experiments (Categorical Shift)

In [None]:
shiftapplicators = []
shift_type = ["categorical_shift"] * 3
cat_col = ["gender", "view", "age"]
target_categories = ["M", "PA", "18-35"]

for s_type, col, target in zip(shift_type, cat_col, target_categories):
    shiftapplicators.append(
        SyntheticShiftApplicator(
            shift_type=s_type,
            categorical_column=col,
            target_category=target,
        )
    )

experiments = []
for shiftapplicator in shiftapplicators:
    drift_experiment = Experimenter(
        "sensitivity_test",
        detector=detector,
        shiftapplicator=shiftapplicator,
    )
    experiments.append(drift_experiment)

## Run Experiments

In [None]:
baseline_results = baseline_experiment.run(dataset, metadata, metadata_mapping)
drift_results = []
for experiment in experiments:
    drift_results.append(experiment.run(dataset, metadata, metadata_mapping))

## Gather Results

In [None]:
results_dict = {}
results_dict.update({"baseline": baseline_results})
for itr, result in enumerate(drift_results):
    results_dict.update({f"{cat_col[itr]}: {target_categories[itr]}": result})

## Plot Experimental Results

In [None]:
plot_drift_samples_pval(results_dict, 0.05)