Skip to content

Commit

Permalink
Merge pull request #400 from QData/specify_split
Browse files Browse the repository at this point in the history
add --split to specify train/test/dev dataset
  • Loading branch information
qiyanjun committed Apr 2, 2021
2 parents e343f6d + 211abc5 commit 77a244d
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 7 deletions.
11 changes: 8 additions & 3 deletions tests/test_command_line/test_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
"textattack attack --model-from-huggingface "
"distilbert-base-uncased-finetuned-sst-2-english "
"--dataset-from-huggingface glue^sst2^train --recipe deepwordbug --num-examples 3 "
"--shuffle=False"
"--shuffle=False "
"--split='train'"
),
"tests/sample_outputs/run_attack_transformers_datasets.txt",
),
Expand Down Expand Up @@ -96,6 +97,7 @@
"/tmp/textattack_test.csv --model bert-base-uncased-mnli --num-examples 2 --attack-n --transformation "
"word-swap-wordnet --constraints lang-tool repeat stopword --search beam-search^beam_width=2 "
"--shuffle=False "
"--split='validation_matched'"
),
"tests/sample_outputs/run_attack_targetedclassification2_wordnet_langtool_log-to-csv_beamsearch2_attack_n.txt",
),
Expand Down Expand Up @@ -131,7 +133,8 @@
"run_attack_kuleshov_nn",
(
"textattack attack --recipe kuleshov --num-examples 2 --model cnn-sst2 "
"--attack-n --query-budget 200 --shuffle=False"
"--attack-n --query-budget 200 --shuffle=False "
"--split='validation'"
),
"tests/sample_outputs/kuleshov_cnn_sst_2.txt",
),
Expand All @@ -142,7 +145,8 @@
"run_attack_stanza_pos_tagger",
(
"textattack attack --model lstm-mr --num-examples 4 --search-method greedy --transformation word-swap-embedding "
"--constraints repeat stopword part-of-speech^tagger_type=\\'stanza\\' --shuffle=False"
"--constraints repeat stopword part-of-speech^tagger_type=\\'stanza\\' --shuffle=False "
"--split='test'"
),
"tests/sample_outputs/run_attack_stanza_pos_tagger.txt",
),
Expand Down Expand Up @@ -199,6 +203,7 @@ def test_command_line_attack(name, command, sample_output_file):

if DEBUG and not re.match(desired_re, stdout, flags=re.S):
pdb.set_trace()

assert re.match(desired_re, stdout, flags=re.S)

assert result.returncode == 0
27 changes: 24 additions & 3 deletions textattack/commands/attack/attack_args_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ def add_dataset_args(parser):
help="The offset to start at in the dataset.",
)

parser.add_argument(
"--split",
type=str,
required=False,
default="test",
help="Choose train, test or dev dataset.",
)


def load_module_from_file(file_path):
"""Uses ``importlib`` to dynamically open a file and load an object from
Expand Down Expand Up @@ -413,9 +421,22 @@ def parse_dataset_from_args(args):
dataset_args = dataset_args.split(ARGS_SPLIT_TOKEN)
else:
dataset_args = (dataset_args,)
dataset = textattack.datasets.HuggingFaceDataset(
*dataset_args, shuffle=args.shuffle
)
if args.split:
dataset_list = list(dataset_args)
if len(dataset_list) > 2:
dataset_list[2] = args.split
dataset = textattack.datasets.HuggingFaceDataset(
*dataset_list, shuffle=args.shuffle
)
else:
dataset = textattack.datasets.HuggingFaceDataset(
*dataset_list, shuffle=args.shuffle, split=args.split
)

else:
dataset = textattack.datasets.HuggingFaceDataset(
*dataset_args, shuffle=args.shuffle
)
dataset.examples = dataset.examples[args.num_examples_offset :]
else:
raise ValueError("Must supply pretrained model or dataset")
Expand Down
2 changes: 1 addition & 1 deletion textattack/datasets/huggingface_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
self,
name,
subset=None,
split="train",
split="test",
label_map=None,
output_scale_factor=None,
dataset_columns=None,
Expand Down

0 comments on commit 77a244d

Please sign in to comment.