In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import math
import time
from typing import Callable, Optional, Tuple, Dict, Any

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)

Using: cuda


In [None]:
def zo_adamm(
    # required
    model_or_predict: Any,            # either a torch.nn.Module OR a callable predict_fn(batch_tensor)->logits
    x0,                               # numpy array or torch tensor with shape (C,H,W) or (1,C,H,W)
    true_label: int,

    # ZO / VR hyperparams
    mu: float = 1e-2,
    q: int = 64,
    Q_ref: int = 256,
    inner_m: int = 20,
    epochs: int = 5,
    dist_weight: float = 0.01,

    # Adam-like
    alpha: float = 0.05,
    beta1: float = 0.9,
    beta2: float = 0.99,
    eps: float = 1e-8,
    v_init: float = 1e-6,

    # projection / constraints
    constrained: bool = True,
    lb: float = -0.5,
    ub: float = 0.5,

    # batching / device
    device: Optional[torch.device] = None,
    forward_batch_size: int = 256,   # split model batches when evaluating many probes
    rng_seed: int = 1234,

    # stopping / budgets
    early_stop: bool = True,
    stop_threshold: float = 0.0,
    max_queries: int = 50000,
    max_distortion: Optional[float] = None,
    patience: int = 40,
    tol: float = 1e-3,
    smooth_window: int = 20,
    flat_threshold: float = 1e-3,

    # misc
    predict_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
    verbose: bool = True,
) -> Tuple[np.ndarray, Dict[str,Any]]:
    """
    ZO–AdaMM (variance-reduction removed) generalized for arbitrary PyTorch model.

    Returns:
      delta (numpy array shape d,) : perturbation vector to add to x0 (flattened)
      diagnostics: dict with keys: queries, history: {loss, queries, dist}, stop_reason, time_elapsed

    Query counting convention: 1 forward on B images -> +B queries.
    """

    # ---------- prepare device and predict function ----------
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # if user passed a model, make default predict_fn
    if predict_fn is None:
        if callable(model_or_predict):
            # assume model_or_predict is a torch.nn.Module
            model = model_or_predict
            model.eval()
            model.to(device)
            def _predict_fn(batch: torch.Tensor) -> torch.Tensor:
                # expects batch on same device
                with torch.no_grad():
                    return model(batch)
            predict = _predict_fn
        else:
            raise ValueError("Either provide a model (nn.Module) as model_or_predict or pass predict_fn explicitly.")
    else:
        predict = predict_fn

    # ---------- prepare x0 and shapes ----------
    if isinstance(x0, torch.Tensor):
        x0_t = x0.detach().cpu().float()
    else:
        x0_t = torch.from_numpy(np.array(x0)).float()

    # support x0 shaped (C,H,W) or (1,C,H,W)
    if x0_t.ndim == 3:
        C,H,W = x0_t.shape
    elif x0_t.ndim == 4 and x0_t.shape[0] == 1:
        _, C,H,W = x0_t.shape
        x0_t = x0_t.squeeze(0)
    else:
        raise ValueError("x0 must have shape (C,H,W) or (1,C,H,W)")

    d = int(C*H*W)
    x_shape = (C,H,W)

    # move base x0 to device once
    x0_dev = x0_t.to(device)

    torch.manual_seed(rng_seed)
    rng = np.random.RandomState(rng_seed)

    # ---------- helpers ----------
    def sample_dirs(q_local:int) -> torch.Tensor:
        U = torch.randn((q_local, d), device=device)
        norms = U.norm(dim=1, keepdim=True)
        norms = torch.where(norms == 0, torch.ones_like(norms), norms)
        return U / norms  # (q, d)

    def batched_predict(batch_imgs: torch.Tensor) -> torch.Tensor:
        """
        Run predict on batch_imgs (N,C,H,W) in sub-batches sized forward_batch_size.
        Returns logits tensor (N, n_classes)
        Also increments queries by N (handled by callers).
        """
        n = batch_imgs.shape[0]
        outs = []
        for i in range(0, n, forward_batch_size):
            chunk = batch_imgs[i:i+forward_batch_size].to(device)
            with torch.no_grad():
                out = predict(chunk)
            outs.append(out.detach().cpu())
        return torch.cat(outs, dim=0).to(device)

    def imgs_from_deltas(deltas: torch.Tensor) -> torch.Tensor:
        # deltas: (q, d) or (B, d) on device -> returns (q, C, H, W)
        return (x0_dev.unsqueeze(0) + deltas.view(-1, C, H, W))

    # ZO symmetric estimator using batched forwards:
    def symmetric_zo_grad(x: torch.Tensor, mu_t: float, directions: torch.Tensor) -> Tuple[torch.Tensor,int]:
        """
        x: tensor (d,) on device
        directions: (q, d) on device
        returns: grad_est (d,) on device, queries_added (int)
        """
        q_local = int(directions.shape[0])

        delta_plus  = x.unsqueeze(0) + mu_t * directions   # (q,d)
        delta_minus = x.unsqueeze(0) - mu_t * directions   # (q,d)

        imgs_plus  = imgs_from_deltas(delta_plus).cpu()
        imgs_minus = imgs_from_deltas(delta_minus).cpu()

        batch = torch.cat([imgs_plus, imgs_minus], dim=0)  # (2q, C, H, W)
        logits = batched_predict(batch)  # (2q, nclass)
        q_inc = batch.shape[0]

        probs = F.softmax(logits, dim=1)  # (2q, C)

        # compute margin per sample: log p_true - log max_other
        probs_clone = probs.clone()
        probs_clone[:, true_label] = float('-inf')
        p_true = probs[:, true_label]            # (2q,)
        p_other = probs_clone.max(dim=1).values  # (2q,)

        margin = torch.log(p_true + 1e-12) - torch.log(p_other + 1e-12)  # (2q,)

        # distortion term per delta
        dist_plus  = (delta_plus.norm(dim=1)**2).detach()  # (q,)
        dist_minus = (delta_minus.norm(dim=1)**2).detach() # (q,)

        f_plus  = (margin[:q_local] + dist_weight * dist_plus).detach()
        f_minus = (margin[q_local:] + dist_weight * dist_minus).detach()

        dir_deriv = (f_plus - f_minus) / (2.0 * mu_t)  # (q,)
        grad_est = (d / float(q_local)) * (directions.t() @ dir_deriv)  # (d,)
        grad_est = grad_est.to(device)

        return grad_est, q_inc

    def mahalanobis_projection(x: torch.Tensor, sqrt_vhat: torch.Tensor) -> torch.Tensor:
        # x, sqrt_vhat, x0_dev: all on device
        z = sqrt_vhat * x
        Lb = sqrt_vhat * (lb - x0_dev.view(-1))
        Ub = sqrt_vhat * (ub - x0_dev.view(-1))
        z_clipped = torch.max(torch.min(z, Ub), Lb)
        return z_clipped / sqrt_vhat

    # ---------- initialization ----------
    # keep x_snap variable for minimal edits (not used for VR anymore)
    x_snap = torch.zeros(d, device=device)
    x_curr = torch.zeros(d, device=device)
    m = torch.zeros(d, device=device)
    v = v_init * torch.ones(d, device=device)
    vhat = v.clone()

    query_count = 0
    loss_trace = []
    queries_trace = []
    dist_trace = []
    best_val = float("inf")
    no_improve = 0
    start_time = time.time()

    mu_t = float(mu)

    # ---------- stopping helper ----------
    def check_stopping(val_f: float, delta: torch.Tensor) -> Tuple[bool,str]:
        nonlocal best_val, no_improve, query_count
        if val_f < stop_threshold:
            return True, "margin reached"
        if (max_distortion is not None) and (float(delta.norm()) > max_distortion):
            return True, "distortion limit exceeded"
        if val_f < best_val - tol:
            best_val = val_f
            no_improve = 0
        else:
            no_improve += 1
        if no_improve > patience:
            return True, "patience exhausted"
        # moving average flattening
        loss_trace.append(val_f)
        if len(loss_trace) > 2 * smooth_window:
            recent = np.mean(loss_trace[-smooth_window:])
            previous = np.mean(loss_trace[-2*smooth_window:-smooth_window])
            if previous - recent < flat_threshold:
                return True, "moving-average plateau"
        if query_count >= max_queries:
            return True, "query budget exhausted"
        return False, None

    # ---------- main loop ----------
    stop_reason = None
    hist = {"loss": [], "queries": [], "dist": [], "time": []}

    for epoch in range(epochs):
        # --- VR reference removed: no reference gradient is computed here anymore ---

        # reset moments (optionally you can keep them)
        m.zero_()
        v = v_init * torch.ones_like(v)
        vhat = v.clone()

        for t in range(inner_m):
            directions = sample_dirs(q)
            g_curr, q_inc = symmetric_zo_grad(x_curr, mu_t, directions); query_count += q_inc

            # plain ZO gradient (no variance reduction)
            g_vr = g_curr

            # Adam/AMSGrad
            m = beta1 * m + (1.0 - beta1) * g_vr
            v = beta2 * v + (1.0 - beta2) * (g_vr * g_vr)
            vhat = torch.max(vhat, v)

            step = alpha * m / (torch.sqrt(vhat) + eps)
            x_new = x_curr - step

            if constrained:
                sqrt_vhat = torch.sqrt(vhat)
                x_new = mahalanobis_projection(x_new, sqrt_vhat)

            x_curr = x_new

            # evaluate objective once (single forward)
            img = (x0_dev + x_curr.view(C,H,W)).unsqueeze(0)  # (1, C, H, W)
            logits = batched_predict(img)   # returns (1, nclass) on device
            query_count += 1
            probs = F.softmax(logits, dim=1)[0]
            probs_clone = probs.clone()
            probs_clone[true_label] = float('-inf')
            p_true = probs[true_label].item()
            p_other = probs_clone.max().item()
            margin = np.log(p_true + 1e-12) - np.log(p_other + 1e-12)
            distortion = float(x_curr.norm()**2)
            val = margin + dist_weight * distortion

            elapsed = time.time() - start_time
            hist["loss"].append(val)
            hist["queries"].append(query_count)
            hist["dist"].append(np.sqrt(distortion))
            hist["time"].append(elapsed)

            if verbose and (t % 10 == 0):
                print(f"[epoch {epoch} step {t}] f={val:.6f} queries={query_count} time={elapsed:.1f}s dist={np.sqrt(distortion):.3f}")

            if early_stop:
                stop, reason = check_stopping(val, x_curr)
                if stop:
                    stop_reason = reason
                    if verbose:
                        print("Early stop:", reason, "final f:", val, "queries:", query_count)
                    delta_final = x_curr.detach().cpu().numpy()
                    diagnostics = {
                        "queries": int(query_count),
                        "stop_reason": stop_reason,
                        "history": hist,
                        "time_elapsed": time.time() - start_time
                    }
                    return delta_final, diagnostics

        # keep same assignment to x_snap for minimal edits (harmless)
        x_snap = x_curr.clone()

    # finished all epochs
    delta_final = x_curr.detach().cpu().numpy()
    diagnostics = {
        "queries": int(query_count),
        "stop_reason": stop_reason or "completed_epochs",
        "history": hist,
        "time_elapsed": time.time() - start_time
    }
    if verbose:
        print("Finished epochs; stop_reason:", diagnostics["stop_reason"], "queries:", diagnostics["queries"])
    return delta_final, diagnostics


In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),                # Output: (64, 14, 14)
        )
        self.fc = nn.Sequential(
            nn.Linear(64*14*14, 128), nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        z = self.conv(x)              # (B, 64, 14, 14)
        z = z.view(z.size(0), -1)     # Flatten correctly
        return self.fc(z)

In [None]:
# MNIST loading
model = SimpleCNN().to(device)

model.load_state_dict(torch.load("best_mnist_cnn.pt"))

transform = T.Compose([
    T.ToTensor(),
    T.Lambda(lambda x: x - 0.5)
])

trainset = torchvision.datasets.MNIST("./data", train=True, download=True, transform=transform)
testset  = torchvision.datasets.MNIST("./data", train=False, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testloader  = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False)

# Training setup
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=5, gamma=0.5)

def evaluate(model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in testloader:
            x, y = x.to(device), y.to(device)
            pred = model(x).argmax(1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return correct / total


100%|██████████| 9.91M/9.91M [00:00<00:00, 18.8MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 510kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.55MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 12.8MB/s]


In [None]:
def evaluate_full_mnist(model, testloader, attack_fn, max_samples=None):
    """
    attack_fn(model, x0 (C,H,W), true_label) → (delta_flat, diagnostics)
    Supports ANY batch size from the testloader.
    """

    device = next(model.parameters()).device

    success = 0
    distortions = []
    queries_list = []
    total = 0

    for X, Y in tqdm(testloader, desc="Evaluating MNIST"):

        # Move whole batch to device
        X = X.to(device)     # shape (B,1,28,28)
        Y = Y.to(device)     # shape (B,)

        # Clean predictions for whole batch
        with torch.no_grad():
            preds = model(X).argmax(1)        # shape (B,)

        # Loop over each element in batch
        batch_size = X.size(0)

        for i in range(batch_size):

            if preds[i].item() != Y[i].item():
                continue   # skip misclassified samples

            # Prepare CPU np array for attack
            x_i_cpu = X[i].cpu().numpy()   # (1,28,28)
            y_i_cpu = Y[i].cpu().item()

            # Run attack
            delta_flat, diag = attack_fn(model, x_i_cpu, y_i_cpu)
            queries = diag["queries"]

            # Build adversarial sample
            adv = x_i_cpu + delta_flat.reshape(1, 28, 28)
            adv_t = torch.from_numpy(adv).float().unsqueeze(0).to(device)

            # Predict adversarial label
            with torch.no_grad():
                adv_pred = model(adv_t).argmax(1).item()

            total += 1
            dist = np.linalg.norm(delta_flat)

            if adv_pred != y_i_cpu:
                success += 1
                distortions.append(dist)
                queries_list.append(queries)

            if max_samples and total >= max_samples:
                break

        if max_samples and total >= max_samples:
            break

    success_rate = success / total if total > 0 else 0

    return {
        "success_rate": success_rate,
        "distortions": distortions,
        "queries": queries_list,
        "total_attacked": total
    }


# ----------------------------------------------------------
#  Attack wrapper for the generalized VR–ZO–AdamM
# ----------------------------------------------------------

def attack_wrapper(model):
    def attack_fn(model, x0_CHW, true_label):
        # x0_CHW: numpy array (C,H,W), e.g. (1,28,28)
        delta_flat, diag = zo_adamm(
            model_or_predict=model,
            x0=x0_CHW,                      # no flattening
            true_label=true_label,
            mu=0.01,
            q=64,
            Q_ref=256,
            inner_m=20,
            epochs=5,
            alpha=0.05,
            constrained=True,
            lb=-0.5,
            ub=0.5,
            verbose=False
        )
        return delta_flat, diag

    return attack_fn


attack_fn = attack_wrapper(model)

In [None]:
import itertools
import pandas as pd
import csv
from datetime import datetime

# ---------------------------------------
# Define search ranges
# ---------------------------------------
# hyper_space = {
#     "mu": [0.0005, 0.001, 0.002, 0.003, 0.004, 0.005],
#     "q": [4, 6, 8, 10],
#     "alpha": [0.001, 0.002, 0.005],
#     "inner_m": [5, 10],
#     "epochs": [2, 3]
# }
    # "mu": [0.005, 0.01, 0.02, 0.03],
hyper_space = {
    "mu": [0.01, 0.02, 0.03],
    "q": [64,32],
    "alpha": [0.05, 0.1],
    "inner_m": [5, 10],
    "epochs": [3]
}

# Cartesian product of all hyperparameter combos
keys = list(hyper_space.keys())
search_list = list(itertools.product(*hyper_space.values()))

# ---------------------------------------
# CSV output file
# ---------------------------------------
csv_file = "mnist_zoadamm_hyperparam_results_1.csv"

# Create CSV header
header = keys + [
    "success_rate",
    "avg_distortion",
    "avg_queries",
    "total_attacked",
    "timestamp"
]

# Start fresh
with open(csv_file, "w", newline='') as f:
    writer = csv.writer(f)
    writer.writerow(header)

# ---------------------------------------
# Loop over combinations
# ---------------------------------------
for values in search_list:

    params = dict(zip(keys, values))

    print("\n====================================")
    print("Running hyperparameters:", params)
    print("====================================")

    # Build attack_fn with these parameters
    def attack_wrapper_hp(model):
        def attack_fn(model, x0_CHW, true_label):
            delta, diag = zo_adamm(
                model_or_predict=model,
                x0=x0_CHW,
                true_label=true_label,
                mu=params["mu"],
                q=params["q"],
                inner_m=params["inner_m"],
                epochs=params["epochs"],
                alpha=params["alpha"],
                Q_ref=256,
                constrained=True,
                lb=-1.0,
                ub=1.0,
                verbose=False
            )
            return delta, diag
        return attack_fn

    attack_fn_hp = attack_wrapper_hp(model)

    # Evaluate
    results = evaluate_full_mnist(
        model=model,
        testloader=testloader,
        attack_fn=attack_fn_hp,
        max_samples=None
    )

    success_rate   = results["success_rate"]
    avg_distortion = float(np.mean(results["distortions"])) if results["distortions"] else None
    avg_queries    = float(np.mean(results["queries"])) if results["queries"] else None
    total_attacked = results["total_attacked"]

    # Print metrics
    print("Success rate:", success_rate)
    print("Avg distortion:", avg_distortion)
    print("Avg queries:", avg_queries)
    # ---------------------------------------
    # Write row to CSV
    # ---------------------------------------
    row = list(values) + [
        success_rate,
        avg_distortion,
        avg_queries,
        total_attacked,
        datetime.now().isoformat()
    ]

    with open(csv_file, "a", newline='') as f:
        writer = csv.writer(f)
        writer.writerow(row)

    print(f"✔ Logged results to {csv_file}")



Running hyperparameters: {'mu': 0.01, 'q': 64, 'alpha': 0.05, 'inner_m': 5, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [03:48<00:00,  5.71s/it]


Success rate: 0.997582107596212
Avg distortion: 4.776084899902344
Avg queries: 605.3046859220359
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.01, 'q': 64, 'alpha': 0.05, 'inner_m': 10, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [03:42<00:00,  5.56s/it]


Success rate: 0.998388071730808
Avg distortion: 4.898564338684082
Avg queries: 596.0034308779011
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.01, 'q': 64, 'alpha': 0.1, 'inner_m': 5, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [02:27<00:00,  3.70s/it]


Success rate: 0.9982873262139835
Avg distortion: 6.536474704742432
Avg queries: 378.537995761429
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.01, 'q': 64, 'alpha': 0.1, 'inner_m': 10, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [02:29<00:00,  3.74s/it]


Success rate: 0.998791053798106
Avg distortion: 6.540417194366455
Avg queries: 379.4791204357474
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.01, 'q': 32, 'alpha': 0.05, 'inner_m': 5, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [03:28<00:00,  5.21s/it]


Success rate: 0.9968768889784405
Avg distortion: 5.231054782867432
Avg queries: 366.8969176351693
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.01, 'q': 32, 'alpha': 0.05, 'inner_m': 10, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [03:19<00:00,  4.98s/it]


Success rate: 0.9984888172476325
Avg distortion: 5.473419666290283
Avg queries: 351.8565230551912
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.01, 'q': 32, 'alpha': 0.1, 'inner_m': 5, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [02:12<00:00,  3.30s/it]


Success rate: 0.9982873262139835
Avg distortion: 7.227199554443359
Avg queries: 219.408618427692
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.01, 'q': 32, 'alpha': 0.1, 'inner_m': 10, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [02:11<00:00,  3.29s/it]


Success rate: 0.998791053798106
Avg distortion: 7.242759704589844
Avg queries: 219.75035303611054
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.02, 'q': 64, 'alpha': 0.05, 'inner_m': 5, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [03:48<00:00,  5.71s/it]


Success rate: 0.997582107596212
Avg distortion: 4.775692462921143
Avg queries: 605.1874368814381
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.02, 'q': 64, 'alpha': 0.05, 'inner_m': 10, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [03:45<00:00,  5.64s/it]


Success rate: 0.998388071730808
Avg distortion: 4.897448539733887
Avg queries: 595.795156407669
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.02, 'q': 64, 'alpha': 0.1, 'inner_m': 5, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [02:29<00:00,  3.74s/it]


Success rate: 0.9981865806971589
Avg distortion: 6.536055088043213
Avg queries: 378.34184497375855
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.02, 'q': 64, 'alpha': 0.1, 'inner_m': 10, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [02:30<00:00,  3.76s/it]


Success rate: 0.9988917993149304
Avg distortion: 6.539890289306641
Avg queries: 379.60998487140694
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.02, 'q': 32, 'alpha': 0.05, 'inner_m': 5, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [03:26<00:00,  5.15s/it]


Success rate: 0.9968768889784405
Avg distortion: 5.23148250579834
Avg queries: 366.93633148054573
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.02, 'q': 32, 'alpha': 0.05, 'inner_m': 10, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [03:17<00:00,  4.95s/it]


Success rate: 0.9986903082812815
Avg distortion: 5.4729228019714355
Avg queries: 352.13961464743267
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.02, 'q': 32, 'alpha': 0.1, 'inner_m': 5, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [02:10<00:00,  3.27s/it]


Success rate: 0.998589562764457
Avg distortion: 7.225176811218262
Avg queries: 219.59140435835351
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.02, 'q': 32, 'alpha': 0.1, 'inner_m': 10, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [02:11<00:00,  3.28s/it]


Success rate: 0.9990932903485795
Avg distortion: 7.241096496582031
Avg queries: 219.99193304426743
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.03, 'q': 64, 'alpha': 0.05, 'inner_m': 5, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [03:47<00:00,  5.68s/it]


Success rate: 0.997582107596212
Avg distortion: 4.773733615875244
Avg queries: 604.7705514037568
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.03, 'q': 64, 'alpha': 0.05, 'inner_m': 10, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [03:42<00:00,  5.57s/it]


Success rate: 0.998388071730808
Avg distortion: 4.895926475524902
Avg queries: 595.5868819374369
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.03, 'q': 64, 'alpha': 0.1, 'inner_m': 5, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [02:27<00:00,  3.70s/it]


Success rate: 0.9981865806971589
Avg distortion: 6.534821510314941
Avg queries: 378.1856075898264
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.03, 'q': 64, 'alpha': 0.1, 'inner_m': 10, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [02:28<00:00,  3.71s/it]


Success rate: 0.998992544831755
Avg distortion: 6.539707183837891
Avg queries: 380.1180919725696
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.03, 'q': 32, 'alpha': 0.05, 'inner_m': 5, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [03:22<00:00,  5.05s/it]


Success rate: 0.9967761434616159
Avg distortion: 5.230449199676514
Avg queries: 366.72377198302
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.03, 'q': 32, 'alpha': 0.05, 'inner_m': 10, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [03:14<00:00,  4.85s/it]


Success rate: 0.9984888172476325
Avg distortion: 5.470466613769531
Avg queries: 351.6597719705378
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.03, 'q': 32, 'alpha': 0.1, 'inner_m': 5, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [02:08<00:00,  3.22s/it]


Success rate: 0.9984888172476325
Avg distortion: 7.223358154296875
Avg queries: 219.45615982241952
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv

Running hyperparameters: {'mu': 0.03, 'q': 32, 'alpha': 0.1, 'inner_m': 10, 'epochs': 3}


Evaluating MNIST: 100%|██████████| 40/40 [02:09<00:00,  3.23s/it]

Success rate: 0.999194035865404
Avg distortion: 7.239182949066162
Avg queries: 220.04184311353094
✔ Logged results to mnist_zoadamm_hyperparam_results_1.csv



