-
Notifications
You must be signed in to change notification settings - Fork 375
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #81 from QData/local-tests
local tests
- Loading branch information
Showing
10 changed files
with
270 additions
and
2 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import os | ||
from test_models import CommandLineTest | ||
|
||
tests = [] | ||
|
||
def register_test(command, name=None, output_file=None, desc=None): | ||
if not os.path.exists(output_file): | ||
raise FileNotFoundError(f'Error creating test {name}: cannot find file {output_file}.') | ||
output = open(output_file).read() | ||
tests.append(CommandLineTest( | ||
command, name=name, output=output, desc=desc | ||
)) | ||
|
||
|
||
####################################### | ||
## BEGIN TESTS ## | ||
####################################### | ||
|
||
# | ||
# test: run_attack_parallel textfooler attack on 10 samples from BERT MR | ||
# (takes about 81s) | ||
# | ||
register_test('python scripts/run_attack.py --model bert-mr --recipe textfooler --num_examples 10', | ||
name='run_attack_textfooler_bert_mr_10', | ||
output_file='local_tests/outputs/run_attack_textfooler_bert_mr_10.txt', | ||
desc='Runs attack using TextFooler recipe on BERT using 10 examples from the MR dataset') | ||
|
||
# | ||
# test: run_attack_parallel textfooler attack on 10 samples from BERT SNLI | ||
# (takes about 51s) | ||
# | ||
register_test('python scripts/run_attack.py --model bert-snli --recipe textfooler --num_examples 10', | ||
name='run_attack_textfooler_bert_snli_10', | ||
output_file='local_tests/outputs/run_attack_textfooler_bert_snli_10.txt', | ||
desc='Runs attack using TextFooler recipe on BERT using 10 examples from the SNLI dataset') | ||
|
||
# | ||
# test: run_attack deepwordbug attack on 10 samples from LSTM MR | ||
# (takes about 41s) | ||
# | ||
register_test('python scripts/run_attack.py --model lstm-mr --recipe deepwordbug --num_examples 10', | ||
name='run_attack_deepwordbug_lstm_mr_10', | ||
output_file='local_tests/outputs/run_attack_deepwordbug_lstm_mr_10.txt', | ||
desc='Runs attack using DeepWordBUG recipe on LSTM using 10 examples from the MR dataset') | ||
|
||
# | ||
# test: run_attack targeted classification of class 2 on BERT MNLI with enable_csv | ||
# and attack_n set, using the WordNet transformation and beam search with | ||
# beam width 2, using language tool constraint, on 10 samples | ||
# (takes about 171s) | ||
# | ||
register_test(('python scripts/run_attack.py --attack_n --goal_function targeted-classification:target_class=2 ' | ||
'--enable_csv --model bert-mnli --num_examples 10 --transformation word-swap-wordnet ' | ||
'--constraints lang-tool --attack beam-search:beam_width=2'), | ||
name='run_attack_targeted2_bertmnli_wordnet_beamwidth_2_enablecsv_attackn', | ||
output_file='local_tests/outputs/run_attack_targetedclassification2_wordnet_langtool_enable_csv_beamsearch2_attack_n_10.txt', | ||
desc=('Runs attack using targeted classification on class 2 on BERT MNLI with' | ||
'enable_csv and attack_n set, using the WordNet transformation and beam ' | ||
'search with beam width 2, using language tool constraint, on 10 samples') | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import os | ||
from test_models import PythonFunctionTest | ||
|
||
tests = [] | ||
|
||
def register_test(function, name=None, output_file=None, desc=None): | ||
if not os.path.exists(output_file): | ||
raise FileNotFoundError(f'Error creating test {name}: cannot find file {output_file}.') | ||
output = open(output_file).read() | ||
tests.append(PythonFunctionTest( | ||
function, name=name, output=output, desc=desc | ||
)) | ||
|
||
|
||
####################################### | ||
## BEGIN TESTS ## | ||
####################################### | ||
|
||
# | ||
# | ||
# | ||
def check_gpu_count(): | ||
import torch | ||
num_gpus = torch.cuda.device_count() | ||
if num_gpus == 0: | ||
print(f'Error: detected 0 GPUs. Must run local tests with multiple GPUs. Perhaps you need to configure CUDA?') | ||
|
||
register_test(check_gpu_count, name='check CUDA', | ||
output_file='local_tests/outputs/empty_file.txt', | ||
desc='Makes sure CUDA is enabled, properly configured, and detects at least 1 GPU') | ||
|
||
# | ||
# test: import textattack | ||
# | ||
def import_textattack(): | ||
import textattack | ||
|
||
register_test(import_textattack, name='import textattack', | ||
output_file='local_tests/outputs/empty_file.txt', | ||
desc='Makes sure the textattack module can be imported') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
colored |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import os | ||
import time | ||
|
||
from test_models import color_text | ||
|
||
def log_sep(): | ||
print('\n' + ('-' * 60) + '\n') | ||
|
||
def print_gray(s): | ||
print(color_text(s, 'light_gray')) | ||
|
||
def main(): | ||
# Change to TextAttack root directory. | ||
this_file_path = os.path.abspath(__file__) | ||
test_directory_name = os.path.dirname(this_file_path) | ||
textattack_root_directory_name = os.path.dirname(test_directory_name) | ||
os.chdir(textattack_root_directory_name) | ||
print_gray(f'Executing tests from {textattack_root_directory_name}.') | ||
|
||
# Execute tests. | ||
start_time = time.time() | ||
passed_tests = 0 | ||
|
||
from tests import tests | ||
for test in tests: | ||
log_sep() | ||
test_passed = test() | ||
if test_passed: | ||
passed_tests += 1 | ||
log_sep() | ||
end_time = time.time() | ||
print_gray(f'Passed {passed_tests}/{len(tests)} in {end_time-start_time}s.') | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
# @TODO add argparser and test sizes. | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import colored | ||
import io | ||
import os | ||
import re | ||
import sys | ||
import subprocess | ||
|
||
def color_text(s, color): | ||
return colored.stylize(s, colored.fg(color)) | ||
|
||
FNULL = open(os.devnull, 'w') | ||
|
||
MAGIC_STRING = '/.*/' | ||
def compare_outputs(true_output, test_output): | ||
""" Desired have the magic string '/.*/' inserted wherever the output | ||
at that position doesn't actually matter. (For example, | ||
when the time to execute is printed, or another non-deterministic | ||
feature of the program.) | ||
`compare_outputs` makes sure all of the outputs match in between | ||
the magic strings. If they do, it returns True. | ||
""" | ||
output_pieces = true_output.split(MAGIC_STRING) | ||
for piece in output_pieces: | ||
index_in_test = test_output.find(piece) | ||
if index_in_test < 0: | ||
return False | ||
else: | ||
test_output = test_output[index_in_test + len(piece):] | ||
return True | ||
|
||
class TextAttackTest: | ||
def __init__(self, name=None, output=None, desc=None): | ||
if name is None: | ||
raise ValueError('Cannot initialize TextAttackTest without name') | ||
if output is None: | ||
raise ValueError('Cannot initialize TextAttackTest without output') | ||
if desc is None: | ||
raise ValueError('Cannot initialize TextAttackTest without description') | ||
self.name = name | ||
self.output = output | ||
self.desc = desc | ||
|
||
def execute(self): | ||
""" Executes test and returns test output. To be implemented by | ||
subclasses. | ||
""" | ||
raise NotImplementedError() | ||
|
||
def __call__(self): | ||
""" Runs test and prints success or failure. """ | ||
self.log_start() | ||
test_output = self.execute() | ||
if compare_outputs(self.output, test_output): | ||
self.log_success() | ||
return True | ||
else: | ||
self.log_failure(test_output) | ||
return False | ||
|
||
def log_start(self): | ||
print(f'Executing test {color_text(self.name, "blue")}.') | ||
|
||
def log_success(self): | ||
success_text = f'✓ Succeeded.' | ||
print(color_text(success_text, 'green')) | ||
|
||
def log_failure(self, test_output): | ||
fail_text = f'✗ Failed.' | ||
print(color_text(fail_text, 'red')) | ||
print('\n') | ||
print(f'Test output: {test_output}.') | ||
print(f'Correct output: {self.output}.') | ||
|
||
class CommandLineTest(TextAttackTest): | ||
""" Runs a command-line command to check for desired output. """ | ||
def __init__(self, command, name=None, output=None, desc=None): | ||
if command is None: | ||
raise ValueError('Cannot initialize CommandLineTest without command') | ||
self.command = command | ||
super().__init__(name=name, output=output, desc=desc) | ||
|
||
def execute(self): | ||
result = subprocess.run( | ||
self.command.split(), | ||
stdout=subprocess.PIPE, | ||
# @TODO: Collect stderr somewhere. In the event of an error, point user to the error file. | ||
stderr=FNULL | ||
) | ||
return result.stdout.decode() | ||
|
||
class Capturing(list): | ||
""" A context manager that captures standard out during its execution. | ||
stackoverflow.com/questions/16571150/how-to-capture-stdout-output-from-a-python-function-call | ||
""" | ||
def __enter__(self): | ||
self._stdout = sys.stdout | ||
sys.stdout = self._stringio = io.StringIO() | ||
return self | ||
def __exit__(self, *args): | ||
self.extend(self._stringio.getvalue().splitlines()) | ||
del self._stringio # free up some memory | ||
sys.stdout = self._stdout | ||
|
||
class PythonFunctionTest(TextAttackTest): | ||
""" Runs a Python function to check for desired output. """ | ||
def __init__(self, function, name=None, output=None, desc=None): | ||
if function is None: | ||
raise ValueError('Cannot initialize PythonFunctionTest without function') | ||
self.function = function | ||
super().__init__(name=name, output=output, desc=desc) | ||
|
||
def execute(self): | ||
with Capturing() as output_lines: | ||
self.function() | ||
output = '\n'.join(output_lines) | ||
return output | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
tests = [] | ||
|
||
import python_function_tests | ||
tests.extend(python_function_tests.tests) | ||
|
||
import command_line_tests | ||
tests.extend(command_line_tests.tests) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters