In [7]:
import os
import numpy as np
import torch
import pandas as pd
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import protector as protect
from utils.cli_utils import softmax_ent

from tent import Tent, configure_model, collect_params
from typing import Sequence, Tuple, Dict, Optional
import argparse

from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

from utilities import *  ## created by me
from plotting import *  ## created by me

%load_ext autoreload
%autoreload 2

In [2]:
CORRUPTIONS = (
    "shot_noise",
    "motion_blur",
    "snow",
    "pixelate",
    "gaussian_noise",
    "defocus_blur",
    "brightness",
    "fog",
    "zoom_blur",
    "frost",
    "glass_blur",
    "impulse_noise",
    "contrast",
    "jpeg_compression",
    "elastic_transform",
)

In [None]:
## ENTER PARAMETERS ##

# Manual settings for arguments
args = type("Args", (), {})()  # Create a simple namespace object
args.device = "cpu"  # Change this manually as needed
args.method = "none"  # Options: 'none' or 'tent'
args.corruption = "gaussian_noise"  # Choose from CORRUPTIONS
args.all_corruptions = False  # Set to True to test all corruptions
args.n_examples = 1000
args.batch_size = 64

# Set torch seed for replicability (don't know if this preserves consistency when using different devices)
torch.manual_seed(42)

Basic set up where we load clean CIFAR-10 and then test on corrupted version. This is a good reference to get a feel for how everyting works together.

In [None]:
# Dynamically set device to best available option
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

# Define normalization transform using CIFAR-10 mean and std values
transform = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616))

# Load pre-trained model and move to appropriate device
print("🚀 Loading model...")
model = get_model(args.method, device)

🚀 Loading model...


Using cache found in C:\Users\Mikel/.cache\torch\hub\chenyaofo_pytorch-cifar-models_master


#### BBSE ODS False Positives

In [None]:
from itertools import combinations
import numpy as np
import argparse
import torch
from torch.utils.data import DataLoader, Subset
from weighted_cdf_bbse_ods import (
    BBSEODSWeightedCDF,
    estimate_confusion_matrix,
    estimate_target_distribution_from_preds,
)

fpr_by_size = {}

for num_classes in [2]:  # Adjust as needed
    candidate_class_sets = list(combinations(range(10), num_classes))
    entropy_peaks = {}
    threshold_crossed = {}

    for subset in candidate_class_sets[:3]:
        print(f"\n🔎 Evaluating label shift ({num_classes}-class): {subset}")

        split = 4000

        # 1) Load clean source stream using only this subset of classes
        x_src, y_src = load_cifar10_label_shift(keep_classes=subset, n_examples=8000, shift_point=split)
        source_dataset = BasicDataset(x_src, y_src, transform=transform)

        source_loader = DataLoader(Subset(source_dataset, range(0, split)), batch_size=args.batch_size, shuffle=False)

        source_ents = []
        source_labels = []
        with torch.no_grad():
            for x_batch, labels in source_loader:
                x_batch = x_batch.to(device)
                logits = model(x_batch)
                probs = torch.softmax(logits, dim=1)
                entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=1)
                y_hat = torch.argmax(probs, dim=1)

                source_ents.extend(entropy.cpu().tolist())
                # We have access to ground truth labels for init of CDF, so we are using them here
                source_labels.extend(labels.cpu().tolist())

        # BBSE: Estimate confusion matrix and p_source from clean source stream
        confusion_matrix, p_source = estimate_confusion_matrix(model, source_loader, device)

        # Load label-shifted stream (same subset)
        x_shift, y_shift = load_cifar10_label_shift(keep_classes=subset, n_examples=8000, shift_point=split)
        full_dataset = BasicDataset(x_shift, y_shift, transform=transform)

        test_loader = DataLoader(
            Subset(full_dataset, range(split, len(full_dataset))), batch_size=args.batch_size, shuffle=False
        )

        # 4. Estimate p_test from shifted half of the stream
        p_test = estimate_target_distribution_from_preds(model, test_loader, device)

        # 5. Create BBSE/ODS weighted protector
        protector = protect.get_bbse_ods_weighted_protector_from_ents(
            source_ents,
            p_test,
            p_source,
            source_labels,
            confusion_matrix,
            0.05,  # ods_alpha
            argparse.Namespace(gamma=1 / (8 * np.sqrt(3)), eps_clip=1.8, device=device),
        )

        # LOUIS: This is for stat tracking i guess?
        mask = torch.tensor([i in subset for i in range(10)], dtype=torch.bool)
        p_source = p_source.to("cpu") * mask
        p_source /= p_source.sum() + 1e-8

        p_test_true = protector.cdf.p_test_true.to("cpu") * mask
        p_test_true /= p_test_true.sum() + 1e-8

        # 6) Evaluate full test stream and run martingale
        loader = DataLoader(full_dataset, batch_size=args.batch_size, shuffle=False)
        # ents, acc, logits_list, labels_list = evaluate(model, loader, device)

        # Initialize tracking variables
        entropies = []
        correct, total = 0, 0
        logits_list, labels_list = [], []
        logs, eps = [], []

        # Reset the protector for a fresh start
        protector.reset()

        # Process batch-by-batch and update ODS as we go
        with torch.no_grad():
            for x, y in loader:
                x, y = x.to(device), y.to(device)
                logits = model(x)
                batch_entropies = softmax_ent(logits).tolist()

                # Get predictions and update accuracy tracking
                preds = torch.argmax(logits, dim=1)
                correct += (preds == y).sum().item()
                total += y.size(0)

                # Store logits and labels for later analysis
                logits_list.append(logits.cpu())
                labels_list.append(y.cpu())

                # KEY STEP: Update protector with new pseudo-labels (ODS update)
                protector.cdf.batch_ods_update(preds.cpu().tolist())

                # Process each entropy value through the martingale detector
                for z in batch_entropies:
                    # Get CDF value (which now uses updated weights)
                    u = protector.cdf(z)
                    # Update martingale
                    protector.protect_u(u)
                    # Store results
                    logs.append(np.log(protector.martingales[-1] + 1e-8))
                    eps.append(protector.epsilons[-1])
                    # Store entropy for later reference
                    entropies.append(z)

        # Finalize results
        acc = correct / total
        key = f"labelshift_{num_classes}cls_{'_'.join(map(str, subset))}"
        ents = np.array(entropies)
        results = {key: {"log_sj": logs, "eps": eps}}
        result = results[key]

        # Record statistics for the main analysis
        entropy_peaks[key] = np.max(ents)
        threshold_crossed[key] = np.max(result["log_sj"]) > np.log(100)

        print("Subset:", subset)
        print("p_source:", p_source.numpy())
        print("p_test_true:", p_test_true.numpy())
        print("min p_source:", p_source.min().item(), "min p_t:", p_test_true.min().item())

    # 8) Calculate and store FPR
    n_tests = len(threshold_crossed)
    n_crossed = sum(threshold_crossed.values())
    fpr = n_crossed / n_tests
    fpr_by_size[num_classes] = fpr

    print(f"\n➡️ {n_crossed} / {n_tests} ({num_classes}-class) shifts triggered detection — FPR = {fpr:.3f}\n")

# fpr_by_size now holds clean, valid FPRs per class subset size.


🔎 Evaluating label shift (2-class): (0, 1)


[32m2025-05-22 14:15:24.900[0m | [1mINFO    [0m | [36mprotector[0m:[36mset_gamma[0m:[36m47[0m - [1msetting gamma val to 0.07216878364870323[0m
[32m2025-05-22 14:15:24.902[0m | [1mINFO    [0m | [36mprotector[0m:[36mset_eps_clip_val[0m:[36m43[0m - [1msetting epsilon clip val to 1.8[0m


Subset: (0, 1)
p_source: [0.49807933 0.5019206  0.         0.         0.         0.
 0.         0.         0.         0.        ]
p_test_true: [0.5026567  0.49734333 0.         0.         0.         0.
 0.         0.         0.         0.        ]
min p_source: 0.0 min p_t: 0.0

🔎 Evaluating label shift (2-class): (0, 2)


[32m2025-05-22 14:16:23.652[0m | [1mINFO    [0m | [36mprotector[0m:[36mset_gamma[0m:[36m47[0m - [1msetting gamma val to 0.07216878364870323[0m
[32m2025-05-22 14:16:23.653[0m | [1mINFO    [0m | [36mprotector[0m:[36mset_eps_clip_val[0m:[36m43[0m - [1msetting epsilon clip val to 1.8[0m


Subset: (0, 2)
p_source: [0.47730058 0.         0.52269936 0.         0.         0.
 0.         0.         0.         0.        ]
p_test_true: [0.5079023  0.         0.49209768 0.         0.         0.
 0.         0.         0.         0.        ]
min p_source: 0.0 min p_t: 0.0

🔎 Evaluating label shift (2-class): (0, 3)


[32m2025-05-22 14:17:22.381[0m | [1mINFO    [0m | [36mprotector[0m:[36mset_gamma[0m:[36m47[0m - [1msetting gamma val to 0.07216878364870323[0m
[32m2025-05-22 14:17:22.382[0m | [1mINFO    [0m | [36mprotector[0m:[36mset_eps_clip_val[0m:[36m43[0m - [1msetting epsilon clip val to 1.8[0m


Subset: (0, 3)
p_source: [0.49999997 0.         0.         0.49999997 0.         0.
 0.         0.         0.         0.        ]
p_test_true: [0.5061103  0.         0.         0.49388963 0.         0.
 0.         0.         0.         0.        ]
min p_source: 0.0 min p_t: 0.0

➡️ 0 / 3 (2-class) shifts triggered detection — FPR = 0.000

