This is the source code for the paper Robust Neural Text Classification and Entailment via Mixup Regularized Adversarial Training
(MRAT). In Proceedings of the 44th International ACM SIGIR Conference on Research and Development in Information Retrieval (SIGIR ’21). [link]
-
Create environment and install requirement packages using provided
environment.yml
:conda env create -f environment.yml conda activate MRAT
-
Download datasets. We follow the practice of textfooler and use the dataset they provided.
- Download
mr.zip
andsnli.zip
datasets from googledrive. - Extract
mr.zip
totrain/data/mr
and extractsnli.zip
totrain/data/snli
. - Download 1k split datasets of
mr
andsnli
from url. - Put
mr
andsnli
toattack/data
.
- Download
- [Optional] For ease of replication, we shared
adversarial examples
andtrained BERT model
. Details see here.
-
Train victim model. Alternatively, you can download our trained victim model from here.
cd train/train_command_mr/ python bert_mr_normal.py
For
snli
dataset we do not need this step. We use thebert-base-uncased-snli
from TextAttack Model Zoo as attack target model on snli dataset. -
Test the adversarial robustness of the victim model.
cd ../../ cd attack/attack_command_mr/ python attack_bert_mr_test_textbugger.py python attack_bert_mr_test_deepwordbug.py python attack_bert_mr_test_textfooler.py
-
Attack victim model on training set and save generate adversarial examples. Alternatively, you can download our generated
adversarial examples
from here.python attack_bert_mr_train_textbugger.py python attack_bert_mr_train_deepwordbug.py python attack_bert_mr_train_textfooler.py
-
Train
MRAT
andMRAT+
models. Alternatively, you can download our trained models from here.cd ../../ cd train/train_command_mr/ python bert_mr_mix_multi.py # MRAT python bert_mr_mixN_multi.py # MRAT+
-
Test the adversarial robustness of
MRAT
andMRAT+
.cd ../../ cd attack/attack_command_mr/ # MRAT python attack_bert_mr_test_mix-multi_textbugger.py python attack_bert_mr_test_mix-multi_deepwordbug.py python attack_bert_mr_test_mix-multi_textfooler.py # MRAT+ python attack_bert_mr_test_mixN-multi_textbugger.py python attack_bert_mr_test_mixN-multi_deepwordbug.py python attack_bert_mr_test_mixN-multi_textfooler.py
- Above steps uses
mr
dataset as example. Forsnli
dataset, the usage is the same. Just changeattack_command_mr
andtrain_command_mr
toattack_command_snli
andtrain_command_snli
folds and run the python scripts inside. - Follows the above steps, you can reproduce the results in
Table 2
of our paper. For additional results, a little change of argument setting in this scripts may be needed. - The code has been tested on CentOS 7 using multi-GPU. For single GPU usage, a little change may be needed.
- If you need helps, feel free to open an issue. 😊
🎉A huge thanks to TextAttack!~ The most of the code in this project is derived from the helpful toolbox TextAttack. Because TextAttack update frequently, we leave a copy in attack\textattack
fold.