Skip to content

RobustBench/robustbench

Repository files navigation

RobustBench: a standardized adversarial robustness benchmark

Francesco Croce* (University of Tübingen), Maksym Andriushchenko* (EPFL), Vikash Sehwag* (Princeton University), Edoardo Debenedetti* (EPFL), Nicolas Flammarion (EPFL), Mung Chiang (Purdue University), Prateek Mittal (Princeton University), Matthias Hein (University of Tübingen)

Leaderboard: https://robustbench.github.io/

Paper: https://arxiv.org/abs/2010.09670

❗Note❗: if you experience problems with the automatic downloading of the models from Google Drive, install the latest version of RobustBench via pip install git+https://github.com/RobustBench/robustbench.git.

News

  • May 2022: We have extended the common corruptions leaderboard on ImageNet with 3D Common Corruptions (ImageNet-3DCC). ImageNet-3DCC evaluation is interesting since (1) it includes more realistic corruptions and (2) it can be used to assess generalization of the existing models which may have overfitted to ImageNet-C. For a quickstart, click here. Note that the entries in leaderboard are still sorted according to ImageNet-C performance.

  • May 2022: We fixed the preprocessing issue for ImageNet corruption evaluations: previously we used resize to 256x256 and central crop to 224x224 which wasn't necessary since the ImageNet-C images are already 224x224 (see this issue). Note that this changed the ranking between the top-1 and top-2 entries.

Main idea

The goal of RobustBench is to systematically track the real progress in adversarial robustness. There are already more than 3'000 papers on this topic, but it is still often unclear which approaches really work and which only lead to overestimated robustness. We start from benchmarking the Linf, L2, and common corruption robustness since these are the most studied settings in the literature.

Evaluation of the robustness to Lp perturbations in general is not straightforward and requires adaptive attacks (Tramer et al., (2020)). Thus, in order to establish a reliable standardized benchmark, we need to impose some restrictions on the defenses we consider. In particular, we accept only defenses that are (1) have in general non-zero gradients wrt the inputs, (2) have a fully deterministic forward pass (i.e. no randomness) that (3) does not have an optimization loop. Often, defenses that violate these 3 principles only make gradient-based attacks harder but do not substantially improve robustness (Carlini et al., (2019)) except those that can present concrete provable guarantees (e.g. Cohen et al., (2019)).

To prevent potential overadaptation of new defenses to AutoAttack, we also welcome external evaluations based on adaptive attacks, especially where AutoAttack flags a potential overestimation of robustness. For each model, we are interested in the best known robust accuracy and see AutoAttack and adaptive attacks as complementary to each other.

RobustBench consists of two parts:

  • a website https://robustbench.github.io/ with the leaderboard based on many recent papers (plots below 👇)
  • a collection of the most robust models, Model Zoo, which are easy to use for any downstream application (see the tutorial below after FAQ 👇)

FAQ

Q: How does the RobustBench leaderboard differ from the AutoAttack leaderboard? 🤔
A: The AutoAttack leaderboard was the starting point of RobustBench. Now only the RobustBench leaderboard is actively maintained.

Q: How does the RobustBench leaderboard differ from robust-ml.org? 🤔
A: robust-ml.org focuses on adaptive evaluations, but we provide a standardized benchmark. Adaptive evaluations have been very useful (e.g., see Tramer et al., 2020) but they are also very time-consuming and not standardized by definition. Instead, we argue that one can estimate robustness accurately mostly without adaptive attacks but for this one has to introduce some restrictions on the considered models. However, we do welcome adaptive evaluations and we are always interested in showing the best known robust accuracy.

Q: How is it related to libraries like foolbox / cleverhans / advertorch? 🤔
A: These libraries provide implementations of different attacks. Besides the standardized benchmark, RobustBench additionally provides a repository of the most robust models. So you can start using the robust models in one line of code (see the tutorial below 👇).

Q: Why is Lp-robustness still interesting? 🤔
A: There are numerous interesting applications of Lp-robustness that span transfer learning (Salman et al. (2020), Utrera et al. (2020)), interpretability (Tsipras et al. (2018), Kaur et al. (2019), Engstrom et al. (2019)), security (Tramèr et al. (2018), Saadatpanah et al. (2019)), generalization (Xie et al. (2019), Zhu et al. (2019), Bochkovskiy et al. (2020)), robustness to unseen perturbations (Xie et al. (2019), Kang et al. (2019)), stabilization of GAN training (Zhong et al. (2020)).

Q: What about verified adversarial robustness? 🤔
A: We mostly focus on defenses which improve empirical robustness, given the lack of clarity regarding which approaches really improve robustness and which only make some particular attacks unsuccessful. However, we do not restrict submissions of verifiably robust models (e.g., we have Zhang et al. (2019) in our CIFAR-10 Linf leaderboard). For methods targeting verified robustness, we encourage the readers to check out Salman et al. (2019) and Li et al. (2020).

Q: What if I have a better attack than the one used in this benchmark? 🤔
A: We will be happy to add a better attack or any adaptive evaluation that would complement our default standardized attacks.

Model Zoo: quick tour

The goal of our Model Zoo is to simplify the usage of robust models as much as possible. Check out our Colab notebook here 👉 RobustBench: quick start for a quick introduction. It is also summarized below 👇.

First, install the latest version of RobustBench (recommended):

pip install git+https://github.com/RobustBench/robustbench.git

or the latest stable version of RobustBench (it is possible that automatic downloading of the models may not work):

pip install git+https://github.com/RobustBench/robustbench.git@v1.0

Now let's try to load CIFAR-10 and some quite robust CIFAR-10 models from Carmon2019Unlabeled that achieves 59.53% robust accuracy evaluated with AA under eps=8/255:

from robustbench.data import load_cifar10

x_test, y_test = load_cifar10(n_examples=50)

from robustbench.utils import load_model

model = load_model(model_name='Carmon2019Unlabeled', dataset='cifar10', threat_model='Linf')

Let's try to evaluate the robustness of this model. We can use any favourite library for this. For example, FoolBox implements many different attacks. We can start from a simple PGD attack:

!pip install -q foolbox
import foolbox as fb
fmodel = fb.PyTorchModel(model, bounds=(0, 1))

_, advs, success = fb.attacks.LinfPGD()(fmodel, x_test.to('cuda:0'), y_test.to('cuda:0'), epsilons=[8/255])
print('Robust accuracy: {:.1%}'.format(1 - success.float().mean()))
>>> Robust accuracy: 58.0%

Wonderful! Can we do better with a more accurate attack?

Let's try to evaluate its robustness with a cheap version AutoAttack from ICML 2020 with 2/4 attacks (only APGD-CE and APGD-DLR):

# autoattack is installed as a dependency of robustbench so there is not need to install it separately
from autoattack import AutoAttack
adversary = AutoAttack(model, norm='Linf', eps=8/255, version='custom', attacks_to_run=['apgd-ce', 'apgd-dlr'])
adversary.apgd.n_restarts = 1
x_adv = adversary.run_standard_evaluation(x_test, y_test)
>>> initial accuracy: 92.00%
>>> apgd-ce - 1/1 - 19 out of 46 successfully perturbed
>>> robust accuracy after APGD-CE: 54.00% (total time 10.3 s)
>>> apgd-dlr - 1/1 - 1 out of 27 successfully perturbed
>>> robust accuracy after APGD-DLR: 52.00% (total time 17.0 s)
>>> max Linf perturbation: 0.03137, nan in tensor: 0, max: 1.00000, min: 0.00000
>>> robust accuracy: 52.00%

Note that for our standardized evaluation of Linf-robustness we use the full version of AutoAttack which is slower but more accurate (for that just use adversary = AutoAttack(model, norm='Linf', eps=8/255)).

What about other types of perturbations? Is Lp-robustness useful there? We can evaluate the available models on more general perturbations. For example, let's take images corrupted by fog perturbations from CIFAR-10-C with the highest level of severity (5). Are different Linf robust models perform better on them?

from robustbench.data import load_cifar10c
from robustbench.utils import clean_accuracy

corruptions = ['fog']
x_test, y_test = load_cifar10c(n_examples=1000, corruptions=corruptions, severity=5)

for model_name in ['Standard', 'Engstrom2019Robustness', 'Rice2020Overfitting',
                   'Carmon2019Unlabeled']:
 model = load_model(model_name, dataset='cifar10', threat_model='Linf')
 acc = clean_accuracy(model, x_test, y_test)
 print(f'Model: {model_name}, CIFAR-10-C accuracy: {acc:.1%}')
>>> Model: Standard, CIFAR-10-C accuracy: 74.4%
>>> Model: Engstrom2019Robustness, CIFAR-10-C accuracy: 38.8%
>>> Model: Rice2020Overfitting, CIFAR-10-C accuracy: 22.0%
>>> Model: Carmon2019Unlabeled, CIFAR-10-C accuracy: 31.1%

As we can see, all these Linf robust models perform considerably worse than the standard model on this type of corruptions. This curious phenomenon was first noticed in Adversarial Examples Are a Natural Consequence of Test Error in Noise and explained from the frequency perspective in A Fourier Perspective on Model Robustness in Computer Vision.

However, on average adversarial training does help on CIFAR-10-C. One can check this easily by loading all types of corruptions via load_cifar10c(n_examples=1000, severity=5), and repeating evaluation on them.

*New*: Evaluating robustness of ImageNet models against 3D Common Corruptions (ImageNet-3DCC)

3D Common Corruptions (3DCC) is a recent benchmark by Kar et al. (CVPR 2022) using scene geometry to generate realistic corruptions. You can evaluate robustness of a standard ResNet-50 against ImageNet-3DCC by following these steps:

  1. Download the data from here using the provided tool. The data will be saved into a folder named ImageNet-3DCC.

  2. Run the sample evaluation script to obtain accuracies and save them in a pickle file:

import torch 
from robustbench.data import load_imagenet3dcc
from robustbench.utils import clean_accuracy, load_model

corruptions_3dcc = ['near_focus', 'far_focus', 'bit_error', 'color_quant', 
                   'flash', 'fog_3d', 'h265_abr', 'h265_crf',
                   'iso_noise', 'low_light', 'xy_motion_blur', 'z_motion_blur'] # 12 corruptions in ImageNet-3DCC

device = torch.device("cuda:0")
model = load_model('Standard_R50', dataset='imagenet', threat_model='corruptions').to(device)
for corruption in corruptions_3dcc:
    for s in [1, 2, 3, 4, 5]:  # 5 severity levels
        x_test, y_test = load_imagenet3dcc(n_examples=5000, corruptions=[corruption], severity=s, data_dir=$PATH_IMAGENET_3DCC)
        acc = clean_accuracy(model, x_test.to(device), y_test.to(device), device=device)
        print(f'Model: {model_name}, ImageNet-3DCC corruption: {corruption} severity: {s} accuracy: {acc:.1%}')

Model Zoo

In order to use a model, you just need to know its ID, e.g. Carmon2019Unlabeled, and to run:

from robustbench import load_model

model = load_model(model_name='Carmon2019Unlabeled', dataset='cifar10', threat_model='Linf')

which automatically downloads the model (all models are defined in model_zoo/models.py).

Reproducing evaluation of models from the Model Zoo can be done directly from the command line. Here is an example of an evaluation of Salman2020Do_R18 model with AutoAttack on ImageNet for eps=4/255=0.0156862745:

python -m robustbench.eval --n_ex=5000 --dataset=imagenet --threat_model=Linf --model_name=Salman2020Do_R18 --data_dir=/tmldata1/andriush/imagenet --batch_size=128 --eps=0.0156862745

The CIFAR-10, CIFAR-10-C, CIFAR-100, and CIFAR-100-C datasets are downloaded automatically. However, the ImageNet datasets should be downloaded manually due to their licensing:

  • ImageNet: Obtain the download link here (requires just signing up from an academic email, the approval system there is automatic and happens instantly) and then follow the instructions here to extract the validation set in a pytorch-compatible format into folder val.
  • ImageNet-C: Please visit here for the instructions.
  • ImageNet-3DCC: Download the data from here using the provided tool. The data will be saved into a folder named ImageNet-3DCC.

In order to use the models from the Model Zoo, you can find all available model IDs in the tables below. Note that the full leaderboard contains a bit more models which we either have not yet added to the Model Zoo or their authors don't want them to appear in the Model Zoo.

CIFAR-10

Linf, eps=8/255

# Model ID Paper Clean accuracy Robust accuracy Architecture Venue
1 Peng2023Robust Robust Principles: Architectural Design Principles for Adversarially Robust CNNs 93.27% 71.07% RaWideResNet-70-16 BMVC 2023
2 Wang2023Better_WRN-70-16 Better Diffusion Models Further Improve Adversarial Training 93.25% 70.69% WideResNet-70-16 ICML 2023
3 Bai2024MixedNUTS MixedNUTS: Training-Free Accuracy-Robustness Balance via Nonlinearly Mixed Classifiers 95.19% 69.71% ResNet-152 + WideResNet-70-16 arXiv, Feb 2024
4 Bai2023Improving_edm Improving the Accuracy-Robustness Trade-off of Classifiers via Adaptive Smoothing 95.23% 68.06% ResNet-152 + WideResNet-70-16 + mixing network SIMODS 2024
5 Cui2023Decoupled_WRN-28-10 Decoupled Kullback-Leibler Divergence Loss 92.16% 67.73% WideResNet-28-10 arXiv, May 2023
6 Wang2023Better_WRN-28-10 Better Diffusion Models Further Improve Adversarial Training 92.44% 67.31% WideResNet-28-10 ICML 2023
7 Rebuffi2021Fixing_70_16_cutmix_extra Fixing Data Augmentation to Improve Adversarial Robustness 92.23% 66.56% WideResNet-70-16 arXiv, Mar 2021
8 Gowal2021Improving_70_16_ddpm_100m Improving Robustness using Generated Data 88.74% 66.10% WideResNet-70-16 NeurIPS 2021
9 Gowal2020Uncovering_70_16_extra Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples 91.10% 65.87% WideResNet-70-16 arXiv, Oct 2020
10 Huang2022Revisiting_WRN-A4 Revisiting Residual Networks for Adversarial Robustness: An Architectural Perspective 91.58% 65.79% WideResNet-A4 arXiv, Dec. 2022
11 Rebuffi2021Fixing_106_16_cutmix_ddpm Fixing Data Augmentation to Improve Adversarial Robustness 88.50% 64.58% WideResNet-106-16 arXiv, Mar 2021
12 Rebuffi2021Fixing_70_16_cutmix_ddpm Fixing Data Augmentation to Improve Adversarial Robustness 88.54% 64.20% WideResNet-70-16 arXiv, Mar 2021
13 Kang2021Stable Stable Neural ODE with Lyapunov-Stable Equilibrium Points for Defending Against Adversarial Attacks 93.73% 64.20% WideResNet-70-16, Neural ODE block NeurIPS 2021
14 Xu2023Exploring_WRN-28-10 Exploring and Exploiting Decision Boundary Dynamics for Adversarial Robustness 93.69% 63.89% WideResNet-28-10 ICLR 2023
15 Gowal2021Improving_28_10_ddpm_100m Improving Robustness using Generated Data 87.50% 63.38% WideResNet-28-10 NeurIPS 2021
16 Pang2022Robustness_WRN70_16 Robustness and Accuracy Could Be Reconcilable by (Proper) Definition 89.01% 63.35% WideResNet-70-16 ICML 2022
17 Rade2021Helper_extra Helper-based Adversarial Training: Reducing Excessive Margin to Achieve a Better Accuracy vs. Robustness Trade-off 91.47% 62.83% WideResNet-34-10 OpenReview, Jun 2021
18 Sehwag2021Proxy_ResNest152 Robust Learning Meets Generative Models: Can Proxy Distributions Improve Adversarial Robustness? 87.30% 62.79% ResNest152 ICLR 2022
19 Gowal2020Uncovering_28_10_extra Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples 89.48% 62.76% WideResNet-28-10 arXiv, Oct 2020
20 Huang2021Exploring_ema Exploring Architectural Ingredients of Adversarially Robust Deep Neural Networks 91.23% 62.54% WideResNet-34-R NeurIPS 2021
21 Huang2021Exploring Exploring Architectural Ingredients of Adversarially Robust Deep Neural Networks 90.56% 61.56% WideResNet-34-R NeurIPS 2021
22 Dai2021Parameterizing Parameterizing Activation Functions for Adversarial Robustness 87.02% 61.55% WideResNet-28-10-PSSiLU arXiv, Oct 2021
23 Pang2022Robustness_WRN28_10 Robustness and Accuracy Could Be Reconcilable by (Proper) Definition 88.61% 61.04% WideResNet-28-10 ICML 2022
24 Rade2021Helper_ddpm Helper-based Adversarial Training: Reducing Excessive Margin to Achieve a Better Accuracy vs. Robustness Trade-off 88.16% 60.97% WideResNet-28-10 OpenReview, Jun 2021
25 Rebuffi2021Fixing_28_10_cutmix_ddpm Fixing Data Augmentation to Improve Adversarial Robustness 87.33% 60.73% WideResNet-28-10 arXiv, Mar 2021
26 Sridhar2021Robust_34_15 Improving Neural Network Robustness via Persistency of Excitation 86.53% 60.41% WideResNet-34-15 ACC 2022
27 Sehwag2021Proxy Robust Learning Meets Generative Models: Can Proxy Distributions Improve Adversarial Robustness? 86.68% 60.27% WideResNet-34-10 ICLR 2022
28 Wu2020Adversarial_extra Adversarial Weight Perturbation Helps Robust Generalization 88.25% 60.04% WideResNet-28-10 NeurIPS 2020
29 Sridhar2021Robust Improving Neural Network Robustness via Persistency of Excitation 89.46% 59.66% WideResNet-28-10 ACC 2022
30 Zhang2020Geometry Geometry-aware Instance-reweighted Adversarial Training 89.36% 59.64% WideResNet-28-10 ICLR 2021
31 Carmon2019Unlabeled Unlabeled Data Improves Adversarial Robustness 89.69% 59.53% WideResNet-28-10 NeurIPS 2019
32 Gowal2021Improving_R18_ddpm_100m Improving Robustness using Generated Data 87.35% 58.50% PreActResNet-18 NeurIPS 2021
33 Chen2024Data_WRN_34_20 Data filtering for efficient adversarial training 86.10% 58.09% WideResNet-34-20 Pattern Recognition 2024
34 Addepalli2021Towards_WRN34 Scaling Adversarial Training to Large Perturbation Bounds 85.32% 58.04% WideResNet-34-10 ECCV 2022
35 Addepalli2022Efficient_WRN_34_10 Efficient and Effective Augmentation Strategy for Adversarial Training 88.71% 57.81% WideResNet-34-10 NeurIPS 2022
36 Chen2021LTD_WRN34_20 LTD: Low Temperature Distillation for Robust Adversarial Training 86.03% 57.71% WideResNet-34-20 arXiv, Nov 2021
37 Rade2021Helper_R18_extra Helper-based Adversarial Training: Reducing Excessive Margin to Achieve a Better Accuracy vs. Robustness Trade-off 89.02% 57.67% PreActResNet-18 OpenReview, Jun 2021
38 Jia2022LAS-AT_70_16 LAS-AT: Adversarial Training with Learnable Attack Strategy 85.66% 57.61% WideResNet-70-16 arXiv, Mar 2022
39 Debenedetti2022Light_XCiT-L12 A Light Recipe to Train Robust Vision Transformers 91.73% 57.58% XCiT-L12 arXiv, Sep 2022
40 Chen2024Data_WRN_34_10 Data filtering for efficient adversarial training 86.54% 57.30% WideResNet-34-10 Pattern Recognition 2024
41 Debenedetti2022Light_XCiT-M12 A Light Recipe to Train Robust Vision Transformers 91.30% 57.27% XCiT-M12 arXiv, Sep 2022
42 Sehwag2020Hydra HYDRA: Pruning Adversarially Robust Neural Networks 88.98% 57.14% WideResNet-28-10 NeurIPS 2020
43 Gowal2020Uncovering_70_16 Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples 85.29% 57.14% WideResNet-70-16 arXiv, Oct 2020
44 Rade2021Helper_R18_ddpm Helper-based Adversarial Training: Reducing Excessive Margin to Achieve a Better Accuracy vs. Robustness Trade-off 86.86% 57.09% PreActResNet-18 OpenReview, Jun 2021
45 Cui2023Decoupled_WRN-34-10 Decoupled Kullback-Leibler Divergence Loss 85.31% 57.09% WideResNet-34-10 arXiv, May 2023
46 Chen2021LTD_WRN34_10 LTD: Low Temperature Distillation for Robust Adversarial Training 85.21% 56.94% WideResNet-34-10 arXiv, Nov 2021
47 Gowal2020Uncovering_34_20 Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples 85.64% 56.82% WideResNet-34-20 arXiv, Oct 2020
48 Rebuffi2021Fixing_R18_ddpm Fixing Data Augmentation to Improve Adversarial Robustness 83.53% 56.66% PreActResNet-18 arXiv, Mar 2021
49 Wang2020Improving Improving Adversarial Robustness Requires Revisiting Misclassified Examples 87.50% 56.29% WideResNet-28-10 ICLR 2020
50 Jia2022LAS-AT_34_10 LAS-AT: Adversarial Training with Learnable Attack Strategy 84.98% 56.26% WideResNet-34-10 arXiv, Mar 2022
51 Wu2020Adversarial Adversarial Weight Perturbation Helps Robust Generalization 85.36% 56.17% WideResNet-34-10 NeurIPS 2020
52 Debenedetti2022Light_XCiT-S12 A Light Recipe to Train Robust Vision Transformers 90.06% 56.14% XCiT-S12 arXiv, Sep 2022
53 Sehwag2021Proxy_R18 Robust Learning Meets Generative Models: Can Proxy Distributions Improve Adversarial Robustness? 84.59% 55.54% ResNet-18 ICLR 2022
54 Hendrycks2019Using Using Pre-Training Can Improve Model Robustness and Uncertainty 87.11% 54.92% WideResNet-28-10 ICML 2019
55 Pang2020Boosting Boosting Adversarial Training with Hypersphere Embedding 85.14% 53.74% WideResNet-34-20 NeurIPS 2020
56 Cui2020Learnable_34_20 Learnable Boundary Guided Adversarial Training 88.70% 53.57% WideResNet-34-20 ICCV 2021
57 Zhang2020Attacks Attacks Which Do Not Kill Training Make Adversarial Learning Stronger 84.52% 53.51% WideResNet-34-10 ICML 2020
58 Rice2020Overfitting Overfitting in adversarially robust deep learning 85.34% 53.42% WideResNet-34-20 ICML 2020
59 Huang2020Self Self-Adaptive Training: beyond Empirical Risk Minimization 83.48% 53.34% WideResNet-34-10 NeurIPS 2020
60 Zhang2019Theoretically Theoretically Principled Trade-off between Robustness and Accuracy 84.92% 53.08% WideResNet-34-10 ICML 2019
61 Cui2020Learnable_34_10 Learnable Boundary Guided Adversarial Training 88.22% 52.86% WideResNet-34-10 ICCV 2021
62 Addepalli2022Efficient_RN18 Efficient and Effective Augmentation Strategy for Adversarial Training 85.71% 52.48% ResNet-18 NeurIPS 2022
63 Chen2020Adversarial Adversarial Robustness: From Self-Supervised Pre-Training to Fine-Tuning 86.04% 51.56% ResNet-50
(3x ensemble)
CVPR 2020
64 Chen2020Efficient Efficient Robust Training via Backward Smoothing 85.32% 51.12% WideResNet-34-10 arXiv, Oct 2020
65 Addepalli2021Towards_RN18 Scaling Adversarial Training to Large Perturbation Bounds 80.24% 51.06% ResNet-18 ECCV 2022
66 Sitawarin2020Improving Improving Adversarial Robustness Through Progressive Hardening 86.84% 50.72% WideResNet-34-10 arXiv, Mar 2020
67 Engstrom2019Robustness Robustness library 87.03% 49.25% ResNet-50 GitHub,
Oct 2019
68 Zhang2019You You Only Propagate Once: Accelerating Adversarial Training via Maximal Principle 87.20% 44.83% WideResNet-34-10 NeurIPS 2019
69 Andriushchenko2020Understanding Understanding and Improving Fast Adversarial Training 79.84% 43.93% PreActResNet-18 NeurIPS 2020
70 Wong2020Fast Fast is better than free: Revisiting adversarial training 83.34% 43.21% PreActResNet-18 ICLR 2020
71 Ding2020MMA MMA Training: Direct Input Space Margin Maximization through Adversarial Training 84.36% 41.44% WideResNet-28-4 ICLR 2020
72 Standard Standardly trained model 94.78% 0.00% WideResNet-28-10 N/A

L2, eps=0.5

# Model ID Paper Clean accuracy Robust accuracy Architecture Venue
1 Wang2023Better_WRN-70-16 Better Diffusion Models Further Improve Adversarial Training 95.54% 84.97% WideResNet-70-16 arXiv, Feb 2023
2 Wang2023Better_WRN-28-10 Better Diffusion Models Further Improve Adversarial Training 95.16% 83.68% WideResNet-28-10 arXiv, Feb 2023
3 Rebuffi2021Fixing_70_16_cutmix_extra Fixing Data Augmentation to Improve Adversarial Robustness 95.74% 82.32% WideResNet-70-16 arXiv, Mar 2021
4 Gowal2020Uncovering_extra Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples 94.74% 80.53% WideResNet-70-16 arXiv, Oct 2020
5 Rebuffi2021Fixing_70_16_cutmix_ddpm Fixing Data Augmentation to Improve Adversarial Robustness 92.41% 80.42% WideResNet-70-16 arXiv, Mar 2021
6 Rebuffi2021Fixing_28_10_cutmix_ddpm Fixing Data Augmentation to Improve Adversarial Robustness 91.79% 78.80% WideResNet-28-10 arXiv, Mar 2021
7 Augustin2020Adversarial_34_10_extra Adversarial Robustness on In- and Out-Distribution Improves Explainability 93.96% 78.79% WideResNet-34-10 ECCV 2020
8 Sehwag2021Proxy Robust Learning Meets Generative Models: Can Proxy Distributions Improve Adversarial Robustness? 90.93% 77.24% WideResNet-34-10 ICLR 2022
9 Augustin2020Adversarial_34_10 Adversarial Robustness on In- and Out-Distribution Improves Explainability 92.23% 76.25% WideResNet-34-10 ECCV 2020
10 Rade2021Helper_R18_ddpm Helper-based Adversarial Training: Reducing Excessive Margin to Achieve a Better Accuracy vs. Robustness Trade-off 90.57% 76.15% PreActResNet-18 OpenReview, Jun 2021
11 Rebuffi2021Fixing_R18_cutmix_ddpm Fixing Data Augmentation to Improve Adversarial Robustness 90.33% 75.86% PreActResNet-18 arXiv, Mar 2021
12 Gowal2020Uncovering Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples 90.90% 74.50% WideResNet-70-16 arXiv, Oct 2020
13 Sehwag2021Proxy_R18 Robust Learning Meets Generative Models: Can Proxy Distributions Improve Adversarial Robustness? 89.76% 74.41% ResNet-18 ICLR 2022
14 Wu2020Adversarial Adversarial Weight Perturbation Helps Robust Generalization 88.51% 73.66% WideResNet-34-10 NeurIPS 2020
15 Augustin2020Adversarial Adversarial Robustness on In- and Out-Distribution Improves Explainability 91.08% 72.91% ResNet-50 ECCV 2020
16 Engstrom2019Robustness Robustness library 90.83% 69.24% ResNet-50 GitHub,
Sep 2019
17 Rice2020Overfitting Overfitting in adversarially robust deep learning 88.67% 67.68% PreActResNet-18 ICML 2020
18 Rony2019Decoupling Decoupling Direction and Norm for Efficient Gradient-Based L2 Adversarial Attacks and Defenses 89.05% 66.44% WideResNet-28-10 CVPR 2019
19 Ding2020MMA MMA Training: Direct Input Space Margin Maximization through Adversarial Training 88.02% 66.09% WideResNet-28-4 ICLR 2020
20 Standard Standardly trained model 94.78% 0.00% WideResNet-28-10 N/A

Common Corruptions

# Model ID Paper Clean accuracy Robust accuracy Architecture Venue
1 Diffenderfer2021Winning_LRR_CARD_Deck A Winning Hand: Compressing Deep Networks Can Improve Out-Of-Distribution Robustness 96.56% 92.78% WideResNet-18-2 NeurIPS 2021
2 Diffenderfer2021Winning_LRR A Winning Hand: Compressing Deep Networks Can Improve Out-Of-Distribution Robustness 96.66% 90.94% WideResNet-18-2 NeurIPS 2021
3 Diffenderfer2021Winning_Binary_CARD_Deck A Winning Hand: Compressing Deep Networks Can Improve Out-Of-Distribution Robustness 95.09% 90.15% WideResNet-18-2 NeurIPS 2021
4 Kireev2021Effectiveness_RLATAugMix On the effectiveness of adversarial training against common corruptions 94.75% 89.60% ResNet-18 arXiv, Mar 2021
5 Hendrycks2020AugMix_ResNeXt AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty 95.83% 89.09% ResNeXt29_32x4d ICLR 2020
6 Modas2021PRIMEResNet18 PRIME: A Few Primitives Can Boost Robustness to Common Corruptions 93.06% 89.05% ResNet-18 arXiv, Dec 2021
7 Hendrycks2020AugMix_WRN AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty 95.08% 88.82% WideResNet-40-2 ICLR 2020
8 Kireev2021Effectiveness_RLATAugMixNoJSD On the effectiveness of adversarial training against common corruptions 94.77% 88.53% PreActResNet-18 arXiv, Mar 2021
9 Diffenderfer2021Winning_Binary A Winning Hand: Compressing Deep Networks Can Improve Out-Of-Distribution Robustness 94.87% 88.32% WideResNet-18-2 NeurIPS 2021
10 Rebuffi2021Fixing_70_16_cutmix_extra_L2 Fixing Data Augmentation to Improve Adversarial Robustness 95.74% 88.23% WideResNet-70-16 arXiv, Mar 2021
11 Kireev2021Effectiveness_AugMixNoJSD On the effectiveness of adversarial training against common corruptions 94.97% 86.60% PreActResNet-18 arXiv, Mar 2021
12 Kireev2021Effectiveness_Gauss50percent On the effectiveness of adversarial training against common corruptions 93.24% 85.04% PreActResNet-18 arXiv, Mar 2021
13 Kireev2021Effectiveness_RLAT On the effectiveness of adversarial training against common corruptions 93.10% 84.10% PreActResNet-18 arXiv, Mar 2021
14 Rebuffi2021Fixing_70_16_cutmix_extra_Linf Fixing Data Augmentation to Improve Adversarial Robustness 92.23% 82.82% WideResNet-70-16 arXiv, Mar 2021
15 Addepalli2022Efficient_WRN_34_10 Efficient and Effective Augmentation Strategy for Adversarial Training 88.71% 80.12% WideResNet-34-10 CVPRW 2022
16 Addepalli2021Towards_WRN34 Towards Achieving Adversarial Robustness Beyond Perceptual Limits 85.32% 76.78% WideResNet-34-10 arXiv, Apr 2021
17 Standard Standardly trained model 94.78% 73.46% WideResNet-28-10 N/A

CIFAR-100

Linf, eps=8/255

# Model ID Paper Clean accuracy Robust accuracy Architecture Venue
1 Wang2023Better_WRN-70-16 Better Diffusion Models Further Improve Adversarial Training 75.22% 42.67% WideResNet-70-16 arXiv, Feb 2023
2 Bai2024MixedNUTS MixedNUTS: Training-Free Accuracy-Robustness Balance via Nonlinearly Mixed Classifiers 83.08% 41.80% ResNet-152 + WideResNet-70-16 arXiv, Feb 2024
3 Cui2023Decoupled_WRN-28-10 Decoupled Kullback-Leibler Divergence Loss 73.85% 39.18% WideResNet-28-10 arXiv, May 2023
4 Wang2023Better_WRN-28-10 Better Diffusion Models Further Improve Adversarial Training 72.58% 38.83% WideResNet-28-10 ICML 2023
5 Bai2023Improving_edm Improving the Accuracy-Robustness Trade-off of Classifiers via Adaptive Smoothing 85.21% 38.72% ResNet-152 + WideResNet-70-16 + mixing network SIMODS 2024
6 Gowal2020Uncovering_extra Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples 69.15% 36.88% WideResNet-70-16 arXiv, Oct 2020
7 Bai2023Improving_trades Improving the Accuracy-Robustness Trade-off of Classifiers via Adaptive Smoothing 80.18% 35.15% ResNet-152 + WideResNet-70-16 + mixing network SIMODS 2024
8 Debenedetti2022Light_XCiT-L12 A Light Recipe to Train Robust Vision Transformers 70.76% 35.08% XCiT-L12 arXiv, Sep 2022
9 Rebuffi2021Fixing_70_16_cutmix_ddpm Fixing Data Augmentation to Improve Adversarial Robustness 63.56% 34.64% WideResNet-70-16 arXiv, Mar 2021
10 Debenedetti2022Light_XCiT-M12 A Light Recipe to Train Robust Vision Transformers 69.21% 34.21% XCiT-M12 arXiv, Sep 2022
11 Pang2022Robustness_WRN70_16 Robustness and Accuracy Could Be Reconcilable by (Proper) Definition 65.56% 33.05% WideResNet-70-16 ICML 2022
12 Cui2023Decoupled_WRN-34-10_autoaug Decoupled Kullback-Leibler Divergence Loss 65.93% 32.52% WideResNet-34-10 arXiv, May 2023
13 Debenedetti2022Light_XCiT-S12 A Light Recipe to Train Robust Vision Transformers 67.34% 32.19% XCiT-S12 arXiv, Sep 2022
14 Rebuffi2021Fixing_28_10_cutmix_ddpm Fixing Data Augmentation to Improve Adversarial Robustness 62.41% 32.06% WideResNet-28-10 arXiv, Mar 2021
15 Jia2022LAS-AT_34_20 LAS-AT: Adversarial Training with Learnable Attack Strategy 67.31% 31.91% WideResNet-34-20 arXiv, Mar 2022
16 Addepalli2022Efficient_WRN_34_10 Efficient and Effective Augmentation Strategy for Adversarial Training 68.75% 31.85% WideResNet-34-10 NeurIPS 2022
17 Cui2023Decoupled_WRN-34-10 Decoupled Kullback-Leibler Divergence Loss 64.08% 31.65% WideResNet-34-10 arXiv, May 2023
18 Cui2020Learnable_34_10_LBGAT9_eps_8_255 Learnable Boundary Guided Adversarial Training 62.99% 31.20% WideResNet-34-10 ICCV 2021
19 Sehwag2021Proxy Robust Learning Meets Generative Models: Can Proxy Distributions Improve Adversarial Robustness? 65.93% 31.15% WideResNet-34-10 ICLR 2022
20 Chen2024Data_WRN_34_10 Data filtering for efficient adversarial training 64.32% 31.13% WideResNet-34-10 Pattern Recognition 2024
21 Pang2022Robustness_WRN28_10 Robustness and Accuracy Could Be Reconcilable by (Proper) Definition 63.66% 31.08% WideResNet-28-10 ICML 2022
22 Jia2022LAS-AT_34_10 LAS-AT: Adversarial Training with Learnable Attack Strategy 64.89% 30.77% WideResNet-34-10 arXiv, Mar 2022
23 Chen2021LTD_WRN34_10 LTD: Low Temperature Distillation for Robust Adversarial Training 64.07% 30.59% WideResNet-34-10 arXiv, Nov 2021
24 Addepalli2021Towards_WRN34 Scaling Adversarial Training to Large Perturbation Bounds 65.73% 30.35% WideResNet-34-10 ECCV 2022
25 Cui2020Learnable_34_20_LBGAT6 Learnable Boundary Guided Adversarial Training 62.55% 30.20% WideResNet-34-20 ICCV 2021
26 Gowal2020Uncovering Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples 60.86% 30.03% WideResNet-70-16 arXiv, Oct 2020
27 Cui2020Learnable_34_10_LBGAT6 Learnable Boundary Guided Adversarial Training 60.64% 29.33% WideResNet-34-10 ICCV 2021
28 Rade2021Helper_R18_ddpm Helper-based Adversarial Training: Reducing Excessive Margin to Achieve a Better Accuracy vs. Robustness Trade-off 61.50% 28.88% PreActResNet-18 OpenReview, Jun 2021
29 Wu2020Adversarial Adversarial Weight Perturbation Helps Robust Generalization 60.38% 28.86% WideResNet-34-10 NeurIPS 2020
30 Rebuffi2021Fixing_R18_ddpm Fixing Data Augmentation to Improve Adversarial Robustness 56.87% 28.50% PreActResNet-18 arXiv, Mar 2021
31 Hendrycks2019Using Using Pre-Training Can Improve Model Robustness and Uncertainty 59.23% 28.42% WideResNet-28-10 ICML 2019
32 Addepalli2022Efficient_RN18 Efficient and Effective Augmentation Strategy for Adversarial Training 65.45% 27.67% ResNet-18 NeurIPS 2022
33 Cui2020Learnable_34_10_LBGAT0 Learnable Boundary Guided Adversarial Training 70.25% 27.16% WideResNet-34-10 ICCV 2021
34 Addepalli2021Towards_PARN18 Scaling Adversarial Training to Large Perturbation Bounds 62.02% 27.14% PreActResNet-18 ECCV 2022
35 Chen2020Efficient Efficient Robust Training via Backward Smoothing 62.15% 26.94% WideResNet-34-10 arXiv, Oct 2020
36 Sitawarin2020Improving Improving Adversarial Robustness Through Progressive Hardening 62.82% 24.57% WideResNet-34-10 arXiv, Mar 2020
37 Rice2020Overfitting Overfitting in adversarially robust deep learning 53.83% 18.95% PreActResNet-18 ICML 2020

Corruptions

# Model ID Paper Clean accuracy Robust accuracy Architecture Venue
1 Diffenderfer2021Winning_LRR_CARD_Deck A Winning Hand: Compressing Deep Networks Can Improve Out-Of-Distribution Robustness 79.93% 71.08% WideResNet-18-2 NeurIPS 2021
2 Diffenderfer2021Winning_Binary_CARD_Deck A Winning Hand: Compressing Deep Networks Can Improve Out-Of-Distribution Robustness 78.50% 69.09% WideResNet-18-2 NeurIPS 2021
3 Modas2021PRIMEResNet18 PRIME: A Few Primitives Can Boost Robustness to Common Corruptions 77.60% 68.28% ResNet-18 arXiv, Dec 2021
4 Diffenderfer2021Winning_LRR A Winning Hand: Compressing Deep Networks Can Improve Out-Of-Distribution Robustness 78.41% 66.45% WideResNet-18-2 NeurIPS 2021
5 Diffenderfer2021Winning_Binary A Winning Hand: Compressing Deep Networks Can Improve Out-Of-Distribution Robustness 77.69% 65.26% WideResNet-18-2 NeurIPS 2021
6 Hendrycks2020AugMix_ResNeXt AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty 78.90% 65.14% ResNeXt29_32x4d ICLR 2020
7 Hendrycks2020AugMix_WRN AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty 76.28% 64.11% WideResNet-40-2 ICLR 2020
8 Addepalli2022Efficient_WRN_34_10 Efficient and Effective Augmentation Strategy for Adversarial Training 68.75% 56.95% WideResNet-34-10 CVPRW 2022
9 Gowal2020Uncovering_extra_Linf Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples 69.15% 56.00% WideResNet-70-16 arXiv, Oct 2020
10 Addepalli2021Towards_WRN34 Towards Achieving Adversarial Robustness Beyond Perceptual Limits 65.73% 54.88% WideResNet-34-10 OpenReview, Jun 2021
11 Addepalli2021Towards_PARN18 Towards Achieving Adversarial Robustness Beyond Perceptual Limits 62.02% 51.77% PreActResNet-18 OpenReview, Jun 2021
12 Gowal2020Uncovering_Linf Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples 60.86% 49.46% WideResNet-70-16 arXiv, Oct 2020

ImageNet

Note: the values (even clean accuracy) might have small fluctuations depending on the version of the packages e.g. torchvision.

Linf, eps=4/255

# Model ID Paper Clean accuracy Robust accuracy Architecture Venue
1 Liu2023Comprehensive_Swin-L A Comprehensive Study on Robustness of Image Classification Models: Benchmarking and Rethinking 78.92% 59.56% Swin-L arXiv, Feb 2023
2 Bai2024MixedNUTS MixedNUTS: Training-Free Accuracy-Robustness Balance via Nonlinearly Mixed Classifiers 81.48% 58.50% ConvNeXtV2-L + Swin-L arXiv, Feb 2024
3 Liu2023Comprehensive_ConvNeXt-L A Comprehensive Study on Robustness of Image Classification Models: Benchmarking and Rethinking 78.02% 58.48% ConvNeXt-L arXiv, Feb 2023
4 Singh2023Revisiting_ConvNeXt-L-ConvStem Revisiting Adversarial Training for ImageNet: Architectures, Training and Generalization across Threat Models 77.00% 57.70% ConvNeXt-L + ConvStem NeurIPS 2023
5 Liu2023Comprehensive_Swin-B A Comprehensive Study on Robustness of Image Classification Models: Benchmarking and Rethinking 76.16% 56.16% Swin-B arXiv, Feb 2023
6 Singh2023Revisiting_ConvNeXt-B-ConvStem Revisiting Adversarial Training for ImageNet: Architectures, Training and Generalization across Threat Models 75.90% 56.14% ConvNeXt-B + ConvStem NeurIPS 2023
7 Liu2023Comprehensive_ConvNeXt-B A Comprehensive Study on Robustness of Image Classification Models: Benchmarking and Rethinking 76.02% 55.82% ConvNeXt-B arXiv, Feb 2023
8 Singh2023Revisiting_ViT-B-ConvStem Revisiting Adversarial Training for ImageNet: Architectures, Training and Generalization across Threat Models 76.30% 54.66% ViT-B + ConvStem NeurIPS 2023
9 Singh2023Revisiting_ConvNeXt-S-ConvStem Revisiting Adversarial Training for ImageNet: Architectures, Training and Generalization across Threat Models 74.10% 52.42% ConvNeXt-S + ConvStem NeurIPS 2023
10 Singh2023Revisiting_ConvNeXt-T-ConvStem Revisiting Adversarial Training for ImageNet: Architectures, Training and Generalization across Threat Models 72.72% 49.46% ConvNeXt-T + ConvStem NeurIPS 2023
11 Peng2023Robust Robust Principles: Architectural Design Principles for Adversarially Robust CNNs 73.44% 48.94% RaWideResNet-101-2 BMVC 2023
12 Singh2023Revisiting_ViT-S-ConvStem Revisiting Adversarial Training for ImageNet: Architectures, Training and Generalization across Threat Models 72.56% 48.08% ViT-S + ConvStem NeurIPS 2023
13 Debenedetti2022Light_XCiT-L12 A Light Recipe to Train Robust Vision Transformers 73.76% 47.60% XCiT-L12 arXiv, Sep 2022
14 Debenedetti2022Light_XCiT-M12 A Light Recipe to Train Robust Vision Transformers 74.04% 45.24% XCiT-M12 arXiv, Sep 2022
15 Debenedetti2022Light_XCiT-S12 A Light Recipe to Train Robust Vision Transformers 72.34% 41.78% XCiT-S12 arXiv, Sep 2022
16 Chen2024Data_WRN_50_2 Data filtering for efficient adversarial training 68.76% 40.60% WideResNet-50-2 Pattern Recognition 2024
17 Salman2020Do_50_2 Do Adversarially Robust ImageNet Models Transfer Better? 68.46% 38.14% WideResNet-50-2 NeurIPS 2020
18 Salman2020Do_R50 Do Adversarially Robust ImageNet Models Transfer Better? 64.02% 34.96% ResNet-50 NeurIPS 2020
19 Engstrom2019Robustness Robustness library 62.56% 29.22% ResNet-50 GitHub,
Oct 2019
20 Wong2020Fast Fast is better than free: Revisiting adversarial training 55.62% 26.24% ResNet-50 ICLR 2020
21 Salman2020Do_R18 Do Adversarially Robust ImageNet Models Transfer Better? 52.92% 25.32% ResNet-18 NeurIPS 2020
22 Standard_R50 Standardly trained model 76.52% 0.00% ResNet-50 N/A

Corruptions (ImageNet-C & ImageNet-3DCC)

# Model ID Paper Clean accuracy Robust accuracy Architecture Venue
1 Tian2022Deeper_DeiT-B Deeper Insights into the Robustness of ViTs towards Common Corruptions 81.38% 67.55% DeiT Base arXiv, Apr 2022
2 Tian2022Deeper_DeiT-S Deeper Insights into the Robustness of ViTs towards Common Corruptions 79.76% 62.91% DeiT Small arXiv, Apr 2022
3 Erichson2022NoisyMix_new NoisyMix: Boosting Robustness by Combining Data Augmentations, Stability Training, and Noise Injections 76.90% 53.28% ResNet-50 arXiv, Feb 2022
4 Hendrycks2020Many The Many Faces of Robustness: A Critical Analysis of Out-of-Distribution Generalization 76.86% 52.90% ResNet-50 ICCV 2021
5 Erichson2022NoisyMix NoisyMix: Boosting Robustness by Combining Data Augmentations, Stability Training, and Noise Injections 76.98% 52.47% ResNet-50 arXiv, Feb 2022
6 Hendrycks2020AugMix AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty 77.34% 49.33% ResNet-50 ICLR 2020
7 Geirhos2018_SIN_IN ImageNet-trained CNNs are biased towards texture; increasing shape bias improves accuracy and robustness 74.98% 45.76% ResNet-50 ICLR 2019
8 Geirhos2018_SIN_IN_IN ImageNet-trained CNNs are biased towards texture; increasing shape bias improves accuracy and robustness 77.56% 42.00% ResNet-50 ICLR 2019
9 Geirhos2018_SIN ImageNet-trained CNNs are biased towards texture; increasing shape bias improves accuracy and robustness 60.08% 39.92% ResNet-50 ICLR 2019
10 Standard_R50 Standardly trained model 76.72% 39.48% ResNet-50 N/A
11 Salman2020Do_50_2_Linf Do Adversarially Robust ImageNet Models Transfer Better? 68.64% 36.09% WideResNet-50-2 NeurIPS 2020
12 AlexNet ImageNet Classification with Deep Convolutional Neural Networks 56.24% 21.12% AlexNet NeurIPS 2012

Notebooks

We host all the notebooks at Google Colab:

  • RobustBench: quick start: a quick tutorial to get started that illustrates the main features of RobustBench.
  • RobustBench: json stats: various plots based on the jsons from model_info (robustness over venues, robustness vs accuracy, etc).

Feel free to suggest a new notebook based on the Model Zoo or the jsons from model_info. We are very interested in collecting new insights about benefits and tradeoffs between different perturbation types.

How to contribute

Contributions to RobustBench are very welcome! You can help to improve RobustBench:

  • Are you an author of a recent paper focusing on improving adversarial robustness? Consider adding new models (see the instructions below 👇).
  • Do you have in mind some better standardized attack? Do you want to extend RobustBench to other threat models? We'll be glad to discuss that!
  • Do you have an idea how to make the existing codebase better? Just open a pull request or create an issue and we'll be happy to discuss potential changes.

Adding a new evaluation

In case you have some new (potentially, adaptive) evaluation that leads to a lower robust accuracy than AutoAttack, we will be happy to add it to the leaderboard. The easiest way is to open an issue with the "New external evaluation(s)" template and fill in all the fields.

Adding a new model

Public model submission (Leaderboard + Model Zoo)

The easiest way to add new models to the leaderboard and/or to the model zoo, is by opening an issue with the "New Model(s)" template and fill in all the fields.

In the following sections there are some tips on how to prepare the claim.

Claim

The claim can be computed in the following way (example for cifar10, Linf threat model):

import torch

from robustbench import benchmark
from myrobust model import MyRobustModel

threat_model = "Linf"  # one of {"Linf", "L2", "corruptions"}
dataset = "cifar10"  # one of {"cifar10", "cifar100", "imagenet"}

model = MyRobustModel()
model_name = "<Name><Year><FirstWordOfTheTitle>"
device = torch.device("cuda:0")

clean_acc, robust_acc = benchmark(model, model_name=model_name, n_examples=10000, dataset=dataset,
                                  threat_model=threat_model, eps=8/255, device=device,
                                  to_disk=True)

In particular, the to_disk argument, if True, generates a json file at the path model_info/<dataset>/<threat_model>/<Name><Year><FirstWordOfTheTitle>.json which is structured in the following way (example from model_info/cifar10/Linf/Rice2020Overfitting.json):

{
  "link": "https://arxiv.org/abs/2002.11569",
  "name": "Overfitting in adversarially robust deep learning",
  "authors": "Leslie Rice, Eric Wong, J. Zico Kolter",
  "additional_data": false,
  "number_forward_passes": 1,
  "dataset": "cifar10",
  "venue": "ICML 2020",
  "architecture": "WideResNet-34-20",
  "eps": "8/255",
  "clean_acc": "85.34",
  "reported": "58",
  "autoattack_acc": "53.42"
}

The only difference is that the generated json will have only the fields "clean_acc" and "autoattack_acc" (for "Linf" and "L2" threat models) or "corruptions_acc" (for the "corruptions" threat model) already specified. The other fields have to be filled manually.

If the given threat_model is corruptions, we also save unaggregated results on the different combinations of corruption types and severities in this csv file (for CIFAR-10).

For ImageNet benchmarks, the users should specify what preprocessing should be used (e.g. resize and crop to the needed resolution). There are some preprocessings already defined in robustbench.data.PREPROCESSINGS, which can be used by specifying the key as the preprocessing parameter of benchmark. Otherwise, it's possible to pass an arbitrary torchvision transform (or torchvision-compatible transform), e.g.:

transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])
clean_acc, robust_acc = benchmark(model, model_name=model_name, n_examples=10000, dataset=dataset,
                                  threat_model=threat_model, eps=8/255, device=device,
                                  to_disk=True, preprocessing=transform)
Model definition

In case you want to add a model in the Model Zoo by yourself, then you should also open a PR with the new model(s) you would like to add. All the models of each <dataset> are saved in robustbench/model_zoo/<dataset>.py. Each file contains a dictionary for every threat model, where the keys are the identifiers of each model, and the values are either class constructors, for models that have to change standard architectures, or lambda functions that return the constructed model.

If your model is a standard architecture (e.g., WideResNet), does not apply any normalization to the input nor has to do things differently from the standard architecture, consider adding your model as a lambda function, e.g.

('Cui2020Learnable_34_10', {
    'model': lambda: WideResNet(depth=34, widen_factor=10, sub_block1=True),
    'gdrive_id': '16s9pi_1QgMbFLISVvaVUiNfCzah6g2YV'
})

If your model is a standard architecture, but you need to do something differently (e.g. applying normalization), consider inheriting the class defined in wide_resnet.py or resnet.py. For example:

class Rice2020OverfittingNet(WideResNet):
    def __init__(self, depth, widen_factor):
        super(Rice2020OverfittingNet, self).__init__(depth=depth, widen_factor=widen_factor,
                                                     sub_block1=False)
        self.mu = torch.Tensor([0.4914, 0.4822, 0.4465]).float().view(3, 1, 1).cuda()
        self.sigma = torch.Tensor([0.2471, 0.2435, 0.2616]).float().view(3, 1, 1).cuda()

    def forward(self, x):
        x = (x - self.mu) / self.sigma
        return super(Rice2020OverfittingNet, self).forward(x)

If instead you need to create a new architecture, please put it in robustbench/model_zoo/archietectures/<my_architecture>.py.

Model checkpoint

You should also add your model entry in the corresponding <threat_model> dict in the file robustbench/model_zoo/<dataset>.py. For instance, let's say your model is robust against common corruptions in CIFAR-10 (i.e. CIFAR-10-C), then you should add your model to the common_corruptions dict in robustbench/model_zoo/cifar10.py.

The model should also contain the Google Drive ID with your PyTorch model so that it can be downloaded automatically from Google Drive:

    ('Rice2020Overfitting', {
        'model': Rice2020OverfittingNet(34, 20),
        'gdrive_id': '1vC_Twazji7lBjeMQvAD9uEQxi9Nx2oG-',
})

Private model submission (leaderboard only)

In case you want to keep your checkpoints private for some reasons, you can also submit your claim by opening an issue with the same "New Model(s)" template, specifying that the submission is private, and sharing the checkpoints with the email address adversarial.benchmark@gmail.com. In this case, we will add your model to the leaderboard but not to the Model Zoo and will not share your checkpoints publicly.

License of the models

By default, the models are released under the MIT license, but you can also tell us if you want to release your model under a customized license.

Automatic tests

In order to run the tests, run:

  • python -m unittest discover tests -t . -v for fast testing
  • RUN_SLOW=true python -m unittest discover tests -t . -v for slower testing

For example, one can test if the clean accuracy on 200 examples exceeds some threshold (70%) or if clean accuracy on 10'000 examples for each model matches the ones from the jsons located at robustbench/model_info.

Note that one can specify some configurations like batch_size, data_dir, model_dir in tests/config.py for running the tests.

Citation

Would you like to reference the RobustBench leaderboard or you are using models from the Model Zoo?
Then consider citing our whitepaper:

@inproceedings{croce2021robustbench,
  title     = {RobustBench: a standardized adversarial robustness benchmark},
  author    = {Croce, Francesco and Andriushchenko, Maksym and Sehwag, Vikash and Debenedetti, Edoardo and Flammarion, Nicolas and Chiang, Mung and Mittal, Prateek and Matthias Hein},
  booktitle = {Thirty-fifth Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
  year      = {2021},
  url       = {https://openreview.net/forum?id=SSKZPJCt7B}
}

Contact

Feel free to contact us about anything related to RobustBench by creating an issue, a pull request or by email at adversarial.benchmark@gmail.com.