In [None]:
import os
import numpy as np
import torch
import pandas as pd
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import protector as protect
from utils.cli_utils import softmax_ent

from tent import Tent, configure_model, collect_params
from typing import Sequence, Tuple, Dict, Optional
import argparse

from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

from utilities import *  ## created by me
from plotting import *  ## created by me
from LabelShiftExperiments import *  ## created by me
from protector import *  ## created by me

In [None]:
CORRUPTIONS = (
    "shot_noise",
    "motion_blur",
    "snow",
    "pixelate",
    "gaussian_noise",
    "defocus_blur",
    "brightness",
    "fog",
    "zoom_blur",
    "frost",
    "glass_blur",
    "impulse_noise",
    "contrast",
    "jpeg_compression",
    "elastic_transform",
)

In [None]:
## ENTER PARAMETERS ##

# Manual settings for arguments
args = type("Args", (), {})()  # Create a simple namespace object
args.device = "cpu"  # Change this manually as needed
args.method = "none"  # Options: 'none' or 'tent'
args.corruption = "gaussian_noise"  # Choose from CORRUPTIONS
args.all_corruptions = False  # Set to True to test all corruptions
args.n_examples = 1000
args.batch_size = 64

In [None]:
# Dynamically set device to best available option
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

# Define normalization transform using CIFAR-10 mean and std values
transform = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616))

# Load pre-trained model and move to appropriate device
print("🚀 Loading model...")
model = get_model(args.method, device)

# Load clean CIFAR-10 test data to compute source entropies
print("📦 Loading clean CIFAR-10 as source entropy")
clean_ds = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transforms.Compose([transforms.ToTensor(), transform])
)
clean_loader = DataLoader(clean_ds, batch_size=args.batch_size, shuffle=False)
source_ents, accuracy, logits_list, labels_list = evaluate(model, clean_loader, device)

# Initialize protector with source entropies for shift detection
protector = protect.get_protector_from_ents(
    source_ents, argparse.Namespace(gamma=1 / (8 * np.sqrt(3)), eps_clip=1.8, device=device))

In [None]:
from copy import deepcopy
import optuna

def run_optuna_pbrs_optimization(model,
                                  compare_fpr_across_seeds,
                                  evaluate_covariate_shift_detection,
                                  load_cifar10_corruption,
                                  load_cifar10_label_shift_balanced,
                                  BasicDataset,
                                  run_martingale,
                                  protector_factory,
                                  transform,
                                  args,
                                  device,
                                  corruption_types,
                                  severities,
                                  num_classes_list=[1, 2, 3, 4, 5, 6, 7, 8, 9],
                                  seeds_tpr=range(3),
                                  seeds_fpr=range(3),
                                  n_trials=50,
                                  log_path='optuna_pbrs_results.csv'):
    
    run_time = datetime.now().strftime("%Y%m%d_%H%M%S")
    results = []

    def objective(trial):
        buffer_size = trial.suggest_categorical("buffer_size", [32, 64, 128, 256, 512, 1024])
        confidence_threshold = trial.suggest_float("confidence_threshold", 0.5, 0.95, step=0.05)
        gamma = trial.suggest_categorical("gamma", [1 / (16 * 3 ** 0.5), 1 / (8 * 3 ** 0.5), 1 / (4 * 3 ** 0.5)])
        eps_clip = trial.suggest_categorical("eps_clip", [1.5, 1.8, 2.0, 2.5, 3.0])

        run_id = f"{run_time}_B{buffer_size}_T{confidence_threshold:.2f}_G{gamma:.3f}_E{eps_clip:.2f}"
        print(f"\n🔧 Trial {trial.number}: {run_id}")

        # === Step 1: Deepcopy the model to ensure reproducibility ===
        model_copy_tpr = deepcopy(model)
        model_copy_fpr = deepcopy(model)

        # === Step 2: Recompute source entropy for this trial ===
        clean_loader = DataLoader(
            torchvision.datasets.CIFAR10(root="./data", train=False, download=True,
                                        transform=torchvision.transforms.Compose([
                                            torchvision.transforms.ToTensor(),
                                            transform
                                        ])),
            batch_size=args.batch_size,
            shuffle=False
        )

        source_ents, _, _, _ = evaluate(model_copy_tpr, clean_loader, device)

        # === Step 3: Create protector using clean entropies ===
        args_protector = argparse.Namespace(gamma=gamma, eps_clip=eps_clip, device=device)
        protector_tpr = protect.get_protector_from_ents(source_ents, args_protector)
        protector_fpr = protect.get_protector_from_ents(source_ents, args_protector)

        # === Step 4: Evaluate TPR ===
        tpr_result = evaluate_covariate_shift_detection(
            model=model_copy_tpr,
            load_cifar10_corruption=load_cifar10_corruption,
            BasicDataset=BasicDataset,
            run_martingale=run_martingale,
            protector=protector_tpr,
            transform=transform,
            args=args,
            device=device,
            corruption_types=corruption_types,
            severities=severities,
            seeds=seeds_tpr,
            buffer_capacity=buffer_size,
            confidence_threshold=confidence_threshold,
            num_classes=10,
            use_pbrs=True,
            log_path=f'tpr_results_{run_id}.csv'
        )

        tpr_scores = [v['detection_rate'] for v in tpr_result.values()]
        delays = [v['avg_delay'] for v in tpr_result.values() if v['avg_delay'] is not None]
        mean_tpr = float(pd.Series(tpr_scores).mean())
        mean_delay = float(pd.Series(delays).mean()) if delays else 4000

        # === Step 5: Evaluate FPR ===
        fpr_result = compare_fpr_across_seeds(
            model=model_copy_fpr,
            load_cifar10_label_shift_balanced=load_cifar10_label_shift_balanced,
            BasicDataset=BasicDataset,
            run_martingale=run_martingale,
            protector=protector_fpr,
            transform=transform,
            args=args,
            device=device,
            seeds=seeds_fpr,
            num_classes_list=num_classes_list,
            buffer_capacity=buffer_size,
            confidence_threshold=confidence_threshold,
            use_pbrs=True,
            log_path=f'fpr_results_{run_id}.csv'
        )

        fpr_scores = list(fpr_result.values())
        mean_fpr = float(pd.Series(fpr_scores).mean())

        # === Step 6: Compute objective ===
        score = (0.8 * (1 - mean_fpr)) + (0.9 * mean_tpr) - (0.3 * (mean_delay / 4000))
        trial.set_user_attr("fpr", mean_fpr)
        trial.set_user_attr("tpr", mean_tpr)
        trial.set_user_attr("delay", mean_delay)

        results.append({
            'run_id': run_id,
            'trial': trial.number,
            'buffer_size': buffer_size,
            'confidence_threshold': confidence_threshold,
            'gamma': gamma,
            'eps_clip': eps_clip,
            'mean_fpr': mean_fpr,
            'mean_tpr': mean_tpr,
            'avg_delay': mean_delay,
            'score': score
        })

        print(f"🔎 Trial {trial.number}: score = {score:.3f} | FPR={mean_fpr:.3f} | TPR={mean_tpr:.3f} | Delay={mean_delay:.0f}")
        return score
    
    study = optuna.create_study(direction="maximize")
    study.optimize(objective, n_trials=n_trials)

    df = pd.DataFrame(results)
    df.to_csv(log_path, index=False)
    print(f"\n✅ Optuna results saved to {log_path}")
    print("Best trial:", study.best_trial.params)
    return study