In [1]:
try:
    import google.colab  # noqa: F401

    %pip install -q dataeval
except Exception:
    pass

In [2]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import models

# Drift
from dataeval.detectors.drift import DriftCVM, DriftKS, DriftMMD
from dataeval.metrics.bias import label_parity
from dataeval.utils.data import collate
from dataeval.utils.data.datasets import VOCDetection

# Set a random seed
rng = np.random.default_rng(213)

# Set default torch device for notebook
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device(device)

In [3]:
# Define the embedding network
class EmbeddingNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Load in pretrained resnet18 model
        self.model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        # Add an additional fully connected layer with an embedding dimension of 128
        self.model.fc = nn.Linear(self.model.fc.in_features, 128)

    def forward(self, x):
        """Run input data through the model"""

        return self.model(x)

In [4]:
embedding_net = EmbeddingNet()

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/dataeval/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0% 0.00/44.7M [00:00<?, ?B/s]

 12% 5.50M/44.7M [00:00<00:00, 54.8MB/s]

 24% 10.8M/44.7M [00:00<00:00, 50.2MB/s]

 37% 16.4M/44.7M [00:00<00:00, 53.9MB/s]

 49% 22.0M/44.7M [00:00<00:00, 55.5MB/s]

 63% 28.0M/44.7M [00:00<00:00, 57.7MB/s]

 75% 33.6M/44.7M [00:00<00:00, 56.9MB/s]

 88% 39.1M/44.7M [00:00<00:00, 56.0MB/s]

100% 44.7M/44.7M [00:00<00:00, 56.1MB/s]




In [5]:
# Define pretrained model transformations
preprocess = models.ResNet18_Weights.DEFAULT.transforms()

# Load the training dataset
train_ds = VOCDetection("./data", year="2011", image_set="train", download=False, transform=preprocess)
# Load the "operational" dataset
operational_ds = VOCDetection("./data", year="2011", image_set="val", download=False, transform=preprocess)


print(train_ds.info())
print(operational_ds.info())

Train
-----
Dataset VOCDetection
    Number of datapoints: 5717
    Root location: ./data
    StandardTransform
Transform: ImageClassification(
               crop_size=[224]
               resize_size=[256]
               mean=[0.485, 0.456, 0.406]
               std=[0.229, 0.224, 0.225]
               interpolation=InterpolationMode.BILINEAR
           )

Val
---
Dataset VOCDetection
    Number of datapoints: 5823
    Root location: ./data
    StandardTransform
Transform: ImageClassification(
               crop_size=[224]
               resize_size=[256]
               mean=[0.485, 0.456, 0.406]
               std=[0.229, 0.224, 0.225]
               interpolation=InterpolationMode.BILINEAR
           )



In [6]:
# This step can take ~1 minute depending on hardware

# Create training batches and targets
train_embs, train_targets, _ = collate(train_ds, model=embedding_net)

# Create operational batches and targets
operational_embs, operational_targets, _ = collate(operational_ds, model=embedding_net)

In [7]:
print(train_embs.shape)
print(operational_embs.shape)

torch.Size([5717, 128])
torch.Size([5823, 128])


In [8]:
# A type alias for all of the drift detectors
DriftDetector = DriftMMD | DriftCVM | DriftKS

# Create a mapping for the detectors to iterate over
detectors: dict[str, DriftDetector] = {
    "MMD": DriftMMD(train_embs),
    "CVM": DriftCVM(train_embs),
    "KS": DriftKS(train_embs),
}

In [9]:
# Iterate and print the name of the detector class and its boolean drift prediction
for name, detector in detectors.items():
    print(f"{name} detected drift? {detector.predict(operational_embs).drifted}")

MMD detected drift? True
CVM detected drift? False


KS detected drift? False


In [10]:
# Creates a normal distribution around the operational embeddings
noisy_embs = torch.normal(mean=operational_embs)

In [11]:
# Iterate and print the name of the detector class and its boolean drift prediction
for name, detector in detectors.items():
    print(f"{name} detected drift? {detector.predict(noisy_embs).drifted}")

MMD detected drift? True


CVM detected drift? True


KS detected drift? True


In [12]:
# The VOC dataset has 20 classes
label_parity(train_targets.labels, operational_targets.labels, num_classes=20).p_value

np.float64(0.949856067521638)