# Selective Inference

The Models are evaluated by blocking certain codebook vectors, with high misclassification rates and testing if an improvement ind adversarial accuracy is achieved.

This is achieved by setting their weights to 0.

In [3]:
import re
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

### Model Loading

In [None]:
# dynamic S-VQVAE import
def get_s_vqvae_model_class(model_path):
    # extract codebook vector size from filename
    match = re.search(r"emb(\d+)", model_path)
    if not match:
        raise ValueError(f"Could not find codebook vector size in filename: {model_path}")
    emb_size = match.group(1)

    # map the size to the appropriate module and class import
    model_imports = {
        "512": "models.s_vqvae_emb512",
        "256": "models.s_vqvae_emb256",
        "128": "models.s_vqvae_emb128",
        "64":  "models.s_vqvae_emb64",
        "10":  "models.s_vqvae_emb10"
    }

    if emb_size not in model_imports:
        raise ValueError(f"Unsupported codebook vector size: {emb_size}")

    module_name = model_imports[emb_size]
    module = __import__(module_name, fromlist=['S_VQVAE'])
    return module.S_VQVAE 

In [None]:
# other model options
from models.pretrainedvgg16_256 import S_VQVAE_VGG 

In [None]:
# model loader
def load_model(model_path, device):


    # basic models without variations
    model_classes = {
        "cnn_classifier": CNNClassifier(num_classes=10),
        "vae": VAE(),
    }

    # attempt to dynamically import an S_VQVAE class
    try:
        s_vqvae_class = get_s_vqvae_model_class(model_path)
        model_classes["s_vqvae"] = s_vqvae_class()
    except ValueError as ve:
        print(f"S_VQVAE Import Error: {ve}")

    if "s_vqvae_vgg" in model_path.lower():
        model_classes["s_vqvae_vgg"] = S_VQVAE_VGG()
    elif "vgg16" in model_path.lower():
        pass

    # load checkpoint from file
    try:
        checkpoint = torch.load(model_path, map_location='cpu')
    except FileNotFoundError:
        raise ValueError(f"File not found: {model_path}")
    except Exception as e:
        raise ValueError(f"Failed to load checkpoint from {model_path}: {e}")

    for model_type, model in model_classes.items():
        try:
            if 'state_dict' in checkpoint:
                model_state_dict = checkpoint['state_dict']
            else:
                model_state_dict = checkpoint

            model_keys = set(model.state_dict().keys())
            checkpoint_keys = set(model_state_dict.keys())

            # check for key mismatch
            if not model_keys.issubset(checkpoint_keys):
                print(f"Key mismatch for {model_type}. Skipping...")
                continue

            # load
            model.load_state_dict(model_state_dict, strict=False)
            model.to(device)
            model.eval()

            print(f"Successfully loaded {model_path} as {model_type}")
            return model, model_type

        except RuntimeError as re:
            print(f"RuntimeError for {model_type}: {re}")
        except KeyError as ke:
            print(f"KeyError for {model_type}: {ke}")
        except Exception as e:
            print(f"Unexpected error for {model_type}: {e}")

    raise ValueError(f"Model loading failed for {model_path}. No compatible architecture found.")



In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Files already downloaded and verified


### Compute Failure Rates under PGD-k Attack

In [None]:
import torch
import torch.nn.functional as F

def pgd_attack(model, images, labels, eps=8/255, alpha=2/255, iters=10):

    device = images.device
    adv_images = images.clone().detach()

    for _ in range(iters):
        adv_images.requires_grad_()
        _, _, logits = model(adv_images)
        loss = F.cross_entropy(logits, labels)
        loss.backward()

        with torch.no_grad():
            adv_images = adv_images + alpha * adv_images.grad.sign()
            perturbation = torch.clamp(adv_images - images, min=-eps, max=eps)
            adv_images = torch.clamp(images + perturbation, min=0, max=1)

    return adv_images.detach()

In [None]:
# evaluate clean accuracy
def evaluate_clean_accuracy(model, dataloader):
    model.eval()
    correct, total = 0, 0
    for images, labels in dataloader:
        images, labels = images.to(model.quantize.embed.device), labels.to(model.quantize.embed.device)
        dec, diff, logits = model(images)  
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return correct / total

In [None]:
# evaluate PGD-k robustness
def evaluate_pgd_robustness(model, dataloader, device, eps=8/255, alpha=2/255, iters=10):
    """
    Evaluate accuracy on PGD adversarial images for a single model.
    """
    model.eval()
    correct, total = 0, 0

    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)

        # No model_type; just pass the necessary arguments
        adv_images = pgd_attack(model, images, labels, eps, alpha, iters)

        _, _, logits = model(adv_images)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return correct / total


In [None]:
# additional flip rate computation
def compute_flip_rate(model, dataloader, eps=8/255, alpha=2/255, iters=10):
    """
    How often does the label flip from clean to adv predictions?
    """
    model.eval()
    flip_count = 0
    total = 0

    for images, labels in dataloader:
        images, labels = images.to(model.quantize.embed.device), labels.to(model.quantize.embed.device)

        # clean predictions
        _, _, clean_logits = model(images)
        clean_preds = clean_logits.argmax(dim=1)

        # PGD adversarial
        adv_images = pgd_attack(model, images, labels, eps, alpha, iters)
        _, _, adv_logits = model(adv_images)
        adv_preds = adv_logits.argmax(dim=1)

        flip_count += (clean_preds != adv_preds).sum().item()
        total += labels.size(0)

    return flip_count / total

In [None]:
def compute_codebook_usage_and_misclassification_pgd(
    model,
    dataloader,
    device,
    eps=8/255,
    alpha=2/255,
    iters=10
):

    model.eval()

    num_codebook_vectors = model.quantize.embed.shape[1]

    usage_count = torch.zeros(num_codebook_vectors, dtype=torch.long, device=device)
    miscount = torch.zeros(num_codebook_vectors, dtype=torch.long, device=device)

    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)

        # create adversarial images
        adv_images = pgd_attack(model, images, labels, eps, alpha, iters)

        # codebook indices
        quant, diff, embed_ind = model.encode(adv_images)

        logits = model.classify(quant)

        # evaluate correctness
        preds = logits.argmax(dim=1)
        correct_mask = (preds == labels)

        flat_indices = embed_ind.view(embed_ind.size(0), -1)

        for i in range(flat_indices.size(0)):
            used_indices = flat_indices[i]
            usage_count.index_add_(
                0,
                used_indices,
                torch.ones_like(used_indices, dtype=torch.long)
            )
            if not correct_mask[i]:
                miscount.index_add_(
                    0,
                    used_indices,
                    torch.ones_like(used_indices, dtype=torch.long)
                )

    return usage_count, miscount


In [None]:
def compute_codebook_misclassification_rates_pgd(
    model,
    dataloader,
    device,
    eps=8/255,
    alpha=2/255,
    iters=10
):

    usage_count, miscount = compute_codebook_usage_and_misclassification_pgd(
        model, dataloader, device, eps, alpha, iters
    )

    usage_count = usage_count.float()
    miscount = miscount.float()

    mis_rates = torch.zeros_like(miscount)
    nonzero_mask = (usage_count > 0)
    mis_rates[nonzero_mask] = miscount[nonzero_mask] / usage_count[nonzero_mask]

    return mis_rates


### Blocking Process

Blocking based on misclassification Threshold, to avoid failing codebook vectors

In [None]:
def block_codebook_vectors(model, high_error_indices):
    with torch.no_grad():
        # set weight of failing codebook vectors to zero
        model.quantize.embed[:, high_error_indices] = 0.0

In [None]:
def evaluate_with_codebook_blocking_pgd(model, dataloader, device, eps, alpha, iters, thresholds):
    model.eval()
    results = []

    for threshold in thresholds:
        # collect stats
        total_clean = 0
        total_pgd   = 0
        total_samples = 0

        # loop over data
        for inputs, labels in dataloader:
            inputs = inputs.to(device, dtype=torch.float32)
            labels = labels.to(device)

        results.append({
            "threshold": threshold,
            "clean_acc": total_clean / total_samples,
            "pgd_acc": total_pgd / total_samples

        })

    df = pd.DataFrame(results)
    return df


In [None]:
import os

def evaluate_models_in_folder(
    folder_path,
    dataloader,
    device,
    thresholds=[1.00, 0.75, 0.50],
    eps=8/255,
    alpha=2/255,
    iters=10
):

    os.makedirs("results", exist_ok=True)

    model_files = [f for f in os.listdir(folder_path) if f.endswith(('.pt', '.pth'))]
    model_files.sort()

    for model_file in model_files:
        model_path = os.path.join(folder_path, model_file)
        print(f"\n=== Evaluating model: {model_file} ===")

        model, detected_type = load_model(model_path, device)

        # evaluate single model
        blocked_results_df = evaluate_with_codebook_blocking_pgd(
            model=model,
            dataloader=dataloader,
            device=device,
            eps=eps,
            alpha=alpha,
            iters=iters,
            thresholds=thresholds
        )

        # annotate with model name, type
        blocked_results_df["model_name"] = model_file
        blocked_results_df["model_type"] = detected_type

        # save
        base_name, _ = os.path.splitext(model_file)
        out_csv = f"results/codebook_blocking_{base_name}.csv"
        blocked_results_df.to_csv(out_csv, index=False)
        print(blocked_results_df)
        print(f"Saved results to: {out_csv}")

In [None]:
# individual model evaluation
def evaluate_models_individually_in_folder(
    folder_path,
    dataloader,
    device,
    thresholds=[0.75, 0.90, 0.95, 0.98],
    eps=8/255,
    alpha=2/255,
    iters=10
):
    os.makedirs("results", exist_ok=True)

    model_files = [f for f in os.listdir(folder_path) if f.endswith(('.pt', '.pth'))]
    model_files.sort()

    for model_file in model_files:
        model_path = os.path.join(folder_path, model_file)
        print(f"\n=== Evaluating model: {model_file} ===")


        model, detected_type = load_model(model_path, device)

        model = model.to(device, dtype=torch.float32)

        blocked_results_df = evaluate_with_codebook_blocking_pgd(
            model=model,
            dataloader=dataloader,
            device=device,
            eps=eps,
            alpha=alpha,
            iters=iters,
            thresholds=thresholds
        )

        blocked_results_df["model_name"] = model_file
        blocked_results_df["model_type"] = detected_type

        print(blocked_results_df)

        # save each model's result to its own csv
        base_name, _ = os.path.splitext(model_file)
        out_csv = f"results/codebook_blocking_{base_name}.csv"
        blocked_results_df.to_csv(out_csv, index=False)
        print(f"Saved results to: {out_csv}")


In [None]:
# blocking process
if __name__ == "__main__":
    import torchvision.transforms as transforms
    import torchvision.datasets as datasets
    from torch.utils.data import DataLoader

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

 
    # CIFAR-10 test dataset and loader
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

    # directory of models
    folder_path = "best_model/blockage"

    # thresholds for blocking, 1.01 for no blocking
    thresholds = [1.01, 1.00, 0.95, 0.90, 0.85, 0.80, 0.75, 0.70, 0.65, 0.60, 0.55, 0.50, 0.45, 0.40]

    # evaluate each model in folder_path
    evaluate_models_individually_in_folder(
        folder_path=folder_path,
        dataloader=test_loader,
        device=device,
        thresholds=thresholds,
        eps=8/255,
        alpha=2/255,
        iters=10
    )

Now let's first calculate the relative changes of blocking certain thresholds and combine the individual results for visualization.

In [None]:
import glob
import pandas as pd

all_csvs = glob.glob("results/codebook_blocking_*.csv")

for csv_file in all_csvs:
    df = pd.read_csv(csv_file)

    # baseline first value, so no blocking
    baseline_clean_acc = df.loc[0, "clean_acc"]
    baseline_pgd_acc   = df.loc[0, "pgd_acc"]
    baseline_flip_rate = df.loc[0, "flip_rate"]

    # compute relative changes
    df["clean_acc_change"] = (df["clean_acc"] - baseline_clean_acc) / baseline_clean_acc * 100
    df["pgd_acc_change"]   = (df["pgd_acc"]   - baseline_pgd_acc)   / baseline_pgd_acc   * 100
    df["flip_rate_change"] = (df["flip_rate"] - baseline_flip_rate) / baseline_flip_rate * 100

    # save a new CSV with suffix `_changes`
    new_file = csv_file.replace(".csv", "_changes.csv")
    df.to_csv(new_file, index=False)
    print(f"Saved changes: {new_file}")

Saved changes: results\codebook_blocking_svqvae_cd_delayedat_emb256_marginchange_lr05_changes.csv
Saved changes: results\codebook_blocking_svqvae_cd_delayedat_emb256_marginchange_lr05_changes_changes.csv
Saved changes: results\codebook_blocking_svqvae_emb256_changes.csv
Saved changes: results\codebook_blocking_svqvae_emb256_changes_changes.csv
Saved changes: results\codebook_blocking_svqvae_emb256_delayedat_changes.csv
Saved changes: results\codebook_blocking_svqvae_emb256_delayedat_changes_changes.csv
Saved changes: results\codebook_blocking_under_pgd_ALL_MODELS_changes.csv
Saved changes: results\codebook_blocking_under_pgd_ALL_MODELS_changes_changes.csv


In [None]:
import numpy as np

# plot function for changes
def plot_changes_bar_chart_in_axis(df_changes, ax, figure_title=""):
    x_labels = df_changes["threshold"].astype(str).values

    # rename "1.01" -> "no blockage"
    x_labels = ["no blockage" if lbl == "1.01" else lbl for lbl in x_labels]

    clean_vals = df_changes["clean_acc_change"].values
    pgd_vals   = df_changes["pgd_acc_change"].values
    flip_vals  = df_changes["flip_rate_change"].values

    x = np.arange(len(x_labels))
    width = 0.25

    # grouped bars
    ax.bar(x - width, clean_vals, width=width, label="Clean Acc %Δ")
    ax.bar(x,         pgd_vals,   width=width, label="PGD Acc %Δ")
    ax.bar(x + width, flip_vals,  width=width, label="Flip Rate %Δ")

    # horizontal line at 0
    ax.axhline(y=0, color='grey', linewidth=1, linestyle='--')

    ax.set_xticks(x)
    ax.set_xticklabels(x_labels)
    ax.set_xlabel("Threshold")
    ax.set_ylabel("Percentage Change (%)")
    ax.set_title(figure_title)
    ax.legend()

    return ax


In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

def visualize_three_models_individually(results_folder="results/collection"):

    # map file names for better visual clarity
    csv_map = [
        (
            "codebook_blocking_svqvae_cd_delayedat_emb256_marginchange_lr05_changes.csv",
            "SVQVAE256 CD DelayedAT"
        ),
        (
            "codebook_blocking_svqvae_emb256_changes.csv",
            "SVQVAE256"
        ),
        (
            "codebook_blocking_svqvae_emb256_delayedat_changes_changes.csv",
            "SVQVAE256 DelayedAT"
        )
    ]

    subplot_labels = ["(a)", "(b)", "(c)"]

    # 3×1 subplots 
    fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(14, 10))
    fig.suptitle("Relative Codebook Blocking Performance by Threshold", fontsize=16)

    os.makedirs(results_folder, exist_ok=True)

    for i, (csv_file, display_title) in enumerate(csv_map):
        csv_path = os.path.join(results_folder, csv_file)

        if not os.path.isfile(csv_path):
            print(f"File not found: {csv_path}. Skipping subplot {i}.")
            continue

        df = pd.read_csv(csv_path)

        ax = axes[i] 

        title_with_label = f"{subplot_labels[i]} {display_title}"

        # plot bars
        plot_changes_bar_chart_in_axis(df, ax, figure_title=title_with_label)

    for j in range(i + 1, 3):
        axes[j].set_visible(False)

    plt.tight_layout(rect=[0, 0, 1, 0.95])  
    out_fig = os.path.join(results_folder, "3models_changes_subplots.png")
    plt.savefig(out_fig, dpi=300)
    plt.close(fig)

    print(f"Saved subplot figure with 3 models at: {out_fig}")


In [21]:
visualize_three_models_individually(results_folder="results/collection")

Saved subplot figure with 3 models at: results/collection\3models_changes_subplots.png
