### Synthetic Drift Detection ###

In [None]:
import datetime
import os
import sys
from functools import reduce

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

sys.path.append("../..")

import torch
import torchvision.transforms as transforms
from torch.utils.data import Subset
from torchxrayvision.datasets import NIH_Dataset, XRayCenterCrop, XRayResizer

from drift_detection.drift_detector.detector import Detector
from drift_detection.drift_detector.experimenter import Experimenter
from drift_detection.drift_detector.plotter import (
    brightness,
    colors,
    colorscale,
    errorfill,
    linestyles,
    markers,
    plot_drift_samples_pval,
    plot_pr,
    plot_roc,
)
from drift_detection.drift_detector.reductor import Reductor
from drift_detection.drift_detector.synthetic_applicator import (
    SyntheticShiftApplicator,
    apply_predefined_shift,
)
from drift_detection.drift_detector.tester import DCTester, TSTester
from drift_detection.drift_detector.utils import scale

IMAGE_PATH = os.environ["NIHCXR_IMAGE_PATH"]
CSV_PATH = os.environ["NIHCXR_CSV_PATH"]

# load NIH dataset
dataset = NIH_Dataset(
    IMAGE_PATH,
    CSV_PATH,
    views=["AP", "PA"],
    unique_patients=False,
    transform=transforms.Compose([XRayCenterCrop(), XRayResizer(224, engine="cv2")]),
)

# grab random subset of 2000 images from dataset
indices = np.random.randint(len(dataset), size=2000)
dataset = Subset(dataset, indices)

reductor = Reductor(
    dr_method="TAE_txrv_CNN",
)

tester = TSTester(
    tester_method="mmd",
)


detector = Detector(
    reductor=reductor,
    tester=tester,
    p_val_threshold=0.05,
)
detector.fit(dataset)


shiftapplicator = SyntheticShiftApplicator(
    shift_type="gn_shift",
)

baseline_experiment = Experimenter(
    "sensitivity_test",
    detector=detector,
)

drift_experiment = Experimenter(
    "sensitivity_test",
    detector=detector,
    shiftapplicator=shiftapplicator,
)

baseline_results = baseline_experiment.run(dataset)
drift_results = drift_experiment.run(dataset)

results = {"baseline": baseline_results, "experiment": drift_results}

In [None]:
plot_drift_samples_pval(results, 0.05)