Skip to content

Commit

Permalink
Merge pull request #80 from QData/attack-results-as-goal-functions
Browse files Browse the repository at this point in the history
Attack results take goal function result
  • Loading branch information
jxmorris12 committed May 2, 2020
2 parents a17054d + 1604290 commit a8eeb73
Show file tree
Hide file tree
Showing 161 changed files with 10,387 additions and 164 deletions.
65 changes: 65 additions & 0 deletions build/lib/docs/conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html

# -- Path setup --------------------------------------------------------------

# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
sys.path.insert(0, os.path.abspath('../'))

# -- Project information -----------------------------------------------------

project = 'TextAttack'
copyright = '2019, UVA QData Lab'
author = 'UVA QData Lab'

# The full version, including alpha/beta/rc tags
release = '0.0.1'

# Set master doc to `index.rst`.
master_doc = 'index'

# -- General configuration ---------------------------------------------------

# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.viewcode',
'sphinx.ext.autodoc',
'sphinx.ext.napoleon',
"sphinx_rtd_theme"
]

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']


# Mock language_check to stop issues with Sphinx not loading it
autodoc_mock_imports = ["language_check"]



# -- Options for HTML output -------------------------------------------------

# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'

# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
File renamed without changes.
60 changes: 60 additions & 0 deletions build/lib/local_tests/command_line_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os
from test_models import CommandLineTest

tests = []

def register_test(command, name=None, output_file=None, desc=None):
if not os.path.exists(output_file):
raise FileNotFoundError(f'Error creating test {name}: cannot find file {output_file}.')
output = open(output_file).read()
tests.append(CommandLineTest(
command, name=name, output=output, desc=desc
))


#######################################
## BEGIN TESTS ##
#######################################

#
# test: run_attack_parallel textfooler attack on 10 samples from BERT MR
# (takes about 81s)
#
register_test('python scripts/run_attack.py --model bert-mr --recipe textfooler --num_examples 10',
name='run_attack_textfooler_bert_mr_10',
output_file='local_tests/outputs/run_attack_textfooler_bert_mr_10.txt',
desc='Runs attack using TextFooler recipe on BERT using 10 examples from the MR dataset')

#
# test: run_attack_parallel textfooler attack on 10 samples from BERT SNLI
# (takes about 51s)
#
register_test('python scripts/run_attack.py --model bert-snli --recipe textfooler --num_examples 10',
name='run_attack_textfooler_bert_snli_10',
output_file='local_tests/outputs/run_attack_textfooler_bert_snli_10.txt',
desc='Runs attack using TextFooler recipe on BERT using 10 examples from the SNLI dataset')

#
# test: run_attack deepwordbug attack on 10 samples from LSTM MR
# (takes about 41s)
#
register_test('python scripts/run_attack.py --model lstm-mr --recipe deepwordbug --num_examples 10',
name='run_attack_deepwordbug_lstm_mr_10',
output_file='local_tests/outputs/run_attack_deepwordbug_lstm_mr_10.txt',
desc='Runs attack using DeepWordBUG recipe on LSTM using 10 examples from the MR dataset')

#
# test: run_attack targeted classification of class 2 on BERT MNLI with enable_csv
# and attack_n set, using the WordNet transformation and beam search with
# beam width 2, using language tool constraint, on 10 samples
# (takes about 171s)
#
register_test(('python scripts/run_attack.py --attack_n --goal_function targeted-classification:target_class=2 '
'--enable_csv --model bert-mnli --num_examples 10 --transformation word-swap-wordnet '
'--constraints lang-tool --attack beam-search:beam_width=2'),
name='run_attack_targeted2_bertmnli_wordnet_beamwidth_2_enablecsv_attackn',
output_file='local_tests/outputs/run_attack_targetedclassification2_wordnet_langtool_enable_csv_beamsearch2_attack_n_10.txt',
desc=('Runs attack using targeted classification on class 2 on BERT MNLI with'
'enable_csv and attack_n set, using the WordNet transformation and beam '
'search with beam width 2, using language tool constraint, on 10 samples')
)
40 changes: 40 additions & 0 deletions build/lib/local_tests/python_function_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
from test_models import PythonFunctionTest

tests = []

def register_test(function, name=None, output_file=None, desc=None):
if not os.path.exists(output_file):
raise FileNotFoundError(f'Error creating test {name}: cannot find file {output_file}.')
output = open(output_file).read()
tests.append(PythonFunctionTest(
function, name=name, output=output, desc=desc
))


#######################################
## BEGIN TESTS ##
#######################################

#
#
#
def check_gpu_count():
import torch
num_gpus = torch.cuda.device_count()
if num_gpus == 0:
print(f'Error: detected 0 GPUs. Must run local tests with multiple GPUs. Perhaps you need to configure CUDA?')

register_test(check_gpu_count, name='check CUDA',
output_file='local_tests/outputs/empty_file.txt',
desc='Makes sure CUDA is enabled, properly configured, and detects at least 1 GPU')

#
# test: import textattack
#
def import_textattack():
import textattack

register_test(import_textattack, name='import textattack',
output_file='local_tests/outputs/empty_file.txt',
desc='Makes sure the textattack module can be imported')
38 changes: 38 additions & 0 deletions build/lib/local_tests/run_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
import time

from test_models import color_text

def log_sep():
print('\n' + ('-' * 60) + '\n')

def print_gray(s):
print(color_text(s, 'light_gray'))

def main():
# Change to TextAttack root directory.
this_file_path = os.path.abspath(__file__)
test_directory_name = os.path.dirname(this_file_path)
textattack_root_directory_name = os.path.dirname(test_directory_name)
os.chdir(textattack_root_directory_name)
print_gray(f'Executing tests from {textattack_root_directory_name}.')

# Execute tests.
start_time = time.time()
passed_tests = 0

from tests import tests
for test in tests:
log_sep()
test_passed = test()
if test_passed:
passed_tests += 1
log_sep()
end_time = time.time()
print_gray(f'Passed {passed_tests}/{len(tests)} in {end_time-start_time}s.')



if __name__ == '__main__':
# @TODO add argparser and test sizes.
main()
120 changes: 120 additions & 0 deletions build/lib/local_tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import colored
import io
import os
import re
import sys
import subprocess

def color_text(s, color):
return colored.stylize(s, colored.fg(color))

FNULL = open(os.devnull, 'w')

MAGIC_STRING = '/.*/'
def compare_outputs(true_output, test_output):
""" Desired have the magic string '/.*/' inserted wherever the output
at that position doesn't actually matter. (For example,
when the time to execute is printed, or another non-deterministic
feature of the program.)
`compare_outputs` makes sure all of the outputs match in between
the magic strings. If they do, it returns True.
"""
output_pieces = true_output.split(MAGIC_STRING)
for piece in output_pieces:
index_in_test = test_output.find(piece)
if index_in_test < 0:
return False
else:
test_output = test_output[index_in_test + len(piece):]
return True

class TextAttackTest:
def __init__(self, name=None, output=None, desc=None):
if name is None:
raise ValueError('Cannot initialize TextAttackTest without name')
if output is None:
raise ValueError('Cannot initialize TextAttackTest without output')
if desc is None:
raise ValueError('Cannot initialize TextAttackTest without description')
self.name = name
self.output = output
self.desc = desc

def execute(self):
""" Executes test and returns test output. To be implemented by
subclasses.
"""
raise NotImplementedError()

def __call__(self):
""" Runs test and prints success or failure. """
self.log_start()
test_output = self.execute()
if compare_outputs(self.output, test_output):
self.log_success()
return True
else:
self.log_failure(test_output)
return False

def log_start(self):
print(f'Executing test {color_text(self.name, "blue")}.')

def log_success(self):
success_text = f'✓ Succeeded.'
print(color_text(success_text, 'green'))

def log_failure(self, test_output):
fail_text = f'✗ Failed.'
print(color_text(fail_text, 'red'))
print('\n')
print(f'Test output: {test_output}.')
print(f'Correct output: {self.output}.')

class CommandLineTest(TextAttackTest):
""" Runs a command-line command to check for desired output. """
def __init__(self, command, name=None, output=None, desc=None):
if command is None:
raise ValueError('Cannot initialize CommandLineTest without command')
self.command = command
super().__init__(name=name, output=output, desc=desc)

def execute(self):
result = subprocess.run(
self.command.split(),
stdout=subprocess.PIPE,
# @TODO: Collect stderr somewhere. In the event of an error, point user to the error file.
stderr=FNULL
)
return result.stdout.decode()

class Capturing(list):
""" A context manager that captures standard out during its execution.
stackoverflow.com/questions/16571150/how-to-capture-stdout-output-from-a-python-function-call
"""
def __enter__(self):
self._stdout = sys.stdout
sys.stdout = self._stringio = io.StringIO()
return self
def __exit__(self, *args):
self.extend(self._stringio.getvalue().splitlines())
del self._stringio # free up some memory
sys.stdout = self._stdout

class PythonFunctionTest(TextAttackTest):
""" Runs a Python function to check for desired output. """
def __init__(self, function, name=None, output=None, desc=None):
if function is None:
raise ValueError('Cannot initialize PythonFunctionTest without function')
self.function = function
super().__init__(name=name, output=output, desc=desc)

def execute(self):
with Capturing() as output_lines:
self.function()
output = '\n'.join(output_lines)
return output

7 changes: 7 additions & 0 deletions build/lib/local_tests/tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
tests = []

import python_function_tests
tests.extend(python_function_tests.tests)

import command_line_tests
tests.extend(command_line_tests.tests)

0 comments on commit a8eeb73

Please sign in to comment.