In [None]:
'''
This cell sets up the environment, imports necessary libraries, and defines hyperparameters and seeds for reproducibility.
'''
import os, time, json
import torch, torch.nn as nn, torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import numpy as np
import pandas as pd

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

BATCH_SIZE = 256
LEARNING_RATE = 0.1
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
EPOCHS = 100
K_PERCENTAGE = 0.5
K_SAMPLES = int(BATCH_SIZE * K_PERCENTAGE)

EPSILON, ALPHA = 8/255, 2/255
PGD_STEPS_TRAIN, PGD_STEPS_EVAL = 10, 20

SEEDS = [42, 43]

NUM_CLASSES = 100
METHOD_NAME_TOPK = f"TopK_CIFAR100_k{K_SAMPLES}_bs{BATCH_SIZE}"
OUT_DIR = "results_topk_runs"
os.makedirs(OUT_DIR, exist_ok=True)

print(f"Method: {METHOD_NAME_TOPK}, K={K_SAMPLES}, BATCH_SIZE={BATCH_SIZE}, seeds={SEEDS}")

In [None]:
'''
CIFAR-100 dataset and data loaders
'''
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])
transform_test = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
test_dataset  = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print("Loaded CIFAR-100:", len(train_dataset), "train,", len(test_dataset), "test")

In [None]:
'''
PGD attack, utilities, per-class evaluation
'''
def pgd_attack(model, images, labels, epsilon, alpha, iters):
    images = images.clone().detach().to(device)
    labels = labels.to(device)
    orig = images.clone().detach()
    for _ in range(iters):
        images.requires_grad = True
        outputs = model(images)
        loss = F.cross_entropy(outputs, labels)
        model.zero_grad()
        loss.backward()
        images = images + alpha * images.grad.sign()
        eta = torch.clamp(images - orig, -epsilon, epsilon)
        images = torch.clamp(orig + eta, 0.0, 1.0).detach()
    return images

def compute_overlap_indices(losses, indices_a, indices_b):
    sa = set(indices_a.cpu().numpy().tolist())
    sb = set(indices_b.cpu().numpy().tolist())
    if len(sa) == 0:
        return 0.0
    return len(sa & sb) / float(len(sa))

def compute_selection_stability(list_of_selected_indices, k):
    """
    list_of_selected_indices: list of 1D tensors (selected indices per batch)
    stability: average intersection/k between consecutive batches
    """
    if len(list_of_selected_indices) < 2:
        return 1.0
    scores = []
    prev = None
    for sel in list_of_selected_indices:
        s = set(sel.cpu().numpy().tolist())
        if prev is not None:
            scores.append(len(prev & s) / float(k))
        prev = s
    return float(np.mean(scores)) if len(scores)>0 else 0.0

def evaluate_per_class(model, data_loader, attack_fn=None, num_classes=NUM_CLASSES):
    """
    Returns overall_acc (float) and per_class_acc (list length=num_classes)
    attack_fn: function(model, images, labels, epsilon, alpha, iters) or None
    """
    model.eval()
    correct_per_class = np.zeros(num_classes, dtype=np.int64)
    total_per_class = np.zeros(num_classes, dtype=np.int64)
    total_correct = 0
    total_samples = 0
    for images, labels in data_loader:
        images, labels = images.to(device), labels.to(device)
        if attack_fn is not None:
            images = attack_fn(model, images, labels, EPSILON, ALPHA, PGD_STEPS_EVAL)
        with torch.no_grad():
            outs = model(images)
            _, preds = torch.max(outs, 1)
            preds_cpu = preds.cpu().numpy()
            labels_cpu = labels.cpu().numpy()
            total_correct += int((preds_cpu == labels_cpu).sum())
            total_samples += labels_cpu.shape[0]
            for c in range(num_classes):
                mask = (labels_cpu == c)
                if mask.sum() > 0:
                    correct_per_class[c] += int((preds_cpu[mask] == c).sum())
                    total_per_class[c] += int(mask.sum())
    overall_acc = 100.0 * total_correct / total_samples if total_samples>0 else 0.0
    per_class_acc = []
    for c in range(num_classes):
        if total_per_class[c] > 0:
            per_class_acc.append(100.0 * correct_per_class[c] / total_per_class[c])
        else:
            per_class_acc.append(None)
    return overall_acc, per_class_acc

In [None]:
'''
This cell defines the training function for one epoch, collects selected indices, computes overlaps, and evaluates the model.
'''
def train_epoch_topk_collect(model, optimizer, data_loader, k):
    '''
    Train the model for one epoch, collecting selected indices and computing overlaps.
    '''
    model.train()
    overlaps = []
    selected_history = []
    for clean_images, labels in data_loader:
        clean_images, labels = clean_images.to(device), labels.to(device)
        adv_images = pgd_attack(model, clean_images, labels, EPSILON, ALPHA, PGD_STEPS_TRAIN)
        combined_images = torch.cat([clean_images, adv_images], dim=0)
        combined_labels = torch.cat([labels, labels], dim=0)
        with torch.no_grad():
            outs = model(combined_images)
            losses = F.cross_entropy(outs, combined_labels, reduction='none')

        sel_idx = torch.topk(losses, k).indices.to(device)
        selected_history.append(sel_idx.clone())

        topk_idx = sel_idx
        overlap = compute_overlap_indices(losses, topk_idx, sel_idx)
        overlaps.append(overlap)

        final_images = combined_images[sel_idx]
        final_labels = combined_labels[sel_idx]
        if final_images.size(0) > 0:
            optimizer.zero_grad()
            preds = model(final_images)
            loss = F.cross_entropy(preds, final_labels)
            loss.backward()
            optimizer.step()

    mean_overlap = float(np.mean(overlaps)) if len(overlaps)>0 else 0.0
    stability = compute_selection_stability(selected_history, k)
    return mean_overlap, stability

def run_experiment_seed_topk(seed, method_name=METHOD_NAME_TOPK):
    '''
    Run the Top-K experiment with a specific random seed.
    '''
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    print(f"\n=== TOPK RUN seed {seed} ===")
    model = models.resnet18(weights=None, num_classes=NUM_CLASSES).to(device)
    optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(0.75*EPOCHS), int(0.9*EPOCHS)], gamma=0.1)

    history = {
        'epoch': [],
        'std_acc': [],
        'robust_acc': [],
        'epoch_time': [],
        'cumulative_time': [],
        'overlap': [],
        'selection_stability': [],
        'per_class_std_acc': [],
        'per_class_robust_acc': []
    }

    start_time = time.time()
    for epoch in range(1, EPOCHS + 1):
        t0 = time.time()
        mean_overlap, stability = train_epoch_topk_collect(model, optimizer, train_loader, K_SAMPLES)

        std_acc, per_class_std = evaluate_per_class(model, test_loader, attack_fn=None, num_classes=NUM_CLASSES)
        robust_acc, per_class_rob = evaluate_per_class(model, test_loader, attack_fn=pgd_attack, num_classes=NUM_CLASSES)

        scheduler.step()
        epoch_time = time.time() - t0
        cumulative_time = time.time() - start_time

        history['epoch'].append(epoch)
        history['std_acc'].append(std_acc)
        history['robust_acc'].append(robust_acc)
        history['epoch_time'].append(epoch_time)
        history['cumulative_time'].append(cumulative_time)
        history['overlap'].append(mean_overlap)
        history['selection_stability'].append(stability)
        history['per_class_std_acc'].append(per_class_std)
        history['per_class_robust_acc'].append(per_class_rob)

        print(f"Seed {seed} Epoch {epoch}/{EPOCHS} | Std: {std_acc:.2f}% | Robust: {robust_acc:.2f}% | Overlap: {mean_overlap:.3f} | Stability: {stability:.3f} | EpochTime: {epoch_time:.1f}s")

    out = {
        'experiment_name': f"{method_name}_seed{seed}",
        'hyperparameters': {
            'batch_size': BATCH_SIZE,
            'learning_rate': LEARNING_RATE,
            'epochs': EPOCHS,
            'k_percentage': K_PERCENTAGE,
            'k_samples': K_SAMPLES,
            'epsilon': EPSILON
        },
        'training_history': history,
        'final_summary': {
            'final_std_acc': history['std_acc'][-1],
            'final_robust_acc': history['robust_acc'][-1],
            'total_training_time': history['cumulative_time'][-1]
        }
    }
    seed_fname = os.path.join(OUT_DIR, f"{method_name}_seed{seed}.json")
    with open(seed_fname, 'w') as f:
        json.dump(out, f, indent=4)
    print("Saved TopK seed results to", seed_fname)
    return seed_fname, out

In [None]:
'''
Run the Top-K experiment with a specific random seed.
'''
seed_files_topk = []
seed_outputs_topk = []

for s in SEEDS:
    fname, out = run_experiment_seed_topk(s, method_name=METHOD_NAME_TOPK)
    seed_files_topk.append(fname)
    seed_outputs_topk.append(out)

min_epochs = min(len(o['training_history']['epoch']) for o in seed_outputs_topk)
metrics = ['std_acc','robust_acc','epoch_time','cumulative_time','overlap','selection_stability']

agg_history = {'epoch': list(range(1, min_epochs+1))}
for m in metrics:
    arr = np.array([o['training_history'][m][:min_epochs] for o in seed_outputs_topk], dtype=float)
    agg_history[m + '_mean'] = list(np.nanmean(arr, axis=0))
    agg_history[m + '_std']  = list(np.nanstd(arr, axis=0, ddof=1))

per_class_std = np.array([o['training_history']['per_class_std_acc'][:min_epochs] for o in seed_outputs_topk], dtype=float)
per_class_rob = np.array([o['training_history']['per_class_robust_acc'][:min_epochs] for o in seed_outputs_topk], dtype=float)
per_class_std_mean = np.nanmean(per_class_std, axis=0).tolist()
per_class_std_std  = np.nanstd(per_class_std, axis=0, ddof=1).tolist()
per_class_rob_mean = np.nanmean(per_class_rob, axis=0).tolist()
per_class_rob_std  = np.nanstd(per_class_rob, axis=0, ddof=1).tolist()

aggregate_output_topk = {
    'experiment_name': METHOD_NAME_TOPK + "_aggregate",
    'seed_files': seed_files_topk,
    'hyperparameters': {
        'batch_size': BATCH_SIZE,
        'learning_rate': LEARNING_RATE,
        'epochs': EPOCHS,
        'k_percentage': K_PERCENTAGE,
        'k_samples': K_SAMPLES,
        'epsilon': EPSILON
    },
    'training_history_aggregate': agg_history,
    'per_class_std_mean_per_epoch': per_class_std_mean,
    'per_class_std_std_per_epoch': per_class_std_std,
    'per_class_robust_mean_per_epoch': per_class_rob_mean,
    'per_class_robust_std_per_epoch': per_class_rob_std,
    'final_summary_aggregate': {
        'final_std_acc_mean': float(np.nanmean([o['final_summary']['final_std_acc'] for o in seed_outputs_topk])),
        'final_std_acc_std' : float(np.nanstd([o['final_summary']['final_std_acc'] for o in seed_outputs_topk], ddof=1)),
        'final_robust_acc_mean': float(np.nanmean([o['final_summary']['final_robust_acc'] for o in seed_outputs_topk])),
        'final_robust_acc_std' : float(np.nanstd([o['final_summary']['final_robust_acc'] for o in seed_outputs_topk], ddof=1)),
        'total_training_time_mean': float(np.nanmean([o['final_summary']['total_training_time'] for o in seed_outputs_topk])),
        'total_training_time_std' : float(np.nanstd([o['final_summary']['total_training_time'] for o in seed_outputs_topk], ddof=1))
    }
}

agg_fname_topk = os.path.join(OUT_DIR, f"{METHOD_NAME_TOPK}_aggregate.json")
with open(agg_fname_topk, 'w') as f:
    json.dump(aggregate_output_topk, f, indent=4)

print("Saved TopK aggregate results to", agg_fname_topk)
print("TopK seed files:", seed_files_topk)