In [None]:
import os, time, json
import torch, torch.nn as nn, torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import numpy as np
import pandas as pd

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

BATCH_SIZE = 128
LEARNING_RATE = 0.1
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
EPOCHS = 100
K_PERCENTAGE = 1
K_SAMPLES = int(BATCH_SIZE * K_PERCENTAGE)

EPSILON, ALPHA = 8/255, 2/255
PGD_STEPS_TRAIN, PGD_STEPS_EVAL = 10, 20

print(f"K = {K_SAMPLES}, PGD steps (train/eval) = {PGD_STEPS_TRAIN}/{PGD_STEPS_EVAL}")

In [None]:
'''
Data loading and transformations for CIFAR-10
'''
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])
transform_test = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset  = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print("Loaded CIFAR-10:", len(train_dataset), "train,", len(test_dataset), "test")

In [None]:
'''
Utility functions for ADMM projection operations
'''
def project_shifted_lp_ball_torch(x, p=2, eps=1e-9):
    '''
    Projects x onto the shifted Lp ball of radius (n^(1/p))/2, where n is the number of elements in x.
    '''
    orig_shape = x.shape
    x = x.reshape(-1)
    shift_vec = 0.5 * torch.ones_like(x, device=x.device)
    shift_x = x - shift_vec
    normp_shift = torch.norm(shift_x, p=p)
    target = (x.numel() ** (1.0 / p)) / 2.0
    if normp_shift.item() < eps:
        return shift_vec.reshape(orig_shape)
    xp = (target / (normp_shift + 1e-20)) * shift_x + shift_vec
    return xp.reshape(orig_shape)

def project_cardinality_topk_torch(x, k):
    '''
    Projects x onto the top-k cardinality set, returning a binary tensor of the same shape as x.
    '''
    x_flat = x.reshape(-1)
    k = int(k)
    if k <= 0:
        return torch.zeros_like(x_flat).reshape(x.shape)
    _, idx = torch.topk(x_flat, k)
    y = torch.zeros_like(x_flat)
    y[idx] = 1.0
    return y.reshape(x.shape)

def admm_selection_discrete_torch(V, d, all_params=None, warm_x0=None, device=device):
    """
    ADMM solver that uses discrete projection for y1 (top-k) at each iteration.
    V: 1D tensor (losses) shape (n,)
    d: k (cardinality)
    all_params: dict of ADMM params
    Returns binary-like selection vector y1 (float tensor with 0/1s) and x_sol continuous
    """
    initial_params = {
        'stop_threshold':1e-4, 'gamma_val':1.0, 'rho_change_step':5,
        'max_iters':200, 'initial_rho':50.0, 'learning_fact':1.005,
        'projection_lp':2, 'eps':1e-9
    }
    if all_params is None:
        all_params = initial_params
    else:
        for k in initial_params:
            if k not in all_params:
                all_params[k] = initial_params[k]

    V = V.reshape(-1).to(device).float()
    n = V.numel()
    k_val = int(d)

    # warm start
    if warm_x0 is not None and warm_x0.shape[0] == n:
        x_sol = warm_x0.reshape(-1).to(device).float().clone()
    else:
        x_sol = torch.rand(n, device=device, dtype=torch.float32)

    y1 = x_sol.clone()
    y2 = x_sol.clone()
    z1 = torch.zeros_like(y1, device=device)
    z2 = torch.zeros_like(y2, device=device)
    z3 = torch.zeros(1, device=device)

    rho1 = rho2 = rho3 = float(all_params['initial_rho'])
    gamma_val = float(all_params['gamma_val'])
    max_iters = int(all_params['max_iters'])
    stop_threshold = float(all_params['stop_threshold'])
    p = float(all_params['projection_lp'])
    learning_fact = float(all_params['learning_fact'])
    eps = float(all_params['eps'])
    u = torch.ones(n, device=device, dtype=torch.float32)

    for it in range(max_iters):
        y1 = project_cardinality_topk_torch(x_sol + z1 / rho1, k_val)
        y2 = project_shifted_lp_ball_torch(x_sol + z2 / rho2, p=p)
        q = (V - z1 - z2 - z3 * u) + rho1 * y1 + rho2 * y2 + rho3 * (d * u)

        alpha = float(rho1 + rho2)
        beta = float(rho3)
        denom = alpha * (alpha + beta * n) + 1e-30
        factor = beta / denom
        sum_q = torch.sum(q)
        x_new = q / alpha - factor * sum_q * u
        x_sol = x_new

        z1 = z1 + gamma_val * rho1 * (x_sol - y1)
        z2 = z2 + gamma_val * rho2 * (x_sol - y2)
        z3 = z3 + gamma_val * rho3 * (torch.sum(x_sol) - float(d))

        if (it + 1) % int(all_params['rho_change_step']) == 0:
            rho1 *= learning_fact
            rho2 *= learning_fact
            rho3 *= learning_fact
            gamma_val = max(gamma_val * 0.95, 1.0)

        norm_x = torch.norm(x_sol) if torch.norm(x_sol) > 0 else torch.tensor(1.0, device=device)
        res1 = torch.norm(x_sol - y1) / (norm_x + eps)
        res2 = torch.norm(x_sol - y2) / (norm_x + eps)
        if max(res1.item(), res2.item()) <= stop_threshold:
            break

    sel = torch.nonzero(y1.reshape(-1) >= 0.5).reshape(-1)
    if sel.numel() != k_val:
        _, sel = torch.topk(x_sol, k_val)
    return sel, x_sol.detach()

class ADMM_Discrete_Solver_Torch:
    '''
    ADMM solver for discrete selection problems using PyTorch.
    '''
    def __init__(self, n, k, admm_params=None, device=device, use_warmstart=True):
        self.n = n
        self.k = int(k)
        self.device = device
        self.admm_params = admm_params if admm_params is not None else {
            'max_iters':200, 'stop_threshold':1e-4, 'initial_rho':50.0, 'projection_lp':2,
            'rho_change_step':5, 'learning_fact':1.005, 'gamma_val':1.0, 'eps':1e-9
        }
        self.last_x = None
        self.use_warmstart = use_warmstart

    def solve(self, V):
        '''
        Solve the discrete selection problem using ADMM.
        '''
        warm = self.last_x if (self.use_warmstart and self.last_x is not None and self.last_x.shape[0] == self.n) else None
        sel, x_sol = admm_selection_discrete_torch(V, d=self.k, all_params=self.admm_params, warm_x0=warm, device=self.device)
        self.last_x = x_sol.clone()
        return sel, x_sol

In [None]:
'''
PGD attack function for adversarial training and evaluation
'''
def pgd_attack(model, images, labels, epsilon, alpha, iters):
    '''
    PGD attack implementation for adversarial training and evaluation.
    '''
    images = images.clone().detach().to(device)
    labels = labels.to(device)
    orig = images.clone().detach()
    for _ in range(iters):
        images.requires_grad = True
        outs = model(images)
        loss = F.cross_entropy(outs, labels)
        model.zero_grad()
        loss.backward()
        images = images + alpha * images.grad.sign()
        eta = torch.clamp(images - orig, -epsilon, epsilon)
        images = torch.clamp(orig + eta, 0.0, 1.0).detach()
    return images

def compute_overlap_indices(losses, indices_a, indices_b):
    '''
    Utility to compute overlap between two sets of indices.
    '''
    sa = set(indices_a.cpu().numpy().tolist())
    sb = set(indices_b.cpu().numpy().tolist())
    return len(sa & sb) / float(len(sa))

admm_discrete_solver = ADMM_Discrete_Solver_Torch(n=2*BATCH_SIZE, k=K_SAMPLES, admm_params=None, device=device, use_warmstart=True)

def select_admm_discrete_wrapper(losses, k):
    '''
    losses: tensor on device, shape (2*BATCH_SIZE,)
    '''
    sel, scores = admm_discrete_solver.solve(losses)
    return sel, scores

def train_epoch_admm_discrete(model, optimizer, data_loader, admm_solver, k):
    '''
    Train epoch that returns mean overlap vs TopK for the epoch.
    '''
    model.train()
    overlaps = []
    for clean_images, labels in data_loader:
        clean_images, labels = clean_images.to(device), labels.to(device)
        adv_images = pgd_attack(model, clean_images, labels, EPSILON, ALPHA, PGD_STEPS_TRAIN)
        combined_images = torch.cat([clean_images, adv_images], dim=0)
        combined_labels = torch.cat([labels, labels], dim=0)
        with torch.no_grad():
            outs = model(combined_images)
            losses = F.cross_entropy(outs, combined_labels, reduction='none')

        topk_idx = torch.topk(losses, k).indices

        sel_idx, _ = admm_solver.solve(losses) 
        sel_idx = sel_idx.to(device)

        overlap = compute_overlap_indices(losses, topk_idx, sel_idx)
        overlaps.append(overlap)

        final_images = combined_images[sel_idx]
        final_labels = combined_labels[sel_idx]
        if final_images.size(0) > 0:
            optimizer.zero_grad()
            predictions = model(final_images)
            loss = F.cross_entropy(predictions, final_labels)
            loss.backward()
            optimizer.step()

    mean_overlap = float(np.mean(overlaps)) if overlaps else 0.0
    return mean_overlap

def evaluate(model, data_loader, attack_fn=None):
    '''
    Evaluate the model on the test set.
    '''
    model.eval()
    total_correct, total = 0, 0
    for images, labels in data_loader:
        images, labels = images.to(device), labels.to(device)
        if attack_fn is not None:
            images = pgd_attack(model, images, labels, EPSILON, ALPHA, PGD_STEPS_EVAL)
        with torch.no_grad():
            outs = model(images)
            _, preds = torch.max(outs, 1)
            total += labels.size(0)
            total_correct += (preds == labels).sum().item()
    return 100.0 * total_correct / total

In [None]:
'''
Checkpoint utilities for saving and loading model state
'''
def save_checkpoint(epoch, model, optimizer, scheduler, history, path):
    '''Saves checkpoint to a file.'''
    dir_name = os.path.dirname(path)
    
    if dir_name and not os.path.exists(dir_name):
        os.makedirs(dir_name)
    
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'history': history
    }, path)
    print(f"Checkpoint saved to {path}")

def load_checkpoint(path, model, optimizer, scheduler):
    '''Loads checkpoint from a file.'''
    start_epoch = 0
    history = {'epoch':[], 'std_acc':[], 'robust_acc':[], 'cumulative_time':[], 'epoch_time':[], 'overlap':[]}
    
    if path and os.path.exists(path):
        checkpoint = torch.load(path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch']
        for key in history.keys():
            if key not in checkpoint['history']:
                checkpoint['history'][key] = []
        history = checkpoint['history']
        print(f"Checkpoint loaded from {path}. Resuming from epoch {start_epoch + 1}")
    else:
        print("No checkpoint found. Starting from scratch.")
        
    return start_epoch, history

In [None]:
'''
Experiment runner for ADMM-discrete training
'''
def run_experiment_admm_discrete(method_name="ADMM_discrete_test", checkpoint_to_load=None, new_checkpoint_base='admm_discrete_ckpt'):
    print("Running:", method_name)
    model = models.resnet18(weights=None, num_classes=10).to(device)
    optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[75,90], gamma=0.1)

    start_epoch = 0
    history = {}
    if checkpoint_to_load:
        start_epoch, history = load_checkpoint(checkpoint_to_load, model, optimizer, scheduler)
    else:
        start_epoch, history = load_checkpoint(None, model, optimizer, scheduler)


    start_time = time.time()
    for epoch in range(start_epoch, EPOCHS):
        t0 = time.time()
        
        mean_overlap = train_epoch_admm_discrete(model, optimizer, train_loader, admm_discrete_solver, K_SAMPLES)
        
        std_acc = evaluate(model, test_loader, attack_fn=None)
        robust_acc = evaluate(model, test_loader, attack_fn=pgd_attack)
        
        scheduler.step()
        
        epoch_time = time.time() - t0
        cumulative_time = time.time() - start_time
        
        history['epoch'].append(epoch + 1)
        history['std_acc'].append(std_acc)
        history['robust_acc'].append(robust_acc)
        history['cumulative_time'].append(cumulative_time)
        history['epoch_time'].append(epoch_time)
        history['overlap'].append(mean_overlap)

        print(f"Epoch {epoch+1}/{EPOCHS} | Std Acc: {std_acc:.2f}% | Robust Acc: {robust_acc:.2f}% | Overlap(topk): {mean_overlap:.3f} | Epoch Time: {epoch_time:.1f}s")

        if (epoch + 1) % 10 == 0 or (epoch + 1) == EPOCHS:
            save_checkpoint(epoch + 1, model, optimizer, scheduler, history, f"{new_checkpoint_base}_epoch_{epoch+1}.pth")

    fname = f"results_{method_name}.json"
    final_results = {
        'final_std_acc': history['std_acc'][-1],
        'final_robust_acc': history['robust_acc'][-1],
        'total_training_time': history['cumulative_time'][-1]
    }
    
    output_data = {
        'experiment_name': method_name,
        'hyperparameters': {
            'batch_size': BATCH_SIZE,
            'learning_rate': LEARNING_RATE,
            'epochs': EPOCHS,
            'k_percentage': K_PERCENTAGE,
            'epsilon': EPSILON
        },
        'training_history': history, 
        'final_summary': final_results
    }
    
    with open(fname, 'w') as f:
        json.dump(output_data, f, indent=4)
    print(f"Full experiment results saved to {fname}")
    
    return history

history_admm_discrete = run_experiment_admm_discrete(method_name=f"ADMM_discrete_k{K_SAMPLES}")