-
Notifications
You must be signed in to change notification settings - Fork 375
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #80 from QData/attack-results-as-goal-functions
Attack results take goal function result
- Loading branch information
Showing
161 changed files
with
10,387 additions
and
164 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Configuration file for the Sphinx documentation builder. | ||
# | ||
# This file only contains a selection of the most common options. For a full | ||
# list see the documentation: | ||
# https://www.sphinx-doc.org/en/master/usage/configuration.html | ||
|
||
# -- Path setup -------------------------------------------------------------- | ||
|
||
# If extensions (or modules to document with autodoc) are in another directory, | ||
# add these directories to sys.path here. If the directory is relative to the | ||
# documentation root, use os.path.abspath to make it absolute, like shown here. | ||
# | ||
import os | ||
import sys | ||
sys.path.insert(0, os.path.abspath('../')) | ||
|
||
# -- Project information ----------------------------------------------------- | ||
|
||
project = 'TextAttack' | ||
copyright = '2019, UVA QData Lab' | ||
author = 'UVA QData Lab' | ||
|
||
# The full version, including alpha/beta/rc tags | ||
release = '0.0.1' | ||
|
||
# Set master doc to `index.rst`. | ||
master_doc = 'index' | ||
|
||
# -- General configuration --------------------------------------------------- | ||
|
||
# Add any Sphinx extension module names here, as strings. They can be | ||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom | ||
# ones. | ||
extensions = [ | ||
'sphinx.ext.viewcode', | ||
'sphinx.ext.autodoc', | ||
'sphinx.ext.napoleon', | ||
"sphinx_rtd_theme" | ||
] | ||
|
||
# Add any paths that contain templates here, relative to this directory. | ||
templates_path = ['_templates'] | ||
|
||
# List of patterns, relative to source directory, that match files and | ||
# directories to ignore when looking for source files. | ||
# This pattern also affects html_static_path and html_extra_path. | ||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] | ||
|
||
|
||
# Mock language_check to stop issues with Sphinx not loading it | ||
autodoc_mock_imports = ["language_check"] | ||
|
||
|
||
|
||
# -- Options for HTML output ------------------------------------------------- | ||
|
||
# The theme to use for HTML and HTML Help pages. See the documentation for | ||
# a list of builtin themes. | ||
# | ||
html_theme = 'sphinx_rtd_theme' | ||
|
||
# Add any paths that contain custom static files (such as style sheets) here, | ||
# relative to this directory. They are copied after the builtin static files, | ||
# so a file named "default.css" will overwrite the builtin "default.css". | ||
html_static_path = ['_static'] |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import os | ||
from test_models import CommandLineTest | ||
|
||
tests = [] | ||
|
||
def register_test(command, name=None, output_file=None, desc=None): | ||
if not os.path.exists(output_file): | ||
raise FileNotFoundError(f'Error creating test {name}: cannot find file {output_file}.') | ||
output = open(output_file).read() | ||
tests.append(CommandLineTest( | ||
command, name=name, output=output, desc=desc | ||
)) | ||
|
||
|
||
####################################### | ||
## BEGIN TESTS ## | ||
####################################### | ||
|
||
# | ||
# test: run_attack_parallel textfooler attack on 10 samples from BERT MR | ||
# (takes about 81s) | ||
# | ||
register_test('python scripts/run_attack.py --model bert-mr --recipe textfooler --num_examples 10', | ||
name='run_attack_textfooler_bert_mr_10', | ||
output_file='local_tests/outputs/run_attack_textfooler_bert_mr_10.txt', | ||
desc='Runs attack using TextFooler recipe on BERT using 10 examples from the MR dataset') | ||
|
||
# | ||
# test: run_attack_parallel textfooler attack on 10 samples from BERT SNLI | ||
# (takes about 51s) | ||
# | ||
register_test('python scripts/run_attack.py --model bert-snli --recipe textfooler --num_examples 10', | ||
name='run_attack_textfooler_bert_snli_10', | ||
output_file='local_tests/outputs/run_attack_textfooler_bert_snli_10.txt', | ||
desc='Runs attack using TextFooler recipe on BERT using 10 examples from the SNLI dataset') | ||
|
||
# | ||
# test: run_attack deepwordbug attack on 10 samples from LSTM MR | ||
# (takes about 41s) | ||
# | ||
register_test('python scripts/run_attack.py --model lstm-mr --recipe deepwordbug --num_examples 10', | ||
name='run_attack_deepwordbug_lstm_mr_10', | ||
output_file='local_tests/outputs/run_attack_deepwordbug_lstm_mr_10.txt', | ||
desc='Runs attack using DeepWordBUG recipe on LSTM using 10 examples from the MR dataset') | ||
|
||
# | ||
# test: run_attack targeted classification of class 2 on BERT MNLI with enable_csv | ||
# and attack_n set, using the WordNet transformation and beam search with | ||
# beam width 2, using language tool constraint, on 10 samples | ||
# (takes about 171s) | ||
# | ||
register_test(('python scripts/run_attack.py --attack_n --goal_function targeted-classification:target_class=2 ' | ||
'--enable_csv --model bert-mnli --num_examples 10 --transformation word-swap-wordnet ' | ||
'--constraints lang-tool --attack beam-search:beam_width=2'), | ||
name='run_attack_targeted2_bertmnli_wordnet_beamwidth_2_enablecsv_attackn', | ||
output_file='local_tests/outputs/run_attack_targetedclassification2_wordnet_langtool_enable_csv_beamsearch2_attack_n_10.txt', | ||
desc=('Runs attack using targeted classification on class 2 on BERT MNLI with' | ||
'enable_csv and attack_n set, using the WordNet transformation and beam ' | ||
'search with beam width 2, using language tool constraint, on 10 samples') | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import os | ||
from test_models import PythonFunctionTest | ||
|
||
tests = [] | ||
|
||
def register_test(function, name=None, output_file=None, desc=None): | ||
if not os.path.exists(output_file): | ||
raise FileNotFoundError(f'Error creating test {name}: cannot find file {output_file}.') | ||
output = open(output_file).read() | ||
tests.append(PythonFunctionTest( | ||
function, name=name, output=output, desc=desc | ||
)) | ||
|
||
|
||
####################################### | ||
## BEGIN TESTS ## | ||
####################################### | ||
|
||
# | ||
# | ||
# | ||
def check_gpu_count(): | ||
import torch | ||
num_gpus = torch.cuda.device_count() | ||
if num_gpus == 0: | ||
print(f'Error: detected 0 GPUs. Must run local tests with multiple GPUs. Perhaps you need to configure CUDA?') | ||
|
||
register_test(check_gpu_count, name='check CUDA', | ||
output_file='local_tests/outputs/empty_file.txt', | ||
desc='Makes sure CUDA is enabled, properly configured, and detects at least 1 GPU') | ||
|
||
# | ||
# test: import textattack | ||
# | ||
def import_textattack(): | ||
import textattack | ||
|
||
register_test(import_textattack, name='import textattack', | ||
output_file='local_tests/outputs/empty_file.txt', | ||
desc='Makes sure the textattack module can be imported') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import os | ||
import time | ||
|
||
from test_models import color_text | ||
|
||
def log_sep(): | ||
print('\n' + ('-' * 60) + '\n') | ||
|
||
def print_gray(s): | ||
print(color_text(s, 'light_gray')) | ||
|
||
def main(): | ||
# Change to TextAttack root directory. | ||
this_file_path = os.path.abspath(__file__) | ||
test_directory_name = os.path.dirname(this_file_path) | ||
textattack_root_directory_name = os.path.dirname(test_directory_name) | ||
os.chdir(textattack_root_directory_name) | ||
print_gray(f'Executing tests from {textattack_root_directory_name}.') | ||
|
||
# Execute tests. | ||
start_time = time.time() | ||
passed_tests = 0 | ||
|
||
from tests import tests | ||
for test in tests: | ||
log_sep() | ||
test_passed = test() | ||
if test_passed: | ||
passed_tests += 1 | ||
log_sep() | ||
end_time = time.time() | ||
print_gray(f'Passed {passed_tests}/{len(tests)} in {end_time-start_time}s.') | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
# @TODO add argparser and test sizes. | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import colored | ||
import io | ||
import os | ||
import re | ||
import sys | ||
import subprocess | ||
|
||
def color_text(s, color): | ||
return colored.stylize(s, colored.fg(color)) | ||
|
||
FNULL = open(os.devnull, 'w') | ||
|
||
MAGIC_STRING = '/.*/' | ||
def compare_outputs(true_output, test_output): | ||
""" Desired have the magic string '/.*/' inserted wherever the output | ||
at that position doesn't actually matter. (For example, | ||
when the time to execute is printed, or another non-deterministic | ||
feature of the program.) | ||
`compare_outputs` makes sure all of the outputs match in between | ||
the magic strings. If they do, it returns True. | ||
""" | ||
output_pieces = true_output.split(MAGIC_STRING) | ||
for piece in output_pieces: | ||
index_in_test = test_output.find(piece) | ||
if index_in_test < 0: | ||
return False | ||
else: | ||
test_output = test_output[index_in_test + len(piece):] | ||
return True | ||
|
||
class TextAttackTest: | ||
def __init__(self, name=None, output=None, desc=None): | ||
if name is None: | ||
raise ValueError('Cannot initialize TextAttackTest without name') | ||
if output is None: | ||
raise ValueError('Cannot initialize TextAttackTest without output') | ||
if desc is None: | ||
raise ValueError('Cannot initialize TextAttackTest without description') | ||
self.name = name | ||
self.output = output | ||
self.desc = desc | ||
|
||
def execute(self): | ||
""" Executes test and returns test output. To be implemented by | ||
subclasses. | ||
""" | ||
raise NotImplementedError() | ||
|
||
def __call__(self): | ||
""" Runs test and prints success or failure. """ | ||
self.log_start() | ||
test_output = self.execute() | ||
if compare_outputs(self.output, test_output): | ||
self.log_success() | ||
return True | ||
else: | ||
self.log_failure(test_output) | ||
return False | ||
|
||
def log_start(self): | ||
print(f'Executing test {color_text(self.name, "blue")}.') | ||
|
||
def log_success(self): | ||
success_text = f'✓ Succeeded.' | ||
print(color_text(success_text, 'green')) | ||
|
||
def log_failure(self, test_output): | ||
fail_text = f'✗ Failed.' | ||
print(color_text(fail_text, 'red')) | ||
print('\n') | ||
print(f'Test output: {test_output}.') | ||
print(f'Correct output: {self.output}.') | ||
|
||
class CommandLineTest(TextAttackTest): | ||
""" Runs a command-line command to check for desired output. """ | ||
def __init__(self, command, name=None, output=None, desc=None): | ||
if command is None: | ||
raise ValueError('Cannot initialize CommandLineTest without command') | ||
self.command = command | ||
super().__init__(name=name, output=output, desc=desc) | ||
|
||
def execute(self): | ||
result = subprocess.run( | ||
self.command.split(), | ||
stdout=subprocess.PIPE, | ||
# @TODO: Collect stderr somewhere. In the event of an error, point user to the error file. | ||
stderr=FNULL | ||
) | ||
return result.stdout.decode() | ||
|
||
class Capturing(list): | ||
""" A context manager that captures standard out during its execution. | ||
stackoverflow.com/questions/16571150/how-to-capture-stdout-output-from-a-python-function-call | ||
""" | ||
def __enter__(self): | ||
self._stdout = sys.stdout | ||
sys.stdout = self._stringio = io.StringIO() | ||
return self | ||
def __exit__(self, *args): | ||
self.extend(self._stringio.getvalue().splitlines()) | ||
del self._stringio # free up some memory | ||
sys.stdout = self._stdout | ||
|
||
class PythonFunctionTest(TextAttackTest): | ||
""" Runs a Python function to check for desired output. """ | ||
def __init__(self, function, name=None, output=None, desc=None): | ||
if function is None: | ||
raise ValueError('Cannot initialize PythonFunctionTest without function') | ||
self.function = function | ||
super().__init__(name=name, output=output, desc=desc) | ||
|
||
def execute(self): | ||
with Capturing() as output_lines: | ||
self.function() | ||
output = '\n'.join(output_lines) | ||
return output | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
tests = [] | ||
|
||
import python_function_tests | ||
tests.extend(python_function_tests.tests) | ||
|
||
import command_line_tests | ||
tests.extend(command_line_tests.tests) |
Oops, something went wrong.