In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import numpy as np
import torch
import pandas as pd
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import utils.protector as protect
from utils.cli_utils import softmax_ent
from utils.tent import Tent, configure_model, collect_params
from utils.utilities import *
from utils.plotting import *
from typing import Sequence, Tuple, Dict, Optional
import argparse
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans

In [3]:
CORRUPTIONS = (
    "shot_noise",
    "motion_blur",
    "snow",
    "pixelate",
    "gaussian_noise",
    "defocus_blur",
    "brightness",
    "fog",
    "zoom_blur",
    "frost",
    "glass_blur",
    "impulse_noise",
    "contrast",
    "jpeg_compression",
    "elastic_transform",
)

In [None]:
## ENTER PARAMETERS ##

# Manual settings for arguments
args = type("Args", (), {})()  # Create a simple namespace object
args.device = "cpu"  # Change this manually as needed
args.method = "none"  # Options: 'none' or 'tent'
args.corruption = "gaussian_noise"  # Choose from CORRUPTIONS
args.all_corruptions = False  # Set to True to test all corruptions
args.n_examples = 1000
args.batch_size = 64

# Set torch seed for replicability (don't know if this preserves consistency when using different devices)
torch.manual_seed(42)

Basic set up where we load clean CIFAR-10 and then test on corrupted version. This is a good reference to get a feel for how everyting works together.

In [None]:
# Dynamically set device to best available option
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

# Define normalization transform using CIFAR-10 mean and std values
transform = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616))

# Load pre-trained model and move to appropriate device
print("Loading model...")
model = get_model(device)

# Load clean CIFAR-10 test data to compute source entropies
print("Loading clean CIFAR-10 as source entropy")
clean_ds = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transforms.Compose([transforms.ToTensor(), transform])
)
clean_loader = DataLoader(clean_ds, batch_size=args.batch_size, shuffle=False)
source_ents, accuracy, logits_list, labels_list = evaluate(model, clean_loader, device)

# Initialize protector with source entropies for shift detection
protector = protect.get_protector_from_ents(
    source_ents, argparse.Namespace(gamma=1 / (8 * np.sqrt(3)), eps_clip=1.8, device=device)
)

# Test model on corrupted datasets
corruptions = CORRUPTIONS if args.all_corruptions else [args.corruption]
entropy_streams, accs, logits_list_dict, logits_labels_dict = {}, {}, {}, {}

# Iterate through each corruption type and severity level
for corruption in corruptions:
    for severity in range(1, 6):
        print(f"{corruption} severity {severity}")
        # Load corrupted CIFAR-10 data
        x, y = load_cifar10c(args.n_examples, severity, corruption)
        dataset = BasicDataset(x, y, transform=transform)
        loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

        # Evaluate model on corrupted data
        ents, acc, logits_list, labels_list = evaluate(model, loader, device)
        key = f"{corruption}_s{severity}"
        entropy_streams[key] = ents
        accs[key] = acc

        logits_list_dict[key] = logits_list
        logits_labels_dict[key] = labels_list

# Run martingale-based shift detection on entropy streams
results = run_martingale(entropy_streams, protector)

# Print accuracy summary for all tested corruptions
print("\nAccuracy summary:")
for key in accs:
    print(f"{key}: {accs[key] * 100:.2f}%")

In [6]:
# Iterate through each corruption type
for corruption in corruptions:
    # For each severity level (1-5)
    for severity in range(1, 6):
        # Create a key in format "corruption_name_sX" where X is severity level
        key = f"{corruption}_s{severity}"

        # Compute accuracy over time from stored logits and labels for this corruption/severity
        # This gives us a sequence of accuracies showing how model performance changes
        # as it processes more examples from the corrupted dataset
        accs = compute_accuracy_over_time_from_logits(logits_list_dict[key], logits_labels_dict[key])

        # Store the accuracy sequence in the results dictionary
        # Each key maps to a dictionary containing various metrics including:
        # - log_sj: martingale wealth values
        # - eps: epsilon values
        # - accs: accuracy over time
        results[key]["accs"] = accs

**Experiment 1: Plotting Martingale values , epsilon, model accuracy and entropy to see if shift if values line up together**

In [None]:
accs, entropy_streams, logits_list_dict, logits_labels_dict = {}, {}, {}, {}


for corruption in corruptions:
    for severity in range(1, 6):
        print(f"🔎 {corruption} severity {severity} (clean → corrupt)")

        loader, is_clean, labels = load_clean_then_corrupt_sequence(
            corruption=corruption,
            severity=severity,
            n_examples=4000,
            data_dir="./data",
            transform=transform,
            batch_size=args.batch_size,
        )

        ents, acc, logits_list, labels_list = evaluate(model, loader, device)

        key = f"{corruption}_s{severity}"
        entropy_streams[key] = ents
        accs[key] = acc
        logits_list_dict[key] = logits_list
        logits_labels_dict[key] = labels_list

results = {}
results = run_martingale(entropy_streams, protector)

for corruption in corruptions:
    for severity in range(1, 6):
        key = f"{corruption}_s{severity}"
        accs = compute_accuracy_over_time_from_logits(logits_list_dict[key], logits_labels_dict[key])
        results[key]["accs"] = accs

In [None]:
# Get max log_sj values for each severity level
max_log_sj = {}
for severity in results:
    max_log_sj[severity] = max(results[severity]['log_sj'])
    print(f"Max log_sj for {severity}: {max_log_sj[severity]}")

# Sort by severity level for cleaner display
sorted_max_log_sj = dict(sorted(max_log_sj.items()))

# Print results
for severity, value in sorted_max_log_sj.items():
    print(f"{severity}: {value:.2f}")

In [None]:
len(results['gaussian_noise_s1']['eps'])

In [None]:
# Create dictionaries to store values for each severity
log_sj_dict = {}
epsilons_dict = {}
accuracies_dict = {}

# Extract values for each severity level
for severity in range(1, 6):
    key = f"gaussian_noise_s{severity}"
    if key in results:
        log_sj_dict[key] = results[key]["log_sj"]
        epsilons_dict[key] = results[key]["eps"]
        accuracies_dict[key] = results[key]["accs"]

# Plot the combined results
plot_combined_martingale_accuracy_severity(
    log_sj_dict,
    epsilons_dict,
    accuracies_dict,
    entropy_dict=entropy_streams,
    batch_size=64,
    title="Gaussian Noise Comparison Across Severities",
)

In [None]:
# Print sizes of each dictionary
print("Log SJ Dictionary size:", len(log_sj_dict))
print("Epsilons Dictionary size:", len(epsilons_dict))
print("Accuracies Dictionary size:", len(accuracies_dict))

# For a single key, print the length of arrays
key = f"gaussian_noise_s1"
if key in log_sj_dict:
    print(f"\nFor {key}:")
    print(f"log_sj length: {len(log_sj_dict[key])}")
    print(f"epsilons length: {len(epsilons_dict[key])}")
    print(f"accuracies length: {len(accuracies_dict[key])}")

In [None]:
detection_delays = compute_detection_delays_from_threshold(log_sj_dict)

# Call the function to plot detection delays
plot_detection_delays(detection_delays)

**Experiment 2: Fix severity at 5 and change the corruption type to see if shift detector works differently based on corruption type**

In [None]:
args.severity = 5  # Fix severity to level 5 for all corruptions

accs, entropy_streams, logits_list_dict, logits_labels_dict = {}, {}, {}, {}

for corruption in CORRUPTIONS:
    print(f"🔎 {corruption} severity {args.severity} (clean → corrupt)")

    loader, is_clean, labels = load_clean_then_corrupt_sequence(
        corruption=corruption,
        severity=args.severity,
        n_examples=4000,
        data_dir="./data",
        transform=transform,
        batch_size=args.batch_size,
    )

    ents, acc, logits_list, labels_list = evaluate(model, loader, device)

    key = f"{corruption}_s{args.severity}"
    entropy_streams[key] = ents
    accs[key] = acc
    logits_list_dict[key] = logits_list
    logits_labels_dict[key] = labels_list

results = run_martingale(entropy_streams, protector)

for corruption in CORRUPTIONS:
    key = f"{corruption}_s{args.severity}"
    accs = compute_accuracy_over_time_from_logits(logits_list_dict[key], logits_labels_dict[key])
    results[key]["accs"] = accs

# Create dictionaries to store values for each corruption
log_sj_dict = {}
epsilons_dict = {}
accuracies_dict = {}
entropy_dict = {}

# Extract values for each corruption at severity level 5
for corruption in CORRUPTIONS:
    key = f"{corruption}_s{args.severity}"
    if key in results:
        log_sj_dict[key] = results[key]["log_sj"]
        epsilons_dict[key] = results[key]["eps"]
        accuracies_dict[key] = results[key]["accs"]
        entropy_dict[key] = entropy_streams[key]

In [None]:
# Plot the combined results
plot_combined_martingale_accuracy_corruption(
    log_sj_dict,
    epsilons_dict,
    accuracies_dict,
    entropy_dict=entropy_dict,
    batch_size=64,
    title="Comparison Across Corruption Types at Severity 5",
)

**Experiment 3: Delay analysis**

In [None]:
detection_delays = compute_detection_delays_from_threshold(log_sj_dict)

# Plot detection delays
plot_detection_delays(detection_delays)


In [None]:
# Compute accuracy drops
accuracy_drops = compute_accuracy_drops(accuracies_dict)

# Print results to verify
for corruption, drop in accuracy_drops.items():
    print(f"{corruption}: {drop:.3f}")

**Experiment 4: Correlation Analysis**

In [None]:
# Align data
corruptions = list(detection_delays.keys())
delays = [detection_delays[c] for c in corruptions]
drops = [accuracy_drops[c] for c in corruptions]

# Plot
plt.figure(figsize=(8, 6))
plt.scatter(delays, drops, color="crimson")
for i, name in enumerate(corruptions):
    plt.annotate(name, (delays[i], drops[i]), textcoords="offset points", xytext=(5, 5), ha="left", fontsize=9)

plt.title("Detection Delay vs Accuracy Drop")
plt.xlabel("Detection Delay (samples after corruption)")
plt.ylabel("Accuracy Drop (clean - corrupt)")
plt.grid(True)
plt.tight_layout()
plt.show()

Negative correlation between detection delay and accuracy drop implies the detector prioritizes responding to impactful shifts, rather than reacting to every minor change. That’s exactly what you want in a reliable shift detector.

**Experiment 5: Clustering Corruptions by Detection Behavior**

In [25]:
# Reuse data from previous experiment (some of this is redundant and just here for readability)
# Could maybe add more features if necessary
detection_delays = compute_detection_delays_from_threshold(log_sj_dict)
accuracy_drops = compute_accuracy_drops(accuracies_dict)
entropy_spikes = compute_entropy_spikes(entropy_streams)
confidence_slopes = compute_detection_confidence_slope(log_sj_dict, epsilons_dict)

results = list()

for corruption in entropy_spikes.keys():
    results.append(
        {
            "corruption": corruption,
            "detection_delay": detection_delays[corruption],
            "accuracy_drop": accuracy_drops[corruption],
            "entropy_spike": entropy_spikes[corruption],
            "confidence_slope": confidence_slopes[corruption],
        }
    )

clustering_df = pd.DataFrame(results)
# print(clustering_df)

In [None]:
# Extract features for clustering
feature_cols = [col for col in clustering_df.columns if col not in ["corruption", "tsne_1", "tsne_2"]]
X = clustering_df[feature_cols].values


# K-means setup - 4 clusters seems considering the result
n_clusters = 4
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
clustering_df["cluster"] = kmeans.fit_predict(X)

# TSNE for 2D visualization
tsne = TSNE(n_components=2, random_state=42, perplexity=5)
X_tsne = tsne.fit_transform(X)
clustering_df["tsne_1"] = X_tsne[:, 0]
clustering_df["tsne_2"] = X_tsne[:, 1]


# Plot things
plt.figure(figsize=(10, 7))
colors = plt.cm.rainbow(np.linspace(0, 1, n_clusters))

for i in range(n_clusters):
    subset = clustering_df[clustering_df["cluster"] == i]
    plt.scatter(subset["tsne_1"], subset["tsne_2"], color=colors[i], label=f"Cluster {i}", s=100)

for _, row in clustering_df.iterrows():
    plt.text(row["tsne_1"] + 0.5, row["tsne_2"] + 0.5, row["corruption"], fontsize=9)

plt.title("t-SNE Projection Colored by k-means Cluster")
plt.xlabel("t-SNE Dim 1")
plt.ylabel("t-SNE Dim 2")
plt.legend()
plt.tight_layout()
plt.show()

**Experiment 6: Label Noise and Concept Shift**

In [27]:
# Shift experiment setup:
# Fix the new priors for replicability
new_priors = {
    0: 0.3,
    1: 0.15,
    2: 0.1,
    3: 0.1,
    4: 0.11,
    5: 0.07,
    6: 0.05,
    7: 0.05,
    8: 0.04,
    9: 0.03,
}

# Label Noise: input-label mapping changes: "dog" labeled as "cat"
label_noise_data = apply_label_noise_to_dataset(clean_ds, noise_rate=0.3)
# Class Prior Shift: Class proportions changes, no longer even split between classes
prior_shift_data = simulate_prior_shift(clean_ds, class_priors=new_priors)
# Combine Prior shift and Label Noise
combined_shift_data = simulate_prior_shift_with_label_noise(clean_ds, class_priors=new_priors)

shift_experiments = {
    "label noise": label_noise_data,
    "prior shift": prior_shift_data,
    "combined": combined_shift_data,
}

In [None]:
# Can reuse components of experiment 2 here
accs, entropy_streams, logits_list_dict, logits_labels_dict = {}, {}, {}, {}

for shift, data in shift_experiments.items():
    print(f" Running shift scenario: {shift}")

    # 1. Generate corrupted dataset
    corrupted_dataset = data

    # 2. Create DataLoader
    corrupted_loader = DataLoader(corrupted_dataset, batch_size=64, shuffle=False)

    # 3. Evaluate model on corrupted data
    ents, acc, logits_list, labels_list = evaluate(model, corrupted_loader, device)

    key = shift
    entropy_streams[shift] = ents
    accs[shift] = acc
    logits_list_dict[key] = logits_list
    logits_labels_dict[key] = labels_list

    results = run_martingale(entropy_streams, protector)

for shift in shift_experiments.keys():
    key = shift
    accs = compute_accuracy_over_time_from_logits(logits_list_dict[key], logits_labels_dict[key])
    results[key]["accs"] = accs

# Create dictionaries to store values for each shift
log_sj_dict = {}
epsilons_dict = {}
accuracies_dict = {}
entropy_dict = {}

# Extract values for each shift
for shift in shift_experiments.keys():
    if shift in results:
        log_sj_dict[shift] = results[shift]["log_sj"]
        epsilons_dict[shift] = results[shift]["eps"]
        accuracies_dict[shift] = results[shift]["accs"]
        entropy_dict[shift] = entropy_streams[shift]

In [None]:
# Plot results for shift types
plot_combined_noise_shift(
    log_sj_dict,
    epsilons_dict,
    accuracies_dict,
    entropy_dict=entropy_dict,
    batch_size=64,
    title="Comparison for Label Noise, Prior Shift, and Combined Shift",
)

**Experiment 7: Continual and Recurring Shifts**

In [None]:
# Example configuration
corruption = "gaussian_noise"
segment_schedule = ["clean", "s2", "s5", "clean", "s4"]
segment_size = 1500

# Load dynamic stream
loader, is_clean, labels, seg_labels = load_dynamic_sequence(
    segments=segment_schedule,
    segment_size=segment_size,
    corruption_name=corruption,
    data_dir="./data",
    transform=transform,
    batch_size=args.batch_size,
)

# Evaluate model over full stream
ents, acc, logits_list, labels_list = evaluate(model, loader, device)

# Run martingale detection
key = f"{corruption}_dynamic"
entropy_streams = {key: ents}
results = run_martingale(entropy_streams, protector)

# Compute accuracy over time
acc_stream = compute_accuracy_over_time_from_logits(logits_list, labels_list)
results[key]["accs"] = acc_stream

# Collect results for plotting/comparison
log_sj_dict = {key: results[key]["log_sj"]}
epsilons_dict = {key: results[key]["eps"]}
accuracies_dict = {key: acc_stream}
entropy_dict = {key: ents}

In [None]:
plot_dynamic_stream_analysis(
    log_sj_dict,
    epsilons_dict,
    accuracies_dict,
    entropy_dict,
    segment_schedule=segment_schedule,
    segment_size=segment_size,
    batch_size=args.batch_size,
    title=f"Dynamic Stream Analysis: {corruption} with Varying Severity",
)