In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import time
import pandas as pd
import warnings
from scipy.sparse import csc_matrix
import os
import json
import glob
import importlib

lpbox_admm_code = """
import numpy as np
import scipy.sparse.linalg as linalg
from scipy.sparse import csc_matrix

def ADMM_bqp_linear_eq(A,b,C,d, all_params=None):
    initial_params = {'stop_threshold':1e-4,'gamma_val':1.6,'rho_change_step':5, \
    'max_iters':1e3,'initial_rho':25,'learning_fact':1+1/100,'x0':None,'pcg_tol':1e-4, 'pcg_maxiters':1e3, 'projection_lp':2}
    if all_params==None: all_params = initial_params
    else:
        for k in initial_params.keys():
            if k not in all_params.keys(): all_params[k] = initial_params[k]
    n = b.size
    stop_threshold, max_iters, initial_rho = all_params['stop_threshold'], all_params['max_iters'], all_params['initial_rho']
    rho_change_step, gamma_val, learning_fact = all_params['rho_change_step'], all_params['gamma_val'], all_params['learning_fact']
    projection_lp, pcg_tol, pcg_maxiters = all_params['projection_lp'], all_params['pcg_tol'], all_params['pcg_maxiters']
    x_sol = all_params['x0'] if all_params['x0'] is not None else np.random.rand(n, 1)
    y1, y2 = x_sol.copy(), x_sol.copy()
    z1, z2, z3 = np.zeros_like(y1), np.zeros_like(y2), np.zeros_like(d)
    rho1, rho2, rho3 = initial_rho, initial_rho, initial_rho
    Csq = csc_matrix(C.transpose() @ C)
    for iter_num in range(int(max_iters)):
        y1 = project_box(x_sol + z1 / rho1)
        y2 = project_shifted_Lp_ball(x_sol + z2 / rho2, projection_lp)
        diag_rho = csc_matrix(( (rho1 + rho2) * np.ones(n), (range(n), range(n)) ), shape=(n, n))
        M = 2 * A + rho3 * Csq + diag_rho
        q = -(b + z1 + z2 + C.transpose() @ z3) + rho1 * y1 + rho2 * y2 + rho3 * C.transpose() @ d
        x_sol, cg_flag = linalg.cg(M, q, x_sol, atol=pcg_tol, maxiter=int(pcg_maxiters))
        x_sol = x_sol.reshape(-1, 1)
        z1 += gamma_val * rho1 * (x_sol - y1)
        z2 += gamma_val * rho2 * (x_sol - y2)
        z3 += gamma_val * rho3 * (C @ x_sol - d)
        if (iter_num + 1) % rho_change_step == 0:
            rho1, rho2, rho3 = rho1 * learning_fact, rho2 * learning_fact, rho3 * learning_fact
            gamma_val = max(gamma_val * 0.95, 1) # Using the factor directly
        res1 = np.linalg.norm(x_sol - y1) / max(np.linalg.norm(x_sol), 1e-16)
        res2 = np.linalg.norm(x_sol - y2) / max(np.linalg.norm(x_sol), 1e-16)
        if max(res1, res2) <= stop_threshold: break
    return x_sol

def project_box(x):
    return np.clip(x, 0, 1)

def project_shifted_Lp_ball(x, p):
    shift_vec = 0.5 * np.ones_like(x)
    shift_x = x - shift_vec
    normp_shift = np.linalg.norm(shift_x, p)
    n = x.size
    if normp_shift < 1e-9: return shift_vec
    xp = (n**(1/p)) * shift_x / (2 * normp_shift) + shift_vec
    return xp
"""

with open("lpbox_admm.py", "w") as f:
    f.write(lpbox_admm_code)

import lpbox_admm
importlib.reload(lpbox_admm)
from lpbox_admm import ADMM_bqp_linear_eq

print("lpbox_admm.py (clean version) created and imported successfully.")

warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

BATCH_SIZE = 128
LEARNING_RATE = 0.1
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
EPOCHS = 100
K_PERCENTAGE = 0.5
K_SAMPLES = int(BATCH_SIZE * K_PERCENTAGE)

EPSILON, ALPHA = 8/255, 2/255
PGD_STEPS_TRAIN, PGD_STEPS_EVAL = 10, 20

print(f"\nHyperparameters set. K_SAMPLES = {K_SAMPLES}")

In [None]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f"Dataset loaded. Train samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")
print(f"Number of training batches per epoch: {len(train_loader)}")

In [None]:
def pgd_attack(model, images, labels, epsilon, alpha, iters):
    images = images.clone().detach().to(device)
    labels = labels.clone().detach().to(device)
    original_images = images.clone().detach()
    for i in range(iters):
        images.requires_grad = True
        outputs = model(images)
        model.zero_grad()
        cost = F.cross_entropy(outputs, labels)
        cost.backward()
        adv_images = images + alpha * images.grad.sign()
        eta = torch.clamp(adv_images - original_images, min=-epsilon, max=epsilon)
        images = torch.clamp(original_images + eta, min=0, max=1).detach()
    return images

def select_informed(losses, k):
    _, indices = torch.topk(losses, k)
    return indices

class ADMM_Selection_Solver:
    '''Wrapper class for our selection problem.'''
    def __init__(self, n, k):
        self.A = csc_matrix((n, n))
        self.C = np.ones((1, n))
        self.d = np.array([[k]], dtype=np.float64)
        self.admm_params = {
            'max_iters': 100,
            'stop_threshold': 1e-3,
            'initial_rho': 10,
            'projection_lp': 2
        }

    def solve(self, V):
        '''Solves the selection problem and returns the CONTINUOUS scores vector.'''
        V_np = V.cpu().detach().numpy().reshape(-1, 1)
        b = -V_np
        x0 = np.random.rand(len(V), 1)
        self.admm_params['x0'] = x0
        
        x_sol = ADMM_bqp_linear_eq(self.A, b, self.C, self.d, self.admm_params)
        
        return torch.from_numpy(x_sol.flatten()).to(V.device)

admm_selection_solver = ADMM_Selection_Solver(n=2 * BATCH_SIZE, k=K_SAMPLES)

def select_admm_sampler_ref(losses, k):
    '''
    This is the ROBUST and PURE selection function.
    1. It gets the continuous scores (from 0 to 1) from the ADMM solver.
    2. It selects the top 'k' samples based on these scores.
    This guarantees that exactly 'k' samples are always selected.
    '''
    admm_scores = admm_selection_solver.solve(losses)    
    _, selected_indices = torch.topk(admm_scores, k)
    
    return selected_indices

In [None]:
def save_checkpoint(epoch, model, optimizer, scheduler, history, filename):
    '''Saves the model state to a .pth file, safely handling directory creation.'''
    state = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'history': history
    }
    
    directory = os.path.dirname(filename)

    if directory:
        os.makedirs(directory, exist_ok=True)
        
    torch.save(state, filename)
    print(f"Checkpoint saved to {filename} at epoch {epoch + 1}")

def load_checkpoint(filename, model, optimizer, scheduler):
    '''Loads the model state from a .pth file.'''
    if not os.path.exists(filename):
        print(f"No checkpoint found at {filename}. Starting training from scratch.")
        return 0, {'epoch': [], 'std_acc': [], 'robust_acc': [], 'time': []}
    
    print(f"Attempting to load checkpoint from: {filename}")
    checkpoint = torch.load(filename, map_location=device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    history = checkpoint['history']
    
    print(f"Checkpoint loaded successfully. Resuming training from epoch {start_epoch}.")
    return start_epoch, history

def save_experiment_result(method_name, history_df, final_res):
    '''Saves the final result and history to a JSON file.'''
    safe_name = method_name.replace(' ', '_').replace('(', '').replace(')', '').replace('=', '').replace(',', '')
    filename = f"results_{safe_name}.json"
    history_json = history_df.to_json(orient='split')
    with open(filename, 'w') as f:
        json.dump({'history': history_json, 'final_result': final_res}, f, indent=4)
    print(f"\nFinal result for '{method_name}' saved to {filename}")

def train_epoch(model, optimizer, data_loader, training_mode, selection_fn=None, k=None):
    '''A flexible training epoch function - NO TQDM.'''
    model.train()
    for clean_images, labels in data_loader:
        clean_images, labels = clean_images.to(device), labels.to(device)
        adv_images = pgd_attack(model, clean_images, labels, EPSILON, ALPHA, PGD_STEPS_TRAIN)
        
        if training_mode == 'selection':
            combined_images = torch.cat([clean_images, adv_images], dim=0)
            combined_labels = torch.cat([labels, labels], dim=0)
            with torch.no_grad():
                outputs = model(combined_images)
                losses = F.cross_entropy(outputs, combined_labels, reduction='none')
            selected_indices = selection_fn(losses, k)
            final_images = combined_images[selected_indices]
            final_labels = combined_labels[selected_indices]
        else:
             raise ValueError(f"This script is configured for 'selection' mode.")

        if len(final_images) > 0:
            optimizer.zero_grad()
            outputs = model(final_images)
            loss = F.cross_entropy(outputs, final_labels)
            loss.backward()
            optimizer.step()

def evaluate(model, data_loader, attack_fn=None):
    '''Evaluates the model on the given data_loader - NO TQDM.'''
    model.eval()
    total_correct, total_samples = 0, 0
    for images, labels in data_loader:
        images, labels = images.to(device), labels.to(device)
        if attack_fn:
            images = attack_fn(model, images, labels, epsilon=EPSILON, alpha=ALPHA, iters=PGD_STEPS_EVAL)
        with torch.no_grad():
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            total_correct += (predicted == labels).sum().item()
    return 100. * total_correct / total_samples

def run_experiment(method_name, training_mode, checkpoint_to_load, new_checkpoint_base, result_filename, selection_fn=None, k=None):
    '''Runs a full training and evaluation experiment with checkpointing capability.'''
    print(f"\n{'='*20} Running Experiment: {method_name} {'='*20}")
    
    model = models.resnet18(weights=None, num_classes=10).to(device)
    optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[75, 90], gamma=0.1)
    
    start_epoch, history = load_checkpoint(checkpoint_to_load, model, optimizer, scheduler)
    
    if start_epoch >= EPOCHS:
        print("Training for this method is already complete.")
        return

    start_time = time.time() - (history['time'][-1] if history['time'] else 0)
    
    for epoch in range(start_epoch, EPOCHS):
        epoch_start_time = time.time()
        train_epoch(model, optimizer, train_loader, training_mode, selection_fn, k)
        
        std_acc = evaluate(model, test_loader, attack_fn=None)
        robust_acc = evaluate(model, test_loader, attack_fn=pgd_attack)
        scheduler.step()
        
        history['epoch'].append(epoch + 1)
        history['std_acc'].append(std_acc)
        history['robust_acc'].append(robust_acc)
        history['time'].append(time.time() - start_time)
        
        epoch_time = time.time() - epoch_start_time
        print(f"Epoch {epoch+1}/{EPOCHS} -> Std Acc: {std_acc:.2f}%, Robust Acc: {robust_acc:.2f}%, Time: {epoch_time:.2f}s")
        
        if (epoch + 1) % 10 == 0 or (epoch + 1) == EPOCHS:
            save_checkpoint(epoch, model, optimizer, scheduler, history, f"{new_checkpoint_base}_{epoch+1}.pth")

    total_time_hours = (time.time() - start_time) / 3600
    final_results = {
        'Method': method_name,
        'SA (%)': history['std_acc'][-1],
        'RA (PGD-20, %)': history['robust_acc'][-1],
        'Training Time (hours)': total_time_hours
    }
    history_df = pd.DataFrame(history)
    save_experiment_result(method_name, history_df, final_results)    
    print(pd.DataFrame([final_results]).to_string(index=False))

In [None]:
method_name = f"ADMM Sampler (k={K_SAMPLES})"
safe_name = method_name.replace(' ', '_').replace('(', '').replace(')', '').replace('=', '').replace(',', '')

checkpoint_to_load = "/kaggle/input/admm/pytorch/default/1/checkpoint_*.pth"
new_checkpoint_base_path = f"checkpoint_{safe_name}"
result_filename = f"results_{safe_name}.json"

print(f"Preparing to RESUME experiment: {method_name} with Top-K Strategy")
print(f"Loading from: {checkpoint_to_load}")
print(f"Saving new checkpoints with base name: {new_checkpoint_base_path}_[epoch].pth")
print(f"Final results will be saved to: {result_filename}")

if os.path.exists(result_filename):
    print(f"Final result file '{result_filename}' already exists. Skipping experiment.")
    try:
        with open(result_filename, 'r') as f:
            data = json.load(f)
        print("\n--- Existing Final Results ---")
        print(pd.DataFrame([data['final_result']]).to_string(index=False))
    except Exception as e:
        print(f"Could not read existing result file: {e}")
else:
    run_experiment(
        method_name=method_name,
        training_mode='selection',
        selection_fn=select_admm_sampler_ref,
        k=K_SAMPLES,
        checkpoint_to_load=checkpoint_to_load,
        new_checkpoint_base=new_checkpoint_base_path,
        result_filename=result_filename
    )