Skip to content

Commit

Permalink
Merge pull request #81 from QData/local-tests
Browse files Browse the repository at this point in the history
local tests
  • Loading branch information
jxmorris12 committed Apr 29, 2020
2 parents 5d7259e + 96e204c commit fb07f7f
Show file tree
Hide file tree
Showing 10 changed files with 270 additions and 2 deletions.
Empty file added local_tests/__init__.py
Empty file.
60 changes: 60 additions & 0 deletions local_tests/command_line_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os
from test_models import CommandLineTest

tests = []

def register_test(command, name=None, output_file=None, desc=None):
if not os.path.exists(output_file):
raise FileNotFoundError(f'Error creating test {name}: cannot find file {output_file}.')
output = open(output_file).read()
tests.append(CommandLineTest(
command, name=name, output=output, desc=desc
))


#######################################
## BEGIN TESTS ##
#######################################

#
# test: run_attack_parallel textfooler attack on 10 samples from BERT MR
# (takes about 81s)
#
register_test('python scripts/run_attack.py --model bert-mr --recipe textfooler --num_examples 10',
name='run_attack_textfooler_bert_mr_10',
output_file='local_tests/outputs/run_attack_textfooler_bert_mr_10.txt',
desc='Runs attack using TextFooler recipe on BERT using 10 examples from the MR dataset')

#
# test: run_attack_parallel textfooler attack on 10 samples from BERT SNLI
# (takes about 51s)
#
register_test('python scripts/run_attack.py --model bert-snli --recipe textfooler --num_examples 10',
name='run_attack_textfooler_bert_snli_10',
output_file='local_tests/outputs/run_attack_textfooler_bert_snli_10.txt',
desc='Runs attack using TextFooler recipe on BERT using 10 examples from the SNLI dataset')

#
# test: run_attack deepwordbug attack on 10 samples from LSTM MR
# (takes about 41s)
#
register_test('python scripts/run_attack.py --model lstm-mr --recipe deepwordbug --num_examples 10',
name='run_attack_deepwordbug_lstm_mr_10',
output_file='local_tests/outputs/run_attack_deepwordbug_lstm_mr_10.txt',
desc='Runs attack using DeepWordBUG recipe on LSTM using 10 examples from the MR dataset')

#
# test: run_attack targeted classification of class 2 on BERT MNLI with enable_csv
# and attack_n set, using the WordNet transformation and beam search with
# beam width 2, using language tool constraint, on 10 samples
# (takes about 171s)
#
register_test(('python scripts/run_attack.py --attack_n --goal_function targeted-classification:target_class=2 '
'--enable_csv --model bert-mnli --num_examples 10 --transformation word-swap-wordnet '
'--constraints lang-tool --attack beam-search:beam_width=2'),
name='run_attack_targeted2_bertmnli_wordnet_beamwidth_2_enablecsv_attackn',
output_file='local_tests/outputs/run_attack_targetedclassification2_wordnet_langtool_enable_csv_beamsearch2_attack_n_10.txt',
desc=('Runs attack using targeted classification on class 2 on BERT MNLI with'
'enable_csv and attack_n set, using the WordNet transformation and beam '
'search with beam width 2, using language tool constraint, on 10 samples')
)
40 changes: 40 additions & 0 deletions local_tests/python_function_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
from test_models import PythonFunctionTest

tests = []

def register_test(function, name=None, output_file=None, desc=None):
if not os.path.exists(output_file):
raise FileNotFoundError(f'Error creating test {name}: cannot find file {output_file}.')
output = open(output_file).read()
tests.append(PythonFunctionTest(
function, name=name, output=output, desc=desc
))


#######################################
## BEGIN TESTS ##
#######################################

#
#
#
def check_gpu_count():
import torch
num_gpus = torch.cuda.device_count()
if num_gpus == 0:
print(f'Error: detected 0 GPUs. Must run local tests with multiple GPUs. Perhaps you need to configure CUDA?')

register_test(check_gpu_count, name='check CUDA',
output_file='local_tests/outputs/empty_file.txt',
desc='Makes sure CUDA is enabled, properly configured, and detects at least 1 GPU')

#
# test: import textattack
#
def import_textattack():
import textattack

register_test(import_textattack, name='import textattack',
output_file='local_tests/outputs/empty_file.txt',
desc='Makes sure the textattack module can be imported')
1 change: 1 addition & 0 deletions local_tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
colored
38 changes: 38 additions & 0 deletions local_tests/run_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
import time

from test_models import color_text

def log_sep():
print('\n' + ('-' * 60) + '\n')

def print_gray(s):
print(color_text(s, 'light_gray'))

def main():
# Change to TextAttack root directory.
this_file_path = os.path.abspath(__file__)
test_directory_name = os.path.dirname(this_file_path)
textattack_root_directory_name = os.path.dirname(test_directory_name)
os.chdir(textattack_root_directory_name)
print_gray(f'Executing tests from {textattack_root_directory_name}.')

# Execute tests.
start_time = time.time()
passed_tests = 0

from tests import tests
for test in tests:
log_sep()
test_passed = test()
if test_passed:
passed_tests += 1
log_sep()
end_time = time.time()
print_gray(f'Passed {passed_tests}/{len(tests)} in {end_time-start_time}s.')



if __name__ == '__main__':
# @TODO add argparser and test sizes.
main()
120 changes: 120 additions & 0 deletions local_tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import colored
import io
import os
import re
import sys
import subprocess

def color_text(s, color):
return colored.stylize(s, colored.fg(color))

FNULL = open(os.devnull, 'w')

MAGIC_STRING = '/.*/'
def compare_outputs(true_output, test_output):
""" Desired have the magic string '/.*/' inserted wherever the output
at that position doesn't actually matter. (For example,
when the time to execute is printed, or another non-deterministic
feature of the program.)
`compare_outputs` makes sure all of the outputs match in between
the magic strings. If they do, it returns True.
"""
output_pieces = true_output.split(MAGIC_STRING)
for piece in output_pieces:
index_in_test = test_output.find(piece)
if index_in_test < 0:
return False
else:
test_output = test_output[index_in_test + len(piece):]
return True

class TextAttackTest:
def __init__(self, name=None, output=None, desc=None):
if name is None:
raise ValueError('Cannot initialize TextAttackTest without name')
if output is None:
raise ValueError('Cannot initialize TextAttackTest without output')
if desc is None:
raise ValueError('Cannot initialize TextAttackTest without description')
self.name = name
self.output = output
self.desc = desc

def execute(self):
""" Executes test and returns test output. To be implemented by
subclasses.
"""
raise NotImplementedError()

def __call__(self):
""" Runs test and prints success or failure. """
self.log_start()
test_output = self.execute()
if compare_outputs(self.output, test_output):
self.log_success()
return True
else:
self.log_failure(test_output)
return False

def log_start(self):
print(f'Executing test {color_text(self.name, "blue")}.')

def log_success(self):
success_text = f'✓ Succeeded.'
print(color_text(success_text, 'green'))

def log_failure(self, test_output):
fail_text = f'✗ Failed.'
print(color_text(fail_text, 'red'))
print('\n')
print(f'Test output: {test_output}.')
print(f'Correct output: {self.output}.')

class CommandLineTest(TextAttackTest):
""" Runs a command-line command to check for desired output. """
def __init__(self, command, name=None, output=None, desc=None):
if command is None:
raise ValueError('Cannot initialize CommandLineTest without command')
self.command = command
super().__init__(name=name, output=output, desc=desc)

def execute(self):
result = subprocess.run(
self.command.split(),
stdout=subprocess.PIPE,
# @TODO: Collect stderr somewhere. In the event of an error, point user to the error file.
stderr=FNULL
)
return result.stdout.decode()

class Capturing(list):
""" A context manager that captures standard out during its execution.
stackoverflow.com/questions/16571150/how-to-capture-stdout-output-from-a-python-function-call
"""
def __enter__(self):
self._stdout = sys.stdout
sys.stdout = self._stringio = io.StringIO()
return self
def __exit__(self, *args):
self.extend(self._stringio.getvalue().splitlines())
del self._stringio # free up some memory
sys.stdout = self._stdout

class PythonFunctionTest(TextAttackTest):
""" Runs a Python function to check for desired output. """
def __init__(self, function, name=None, output=None, desc=None):
if function is None:
raise ValueError('Cannot initialize PythonFunctionTest without function')
self.function = function
super().__init__(name=name, output=output, desc=desc)

def execute(self):
with Capturing() as output_lines:
self.function()
output = '\n'.join(output_lines)
return output

7 changes: 7 additions & 0 deletions local_tests/tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
tests = []

import python_function_tests
tests.extend(python_function_tests.tests)

import command_line_tests
tests.extend(command_line_tests.tests)
1 change: 0 additions & 1 deletion scripts/run_attack_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def set_env_variables(gpu_id):

def attack_from_queue(args, in_queue, out_queue):
gpu_id = torch.multiprocessing.current_process()._identity[0] - 2
print('Using GPU #' + str(gpu_id))
set_env_variables(gpu_id)
_, attack = parse_goal_function_and_attack_from_args(args)
if gpu_id == 0:
Expand Down
File renamed without changes.
5 changes: 4 additions & 1 deletion textattack/attack_methods/attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ def _filter_transformations(self, transformations, text, original_text=None):
self.constraints_cache[t] = self.constraints_cache[t]
self._filter_transformations_uncached(uncached_transformations, text, original_text=original_text)
# Return transformations from cache.
return [t for t in transformations if self.constraints_cache[t]]
filtered_transformations = [t for t in transformations if self.constraints_cache[t]]
# Sort transformations to ensure order is preserved between runs.
filtered_transformations.sort(key=lambda t: t.text)
return filtered_transformations

def attack_one(self, tokenized_text):
"""
Expand Down

0 comments on commit fb07f7f

Please sign in to comment.