Skip to content

Commit

Permalink
Merge pull request #657 from VijayKalmath/Fix-Train-AdvesarialDataset
Browse files Browse the repository at this point in the history
Fix adversarial dataset generation in trainer.py
  • Loading branch information
jxmorris12 committed Jun 8, 2022
2 parents f59e49a + 6c1351c commit 2ce4df1
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions textattack/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,18 +225,24 @@ def _generate_adversarial_examples(self, epoch):
f"Attack success rate: {success_rate:.2f}% [{attack_types['SuccessfulAttackResult']} / {total_attacks}]"
)
# TODO: This will produce a bug if we need to manipulate ground truth output.

# To Fix Issue #498 , We need to add the Non Output columns in one tuple to represent input columns
# Since adversarial_example won't be an input to the model , we will have to remove it from the input
# dictionary in collate_fn
adversarial_examples = [
(
tuple(r.perturbed_result.attacked_text._text_input.values()),
tuple(r.perturbed_result.attacked_text._text_input.values())
+ ("adversarial_example",),
r.perturbed_result.ground_truth_output,
"adversarial_example",
)
for r in results
if isinstance(r, (SuccessfulAttackResult, MaximizedAttackResult))
]

# Name for column indicating if an example is adversarial is set as "_example_type".
adversarial_dataset = textattack.datasets.Dataset(
adversarial_examples,
input_columns=self.train_dataset.input_columns,
input_columns=self.train_dataset.input_columns + ("_example_type",),
label_map=self.train_dataset.label_map,
label_names=self.train_dataset.label_names,
output_scale_factor=self.train_dataset.output_scale_factor,
Expand Down Expand Up @@ -399,9 +405,15 @@ def collate_fn(data):
targets = []
is_adv_sample = []
for item in data:
if len(item) == 3:
# `len(item)` is 3 for adversarial training dataset
_input, label, adv = item
if "_example_type" in item[0].keys():

# Get example type value from OrderedDict and remove it

adv = item[0].pop("_example_type")

# with _example_type removed from item[0] OrderedDict
# all other keys should be part of input
_input, label = item
if adv != "adversarial_example":
raise ValueError(
"`item` has length of 3 but last element is not for marking if the item is an `adversarial example`."
Expand Down

1 comment on commit 2ce4df1

@plasmashen
Copy link
Contributor

@plasmashen plasmashen commented on 2ce4df1 Aug 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

File "/home/zjunesa/anaconda2/envs/pt2/lib/python3.8/site-packages/textattack/trainer.py", line 245, in _generate_adversarial_examples
    input_columns=self.train_dataset.input_columns + ("_example_type",),
TypeError: can only concatenate list (not "tuple") to list

Please sign in to comment.