In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import numpy as np
import torch
import pandas as pd
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import protector as protect
from utils.cli_utils import softmax_ent

from tent import Tent, configure_model, collect_params
from typing import Sequence, Tuple, Dict, Optional
import argparse

from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

from utilities import *  ## created by me
from plotting import *  ## created by me

In [None]:
CORRUPTIONS = (
    "shot_noise",
    "motion_blur",
    "snow",
    "pixelate",
    "gaussian_noise",
    "defocus_blur",
    "brightness",
    "fog",
    "zoom_blur",
    "frost",
    "glass_blur",
    "impulse_noise",
    "contrast",
    "jpeg_compression",
    "elastic_transform",
)

In [None]:
## ENTER PARAMETERS ##

# Manual settings for arguments
args = type("Args", (), {})()  # Create a simple namespace object
args.device = "cpu"  # Change this manually as needed
args.method = "none"  # Options: 'none' or 'tent'
args.corruption = "gaussian_noise"  # Choose from CORRUPTIONS
args.all_corruptions = False  # Set to True to test all corruptions
args.n_examples = 1000
args.batch_size = 64

# Set torch seed for replicability (don't know if this preserves consistency when using different devices)
torch.manual_seed(42)

Basic set up where we load clean CIFAR-10 and then test on corrupted version. This is a good reference to get a feel for how everyting works together.

In [None]:
# Dynamically set device to best available option
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

# Define normalization transform using CIFAR-10 mean and std values
transform = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616))

# Load pre-trained model and move to appropriate device
print("🚀 Loading model...")
model = get_model(args.method, device)

# Load clean CIFAR-10 test data to compute source entropies
print("📦 Loading clean CIFAR-10 as source entropy")
clean_ds = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transforms.Compose([transforms.ToTensor(), transform])
)
clean_loader = DataLoader(clean_ds, batch_size=args.batch_size, shuffle=False)
source_ents, _, _, _ = evaluate(model, clean_loader, device)

# Initialize protector with source entropies for shift detection
protector = protect.get_protector_from_ents(
    source_ents, argparse.Namespace(gamma=1 / (8 * np.sqrt(3)), eps_clip=1.8, device=device)
)

# --- Label Shift Severity Sweep ---
label_shift_severities = {
    "none": list(range(10)),  # no actual shift
    "medium": list(range(5)),  # moderate shift
    "severe": [0, 1, 2],  # strong shift
    "extreme": [0],  # extreme shift
}

entropy_streams = {}
accs = {}
logits_list_dict = {}
logits_labels_dict = {}

for severity_name, keep_classes in label_shift_severities.items():
    print(f"🔎 Evaluating label shift severity: {severity_name}")
    x, y = load_cifar10_label_shift(keep_classes=keep_classes, n_examples=8000, shift_point=4000)
    dataset = BasicDataset(x, y, transform=transform)
    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

    ents, acc, logits_list, labels_list = evaluate(model, loader, device)

    key = f"label_shift_{severity_name}"
    entropy_streams[key] = ents
    accs[key] = acc
    logits_list_dict[key] = logits_list
    logits_labels_dict[key] = labels_list

# Run martingale-based shift detection
results = run_martingale(entropy_streams, protector)

# Print results
print("\n📊 Accuracy under label shift severities:")
for key, value in accs.items():
    print(f"{key}: {value * 100:.2f}%")


In [None]:
# Set up the plot
plt.figure(figsize=(12, 6))

# Plot entropy streams for each severity level
plt.plot(entropy_streams["label_shift_none"], label="None", alpha=0.7)
plt.plot(entropy_streams["label_shift_medium"], label="Medium", alpha=0.7)
plt.plot(entropy_streams["label_shift_severe"], label="Severe", alpha=0.7)

# Add vertical line at shift point (assuming shift point at 4000/batch_size)
shift_point = 4000
plt.axvline(x=shift_point, color="r", linestyle="--", alpha=0.5, label="Shift Point")

# Customize the plot
plt.xlabel("Batch Index")
plt.ylabel("Entropy")
plt.title("Entropy Streams for Different Label Shift Severities")
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
label_shift_levels = {
    "mild": list(range(10)),
    "medium": list(range(5)),
    "severe": [0, 1, 2],
    "test": [6, 7, 8],
    "test2": [3, 4, 5],
    "extreme": [0],
}

accs, entropy_streams, logits_list_dict, logits_labels_dict = {}, {}, {}, {}

for severity_name, keep_classes in label_shift_levels.items():
    print(f"🔎 Label shift severity: {severity_name}")

    loader, is_clean, labels = load_clean_then_label_shift_sequence(
        keep_classes=keep_classes,
        n_examples=8000,
        shift_point=4000,
        data_dir="./data/cifar-10-batches-py",
        transform=transform,
        batch_size=args.batch_size,
    )

    ents, acc, logits_list, labels_list = evaluate(model, loader, device)

    key = f"labelshift_{severity_name}"
    entropy_streams[key] = ents
    accs[key] = acc
    logits_list_dict[key] = logits_list
    logits_labels_dict[key] = labels_list

# Run martingale
results = run_martingale(entropy_streams, protector)

# Add accuracy over time
for severity_name in label_shift_levels:
    key = f"labelshift_{severity_name}"
    acc_time = compute_accuracy_over_time_from_logits(logits_list_dict[key], logits_labels_dict[key])
    results[key]["accs"] = acc_time


In [None]:
# Create dictionaries to store values for each label shift severity
log_sj_dict = {}
epsilons_dict = {}
accuracies_dict = {}
entropy_dict = {}

# Extract values for each severity level
for severity in label_shift_levels.keys():
    key = f"labelshift_{severity}"
    if key in results:
        log_sj_dict[severity] = results[key]["log_sj"]
        epsilons_dict[severity] = results[key]["eps"]
        accuracies_dict[severity] = results[key]["accs"]
        entropy_dict[severity] = entropy_streams[key]

# Plot the combined results
plot_combined_martingale_accuracy_severity(
    log_sj_dict,
    epsilons_dict,
    accuracies_dict,
    entropy_dict=entropy_dict,
    batch_size=64,
    title="Label Shift Comparison Across Severities",
)


In [None]:
max(results["labelshift_test"]["log_sj"])

In [None]:
np.log(100)

In [None]:
detection_delays = compute_detection_delays_from_threshold(log_sj_dict)

# Call the function to plot detection delays
plot_detection_delays(detection_delays)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

for k in entropy_streams:
    sns.kdeplot(entropy_streams[k], label=k)

plt.title("Entropy Distribution by Label Subset")
plt.xlabel("Entropy")
plt.ylabel("Density")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
fpr_by_size

In [None]:
from itertools import combinations
import numpy as np

fpr_by_size = {}  # ← new: store FPR per class‑set size

for num_classes in [1]:
    candidate_class_sets = list(combinations(range(10), num_classes))

    entropy_peaks = {}
    threshold_crossed = {}

    for subset in candidate_class_sets:
        print(f"Evaluating label shift ({num_classes}-class): {subset}")

        # 1) Load synthetic label‑shift stream
        x, y = load_cifar10_label_shift(keep_classes=subset, n_examples=8000, shift_point=4000)
        dataset = BasicDataset(x, y, transform=transform)
        loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

        # 2) Forward pass
        ents, acc, logits_list, labels_list = evaluate(model, loader, device)

        # 3) Martingale test
        key = f"labelshift_{num_classes}cls_{'_'.join(map(str, subset))}"
        result = run_martingale({key: ents}, protector)[key]

        # 4) Record stats
        entropy_peaks[key] = np.max(ents)
        threshold_crossed[key] = np.max(result["log_sj"]) > np.log(100)

    # --- compute per‑size false‑positive rate -------------------------
    n_tests = len(threshold_crossed)
    n_crossed = sum(threshold_crossed.values())
    fpr = n_crossed / n_tests
    fpr_by_size[num_classes] = fpr  # save for later use

    print(f"\n➡️  {n_crossed} / {n_tests} ({num_classes}-class) shifts triggered detection  – FPR = {fpr:.3f}\n")

# ----------------------------------------------------------------------
# Now fpr_by_size contains something like {2: 0.20, 3: 0.35, ...}
# You can save it to disk, log to WandB, etc.


In [None]:
from collections import defaultdict
import random


class PBRSBuffer:
    def __init__(self, capacity=64, num_classes=10):
        self.capacity = capacity
        self.num_classes = num_classes
        self.target_per_class = capacity // num_classes
        self.buffer = []
        self.label_counts = defaultdict(int)

    def accept(self, y_hat):
        # Always accept if there's room
        if len(self.buffer) < self.capacity:
            return True
        # Reject if class is already full
        if self.label_counts[y_hat] >= self.target_per_class:
            return False
        return True

    def add(self, x, entropy, y_hat):
        if len(self.buffer) >= self.capacity:
            for i, (_, _, label) in enumerate(self.buffer):
                if self.label_counts[label] > self.target_per_class:
                    self.label_counts[label] -= 1
                    del self.buffer[i]
                    break
            else:
                removed = self.buffer.pop(0)
                self.label_counts[removed[2]] -= 1

        self.buffer.append((x, entropy, y_hat))
        self.label_counts[y_hat] += 1

    def full(self):
        return len(self.buffer) >= self.capacity

    def get_entropies(self):
        return [entry[1] for entry in self.buffer]

    def reset(self):
        self.buffer = []
        self.label_counts = defaultdict(int)

In [None]:
from itertools import combinations
import numpy as np

fpr_by_size = {}  # ← new: store FPR per class‑set size

for num_classes in [3]:
    candidate_class_sets = list(combinations(range(10), num_classes))

    entropy_peaks = {}
    threshold_crossed = {}

    for subset in candidate_class_sets:
        print(f"Evaluating label shift ({num_classes}-class): {subset}")

        # 1) Load synthetic label‑shift stream
        x, y = load_cifar10_label_shift(keep_classes=subset, n_examples=8000, shift_point=4000)
        dataset = BasicDataset(x, y, transform=transform)
        loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

        # Step 2: Forward pass with PBRS buffering
        confidence_threshold = 0.8  # You can tune this

        buffer = PBRSBuffer(capacity=512, num_classes=num_classes)
        with torch.no_grad():
            for x_batch, _ in loader:
                x_batch = x_batch.to(device)
                logits = model(x_batch)
                probs = torch.softmax(logits, dim=1)
                entropies = -torch.sum(probs * torch.log(probs + 1e-8), dim=1)
                pseudo_labels = torch.argmax(probs, dim=1)
                max_probs = torch.max(probs, dim=1).values  # ← get model confidence

                # Add only confident, accepted samples to buffer
                for entropy, y_hat, confidence in zip(
                    entropies.cpu().tolist(), pseudo_labels.cpu().tolist(), max_probs.cpu().tolist()
                ):
                    if confidence > confidence_threshold and buffer.accept(y_hat):
                        buffer.add(None, entropy, y_hat)

        # Step 3: Run martingale test on buffered entropies
        key = f"labelshift_{num_classes}cls_{'_'.join(map(str, subset))}_PBRS"

        ents = buffer.get_entropies()
        result = run_martingale({key: np.array(ents)}, protector)[key]

        # Step 4: Store stats
        entropy_peaks[key] = np.max(ents)
        threshold_crossed[key] = np.max(result["log_sj"]) > np.log(100)

    # --- compute per‑size false‑positive rate -------------------------
    n_tests = len(threshold_crossed)
    n_crossed = sum(threshold_crossed.values())
    fpr = n_crossed / n_tests
    fpr_by_size[num_classes] = fpr  # save for later use

    print(f"\n➡️  {n_crossed} / {n_tests} ({num_classes}-class) shifts triggered detection  – FPR = {fpr:.3f}\n")

In [None]:
## ENTER PARAMETERS ##

# Manual settings for arguments
args = type("Args", (), {})()  # Create a simple namespace object
args.device = "cpu"  # Change this manually as needed
args.method = "none"  # Options: 'none' or 'tent'
args.corruption = "gaussian_noise"  # Choose from CORRUPTIONS
args.all_corruptions = False  # Set to True to test all corruptions
args.n_examples = 1000
args.batch_size = 64

# Set torch seed for replicability (don't know if this preserves consistency when using different devices)
torch.manual_seed(42)

In [None]:
def confidence_accept(conf, threshold=0.8, softness=15):
    """
    Accept sample with probability based on confidence.
    - High confidence → near-certain acceptance
    - Medium confidence → partial acceptance
    - Low confidence → mostly rejected
    """
    prob = 1 / (1 + np.exp(-softness * (conf - threshold)))
    return random.random() < prob

In [None]:
def confidence_accept(conf, threshold=0.8, softness=15):
    """
    Accept sample with probability based on confidence.
    - High confidence → near-certain acceptance
    - Medium confidence → partial acceptance
    - Low confidence → mostly rejected
    """
    prob = 1 / (1 + np.exp(-softness * (conf - threshold)))
    return random.random() < prob


accs, entropy_streams, logits_list_dict, logits_labels_dict = {}, {}, {}, {}
results = {}
corruptions = CORRUPTIONS if args.all_corruptions else [args.corruption]

use_pbrs = True
confidence_threshold = 0.8  # You can tune this

for corruption in corruptions:
    for severity in range(1, 6):
        print(f"🔎 {corruption} severity {severity} (clean → corrupt)")

        # Load clean-to-corrupt stream
        loader, is_clean, labels = load_clean_then_corrupt_sequence(
            corruption=corruption,
            severity=severity,
            n_examples=4000,
            data_dir="./data",
            transform=transform,
            batch_size=args.batch_size,
        )

        if use_pbrs:
            buffer = PBRSBuffer(capacity=512, num_classes=10)

        logits_list = []
        labels_list = []
        entropy_list = []

        with torch.no_grad():
            for x_batch, y_batch in loader:
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                logits = model(x_batch)
                probs = torch.softmax(logits, dim=1)
                entropies = -torch.sum(probs * torch.log(probs + 1e-8), dim=1)
                pseudo_labels = torch.argmax(probs, dim=1)
                max_probs = torch.max(probs, dim=1).values

                for entropy, y_hat, conf in zip(
                    entropies.cpu().tolist(), pseudo_labels.cpu().tolist(), max_probs.cpu().tolist()
                ):
                    if use_pbrs:
                        if confidence_accept(conf, threshold=0.8, softness=15) and buffer.accept(y_hat):
                            buffer.add(None, entropy, y_hat)
                    else:
                        entropy_list.append(entropy)

                logits_list.append(logits.cpu())
                labels_list.append(y_batch.cpu())

        key = f"{corruption}_s{severity}"

        # Store entropy stream
        ents = np.array(buffer.get_entropies()) if use_pbrs else np.array(entropy_list)
        entropy_streams[key] = ents
        logits_list_dict[key] = logits_list
        logits_labels_dict[key] = labels_list


# Run martingale detection on all entropy streams
results = run_martingale(entropy_streams, protector)

# Compute accuracy over time
for corruption in corruptions:
    for severity in range(1, 6):
        key = f"{corruption}_s{severity}"
        accs = compute_accuracy_over_time_from_logits(logits_list_dict[key], logits_labels_dict[key])
        results[key]["accs"] = accs

In [None]:
# Get max log_sj values for each severity level
max_log_sj = {}
for severity in results:
    max_log_sj[severity] = max(results[severity]["log_sj"])
    print(f"Max log_sj for {severity}: {max_log_sj[severity]}")

# Sort by severity level for cleaner display
sorted_max_log_sj = dict(sorted(max_log_sj.items()))

# Print results
for severity, value in sorted_max_log_sj.items():
    print(f"{severity}: {value:.2f}")

In [None]:
# Get max log_sj values for each severity level
max_log_sj = {}
for severity in results:
    max_log_sj[severity] = max(results[severity]["log_sj"])
    print(f"Max log_sj for {severity}: {max_log_sj[severity]}")

# Sort by severity level for cleaner display
sorted_max_log_sj = dict(sorted(max_log_sj.items()))

# Print results
for severity, value in sorted_max_log_sj.items():
    print(f"{severity}: {value:.2f}")

In [None]:
np.log(100)

In [None]:
# Create dictionaries to store values for each severity
log_sj_dict = {}
epsilons_dict = {}
accuracies_dict = {}

# Extract values for each severity level
for severity in range(1, 6):
    key = f"gaussian_noise_s{severity}"
    if key in results:
        log_sj_dict[key] = results[key]["log_sj"]
        epsilons_dict[key] = results[key]["eps"]
        accuracies_dict[key] = results[key]["accs"]

# Plot the combined results
if use_pbrs:
    plot_combined_martingale_accuracy_severity_pbrs(
        log_sj_dict,
        epsilons_dict,
        accuracies_dict,
        entropy_dict=entropy_streams,
        batch_size=64,
        buffer_size=256,
        title="Gaussian Noise Comparison Across Severities (PBRS)",
    )
else:
    plot_combined_martingale_accuracy_severity(
        log_sj_dict,
        epsilons_dict,
        accuracies_dict,
        entropy_dict=entropy_streams,
        batch_size=64,
        title="Gaussian Noise Comparison Across Severities",
    )

Implement Component 2: weighted CDF

In [None]:
from weighted_cdf import *
from protector import *

In [None]:
# Dynamically set device to best available option
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

# Define normalization transform using CIFAR-10 mean and std values
transform = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616))

# Load pre-trained model and move to appropriate device
print("🚀 Loading model...")
model = get_model(args.method, device)

In [None]:
from torch.utils.data import DataLoader, Subset
from torchvision.transforms import Compose, ToTensor

# 1. Wrap full dataset
# 1) Load synthetic label‑shift stream

split = 4000
x, y = load_cifar10_label_shift(keep_classes=subset, n_examples=8000, shift_point=split)
full_dataset = BasicDataset(x, y, transform=transform)

# 2. Split into clean and shifted
clean_loader = DataLoader(Subset(full_dataset, range(0, split)), batch_size=args.batch_size, shuffle=False)

test_loader = DataLoader(
    Subset(full_dataset, range(split, len(full_dataset))), batch_size=args.batch_size, shuffle=False
)

# 1. Estimate p_s from clean CIFAR-10
p_s = estimate_label_distribution(model, clean_loader, device)

# 2. Estimate p_t and collect pseudo-labels on test stream
p_t, _ = estimate_pseudo_label_distribution(model, test_loader, device)
_, source_pseudo_labels = estimate_pseudo_label_distribution(model, clean_loader, device)

# 3. Collect entropies on clean source set
# source_ents = compute_entropies(model, clean_loader, device)  # or reuse from before
source_ents = []
source_pseudo_labels = []

with torch.no_grad():
    for x, _ in clean_loader:
        x = x.to(device)
        logits = model(x)
        probs = torch.softmax(logits, dim=1)
        entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=1)
        y_hat = torch.argmax(probs, dim=1)

        source_ents.extend(entropy.cpu().tolist())
        source_pseudo_labels.extend(y_hat.cpu().tolist())

# 4. Create weighted CDF object
weighted_cdf = WeightedCDF(entropies=source_ents, pseudo_labels=source_pseudo_labels, p_s=p_s.numpy(), p_t=p_t.numpy())
protector = protect.get_weighted_protector_from_ents(
    source_ents,
    source_pseudo_labels,
    p_s,
    p_t,
    argparse.Namespace(gamma=1 / (8 * np.sqrt(3)), eps_clip=1.8, device=device),
)

#### Weighted CDF with BBSE weights

In [None]:
from torch.utils.data import DataLoader, Subset
from torchvision.transforms import Compose, ToTensor
from weighted_cdf_bbse import BBSEWeightedCDF, estimate_shift_weights

# 1. Wrap full dataset
# 1) Load synthetic label‑shift stream

split = 4000
x, y = load_cifar10_label_shift(keep_classes=subset, n_examples=8000, shift_point=split)
full_dataset = BasicDataset(x, y, transform=transform)

# 2. Split into clean and shifted
clean_loader = DataLoader(Subset(full_dataset, range(0, split)), batch_size=args.batch_size, shuffle=False)

test_loader = DataLoader(
    Subset(full_dataset, range(split, len(full_dataset))), batch_size=args.batch_size, shuffle=False
)

# nice wrapper function takes care of evertyhing
w, p_s, p_t_true, p_t_pred = estimate_shift_weights(model, clean_loader, test_loader, device)

# 3. Collect entropies on clean source set
# Could probably integrate entropy calculation into estimate_shift_weights for efficiency
source_ents = compute_entropies(model, clean_loader, device)  # or reuse from before

# source_ents = []
# source_pseudo_labels = []

# with torch.no_grad():
#     for x, _ in clean_loader:
#         x = x.to(device)
#         logits = model(x)
#         probs = torch.softmax(logits, dim=1)
#         entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=1)
#         #y_hat = torch.argmax(probs, dim=1)

#         source_ents.extend(entropy.cpu().tolist())
#         #source_pseudo_labels.extend(y_hat.cpu().tolist())

# 4. Create weighted CDF object
weighted_cdf = BBSEWeightedCDF(
    entropies=source_ents,
    pseudo_labels=pseudo_labels,
    weights=weights,
)

protector = protect.get_weighted_protector_from_ents(
    source_ents,
    source_pseudo_labels,
    p_s,
    p_t,
    argparse.Namespace(gamma=1 / (8 * np.sqrt(3)), eps_clip=1.8, device=device),
)

False positive rate

In [None]:
from itertools import combinations
import numpy as np
import argparse
import torch
from torch.utils.data import DataLoader, Subset
from weighted_cdf import *

fpr_by_size = {}

for num_classes in [2]:  # Adjust as needed
    candidate_class_sets = list(combinations(range(10), num_classes))
    entropy_peaks = {}
    threshold_crossed = {}

    for subset in candidate_class_sets[:3]:
        print(f"\n🔎 Evaluating label shift ({num_classes}-class): {subset}")

        split = 4000

        # 1) Load clean source stream using only this subset of classes
        x_src, y_src = load_cifar10_label_shift(keep_classes=subset, n_examples=8000, shift_point=split)
        source_dataset = BasicDataset(x_src, y_src, transform=transform)

        source_loader = DataLoader(Subset(source_dataset, range(0, split)), batch_size=args.batch_size, shuffle=False)

        # 2) Estimate p_s and collect source entropies + pseudo-labels
        p_s = estimate_label_distribution(model, source_loader, device)

        source_ents = []
        source_pseudo_labels = []
        with torch.no_grad():
            for x_batch, _ in source_loader:
                x_batch = x_batch.to(device)
                logits = model(x_batch)
                probs = torch.softmax(logits, dim=1)
                entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=1)
                y_hat = torch.argmax(probs, dim=1)

                source_ents.extend(entropy.cpu().tolist())
                source_pseudo_labels.extend(y_hat.cpu().tolist())

        # 3) Load label-shifted stream (same subset)
        x_shift, y_shift = load_cifar10_label_shift(keep_classes=subset, n_examples=8000, shift_point=split)
        full_dataset = BasicDataset(x_shift, y_shift, transform=transform)

        test_loader = DataLoader(
            Subset(full_dataset, range(split, len(full_dataset))), batch_size=args.batch_size, shuffle=False
        )

        # 4) Estimate p_t from shifted half of the stream
        p_t, _ = estimate_pseudo_label_distribution(model, test_loader, device)

        mask = torch.tensor([i in subset for i in range(10)], dtype=torch.bool)

        p_s = p_s * mask
        p_s /= p_s.sum() + 1e-8

        p_t = p_t * mask
        p_t /= p_t.sum() + 1e-8

        # 5) Create weighted protector
        protector = protect.get_weighted_protector_from_ents(
            source_ents,
            source_pseudo_labels,
            p_s,
            p_t,
            argparse.Namespace(gamma=1 / (8 * np.sqrt(3)), eps_clip=1.8, device=device),
        )

        # 6) Evaluate full test stream and run martingale
        loader = DataLoader(full_dataset, batch_size=args.batch_size, shuffle=False)
        ents, acc, logits_list, labels_list = evaluate(model, loader, device)

        key = f"labelshift_{num_classes}cls_{'_'.join(map(str, subset))}"
        result = run_martingale({key: ents}, protector)[key]

        # 7) Record statistics
        entropy_peaks[key] = np.max(ents)
        threshold_crossed[key] = np.max(result["log_sj"]) > np.log(100)

        print("Subset:", subset)
        print("p_s:", p_s.numpy())
        print("p_t:", p_t.numpy())
        print("min p_s:", p_s.min().item(), "min p_t:", p_t.min().item())

    # 8) Calculate and store FPR
    n_tests = len(threshold_crossed)
    n_crossed = sum(threshold_crossed.values())
    fpr = n_crossed / n_tests
    fpr_by_size[num_classes] = fpr

    print(f"\n➡️ {n_crossed} / {n_tests} ({num_classes}-class) shifts triggered detection — FPR = {fpr:.3f}\n")

# fpr_by_size now holds clean, valid FPRs per class subset size.

##### BBSE False positives

In [None]:
from itertools import combinations
import numpy as np
import argparse
import torch
from torch.utils.data import DataLoader, Subset
from weighted_cdf_bbse import BBSEWeightedCDF, estimate_shift_weights

fpr_by_size = {}

for num_classes in [2]:  # Adjust as needed
    candidate_class_sets = list(combinations(range(10), num_classes))
    entropy_peaks = {}
    threshold_crossed = {}

    for subset in candidate_class_sets[:3]:
        print(f"\n🔎 Evaluating label shift ({num_classes}-class): {subset}")

        split = 4000

        # 1) Load clean source stream using only this subset of classes
        x_src, y_src = load_cifar10_label_shift(keep_classes=subset, n_examples=8000, shift_point=split)
        source_dataset = BasicDataset(x_src, y_src, transform=transform)

        source_loader = DataLoader(Subset(source_dataset, range(0, split)), batch_size=args.batch_size, shuffle=False)

        source_ents = []
        source_labels = []
        with torch.no_grad():
            for x_batch, labels in source_loader:
                x_batch = x_batch.to(device)
                logits = model(x_batch)
                probs = torch.softmax(logits, dim=1)
                entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=1)
                y_hat = torch.argmax(probs, dim=1)

                source_ents.extend(entropy.cpu().tolist())
                # We have access to ground truth labels for init of CDF, so we are using them here
                source_labels.extend(labels.cpu().tolist())

        # 3) Load label-shifted stream (same subset)
        x_shift, y_shift = load_cifar10_label_shift(keep_classes=subset, n_examples=8000, shift_point=split)
        full_dataset = BasicDataset(x_shift, y_shift, transform=transform)

        test_loader = DataLoader(
            Subset(full_dataset, range(split, len(full_dataset))), batch_size=args.batch_size, shuffle=False
        )

        # 4) Estimate p_t from shifted half of the stream
        weights, p_s, p_t_true, p_t_pred = estimate_shift_weights(model, source_loader, test_loader, device)
        print(f"p_true: {p_t_true}")

        # 5) Create weighted protector
        protector = protect.get_bbse_weighted_protector_from_ents(
            source_ents,
            source_labels,
            weights,
            argparse.Namespace(gamma=1 / (8 * np.sqrt(3)), eps_clip=1.8, device=device),
        )

        # 6) Evaluate full test stream and run martingale
        loader = DataLoader(full_dataset, batch_size=args.batch_size, shuffle=False)
        ents, acc, logits_list, labels_list = evaluate(model, loader, device)

        key = f"labelshift_{num_classes}cls_{'_'.join(map(str, subset))}"
        result = run_martingale({key: ents}, protector)[key]

        # 7) Record statistics
        entropy_peaks[key] = np.max(ents)
        threshold_crossed[key] = np.max(result["log_sj"]) > np.log(100)

        print("Subset:", subset)
        print("p_s:", p_s.numpy())
        # TODO: LOUIS: Not sure if we should be using the original preds or the transformed ones
        print("p_t:", p_t_true.numpy())
        print("min p_s:", p_s.min().item(), "min p_t:", p_t_true.min().item())

    # 8) Calculate and store FPR
    n_tests = len(threshold_crossed)
    n_crossed = sum(threshold_crossed.values())
    fpr = n_crossed / n_tests
    fpr_by_size[num_classes] = fpr

    print(f"\n➡️ {n_crossed} / {n_tests} ({num_classes}-class) shifts triggered detection — FPR = {fpr:.3f}\n")

# fpr_by_size now holds clean, valid FPRs per class subset size.

#### BBSE ODS False Positives

In [None]:
from itertools import combinations
import numpy as np
import argparse
import torch
from torch.utils.data import DataLoader, Subset
from weighted_cdf_bbse_ods import (
    BBSEODSWeightedCDF,
    estimate_confusion_matrix,
    estimate_target_distribution_from_preds,
)

fpr_by_size = {}

for num_classes in [2]:  # Adjust as needed
    candidate_class_sets = list(combinations(range(10), num_classes))
    entropy_peaks = {}
    threshold_crossed = {}

    for subset in candidate_class_sets[:3]:
        print(f"\n🔎 Evaluating label shift ({num_classes}-class): {subset}")

        split = 4000

        # 1) Load clean source stream using only this subset of classes
        x_src, y_src = load_cifar10_label_shift(keep_classes=subset, n_examples=8000, shift_point=split)
        source_dataset = BasicDataset(x_src, y_src, transform=transform)

        source_loader = DataLoader(Subset(source_dataset, range(0, split)), batch_size=args.batch_size, shuffle=False)

        source_ents = []
        source_labels = []
        with torch.no_grad():
            for x_batch, labels in source_loader:
                x_batch = x_batch.to(device)
                logits = model(x_batch)
                probs = torch.softmax(logits, dim=1)
                entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=1)
                y_hat = torch.argmax(probs, dim=1)

                source_ents.extend(entropy.cpu().tolist())
                # We have access to ground truth labels for init of CDF, so we are using them here
                source_labels.extend(labels.cpu().tolist())

        # BBSE: Estimate confusion matrix and p_source from clean source stream
        confusion_matrix, p_source = estimate_confusion_matrix(model, source_loader, device)

        # Load label-shifted stream (same subset)
        x_shift, y_shift = load_cifar10_label_shift(keep_classes=subset, n_examples=8000, shift_point=split)
        full_dataset = BasicDataset(x_shift, y_shift, transform=transform)

        test_loader = DataLoader(
            Subset(full_dataset, range(split, len(full_dataset))), batch_size=args.batch_size, shuffle=False
        )

        # 4. Estimate p_test from shifted half of the stream
        p_test = estimate_target_distribution_from_preds(model, test_loader, device)

        # 5. Create BBSE/ODS weighted protector
        protector = protect.get_bbse_ods_weighted_protector_from_ents(
            source_ents,
            p_test,
            p_source,
            source_labels,
            confusion_matrix,
            0.05,  # ods_alpha
            argparse.Namespace(gamma=1 / (8 * np.sqrt(3)), eps_clip=1.8, device=device),
        )

        # LOUIS: This is for stat tracking i guess?
        mask = torch.tensor([i in subset for i in range(10)], dtype=torch.bool)
        p_source = p_source.to("cpu") * mask
        p_source /= p_source.sum() + 1e-8

        p_test_true = protector.cdf.p_test_true.to("cpu") * mask
        p_test_true /= p_test_true.sum() + 1e-8

        # LOUIS: Here's where the problem is... i think the evaluate model logic and run martingale logic need to be
        # combined somehow, in the current setup the protector CDF cannot be updated as we go along, i think? 
        # 6) Evaluate full test stream and run martingale
        loader = DataLoader(full_dataset, batch_size=args.batch_size, shuffle=False)
        ents, acc, logits_list, labels_list = evaluate(model, loader, device)

        key = f"labelshift_{num_classes}cls_{'_'.join(map(str, subset))}"
        result = run_martingale({key: ents}, protector)[key]

        # Record statistics
        entropy_peaks[key] = np.max(ents)
        threshold_crossed[key] = np.max(result["log_sj"]) > np.log(100)

        print("Subset:", subset)
        print("p_source:", p_source.numpy())
        print("p_test_true:", p_test_true.numpy())
        print("min p_source:", p_source.min().item(), "min p_t:", p_test_true.min().item())

    # 8) Calculate and store FPR
    n_tests = len(threshold_crossed)
    n_crossed = sum(threshold_crossed.values())
    fpr = n_crossed / n_tests
    fpr_by_size[num_classes] = fpr

    print(f"\n➡️ {n_crossed} / {n_tests} ({num_classes}-class) shifts triggered detection — FPR = {fpr:.3f}\n")

# fpr_by_size now holds clean, valid FPRs per class subset size.

True positive rate

In [None]:
accs, entropy_streams, logits_list_dict, logits_labels_dict = {}, {}, {}, {}
results = {}
corruptions = CORRUPTIONS if args.all_corruptions else [args.corruption]

use_pbrs = False
confidence_threshold = 0.8  # You can tune this

for corruption in corruptions:
    for severity in range(1, 6):
        print(f"🔎 {corruption} severity {severity} (clean → corrupt)")

        # Load clean-to-corrupt stream
        loader, is_clean, labels = load_clean_then_corrupt_sequence(
            corruption=corruption,
            severity=severity,
            n_examples=4000,
            data_dir="./data",
            transform=transform,
            batch_size=args.batch_size,
        )

        if use_pbrs:
            buffer = PBRSBuffer(capacity=512, num_classes=10)

        logits_list = []
        labels_list = []
        entropy_list = []

        with torch.no_grad():
            for x_batch, y_batch in loader:
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                logits = model(x_batch)
                probs = torch.softmax(logits, dim=1)
                entropies = -torch.sum(probs * torch.log(probs + 1e-8), dim=1)
                pseudo_labels = torch.argmax(probs, dim=1)
                max_probs = torch.max(probs, dim=1).values

                for entropy, y_hat, conf in zip(
                    entropies.cpu().tolist(), pseudo_labels.cpu().tolist(), max_probs.cpu().tolist()
                ):
                    if use_pbrs:
                        if confidence_accept(conf, threshold=0.8, softness=15) and buffer.accept(y_hat):
                            buffer.add(None, entropy, y_hat)
                    else:
                        entropy_list.append(entropy)

                logits_list.append(logits.cpu())
                labels_list.append(y_batch.cpu())

        key = f"{corruption}_s{severity}"

        # Store entropy stream
        ents = np.array(buffer.get_entropies()) if use_pbrs else np.array(entropy_list)
        entropy_streams[key] = ents
        logits_list_dict[key] = logits_list
        logits_labels_dict[key] = labels_list


# Run martingale detection on all entropy streams
results = run_martingale(entropy_streams, protector)

# Compute accuracy over time
for corruption in corruptions:
    for severity in range(1, 6):
        key = f"{corruption}_s{severity}"
        accs = compute_accuracy_over_time_from_logits(logits_list_dict[key], logits_labels_dict[key])
        results[key]["accs"] = accs