Skip to content

Commit

Permalink
Merge pull request #439 from QData/revert-400-specify_split
Browse files Browse the repository at this point in the history
Revert "add --split to specify train/test/dev dataset"
  • Loading branch information
qiyanjun committed Apr 2, 2021
2 parents 77a244d + aa0d39b commit 0555059
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 33 deletions.
11 changes: 3 additions & 8 deletions tests/test_command_line/test_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@
"textattack attack --model-from-huggingface "
"distilbert-base-uncased-finetuned-sst-2-english "
"--dataset-from-huggingface glue^sst2^train --recipe deepwordbug --num-examples 3 "
"--shuffle=False "
"--split='train'"
"--shuffle=False"
),
"tests/sample_outputs/run_attack_transformers_datasets.txt",
),
Expand Down Expand Up @@ -97,7 +96,6 @@
"/tmp/textattack_test.csv --model bert-base-uncased-mnli --num-examples 2 --attack-n --transformation "
"word-swap-wordnet --constraints lang-tool repeat stopword --search beam-search^beam_width=2 "
"--shuffle=False "
"--split='validation_matched'"
),
"tests/sample_outputs/run_attack_targetedclassification2_wordnet_langtool_log-to-csv_beamsearch2_attack_n.txt",
),
Expand Down Expand Up @@ -133,8 +131,7 @@
"run_attack_kuleshov_nn",
(
"textattack attack --recipe kuleshov --num-examples 2 --model cnn-sst2 "
"--attack-n --query-budget 200 --shuffle=False "
"--split='validation'"
"--attack-n --query-budget 200 --shuffle=False"
),
"tests/sample_outputs/kuleshov_cnn_sst_2.txt",
),
Expand All @@ -145,8 +142,7 @@
"run_attack_stanza_pos_tagger",
(
"textattack attack --model lstm-mr --num-examples 4 --search-method greedy --transformation word-swap-embedding "
"--constraints repeat stopword part-of-speech^tagger_type=\\'stanza\\' --shuffle=False "
"--split='test'"
"--constraints repeat stopword part-of-speech^tagger_type=\\'stanza\\' --shuffle=False"
),
"tests/sample_outputs/run_attack_stanza_pos_tagger.txt",
),
Expand Down Expand Up @@ -203,7 +199,6 @@ def test_command_line_attack(name, command, sample_output_file):

if DEBUG and not re.match(desired_re, stdout, flags=re.S):
pdb.set_trace()

assert re.match(desired_re, stdout, flags=re.S)

assert result.returncode == 0
27 changes: 3 additions & 24 deletions textattack/commands/attack/attack_args_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,6 @@ def add_dataset_args(parser):
help="The offset to start at in the dataset.",
)

parser.add_argument(
"--split",
type=str,
required=False,
default="test",
help="Choose train, test or dev dataset.",
)


def load_module_from_file(file_path):
"""Uses ``importlib`` to dynamically open a file and load an object from
Expand Down Expand Up @@ -421,22 +413,9 @@ def parse_dataset_from_args(args):
dataset_args = dataset_args.split(ARGS_SPLIT_TOKEN)
else:
dataset_args = (dataset_args,)
if args.split:
dataset_list = list(dataset_args)
if len(dataset_list) > 2:
dataset_list[2] = args.split
dataset = textattack.datasets.HuggingFaceDataset(
*dataset_list, shuffle=args.shuffle
)
else:
dataset = textattack.datasets.HuggingFaceDataset(
*dataset_list, shuffle=args.shuffle, split=args.split
)

else:
dataset = textattack.datasets.HuggingFaceDataset(
*dataset_args, shuffle=args.shuffle
)
dataset = textattack.datasets.HuggingFaceDataset(
*dataset_args, shuffle=args.shuffle
)
dataset.examples = dataset.examples[args.num_examples_offset :]
else:
raise ValueError("Must supply pretrained model or dataset")
Expand Down
2 changes: 1 addition & 1 deletion textattack/datasets/huggingface_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
self,
name,
subset=None,
split="test",
split="train",
label_map=None,
output_scale_factor=None,
dataset_columns=None,
Expand Down

0 comments on commit 0555059

Please sign in to comment.