Skip to content

Commit

Permalink
Merge pull request #658 from VijayKalmath/Fix-total_adv_training_step…
Browse files Browse the repository at this point in the history
…s-Calculation

Fix total_adv_training_steps calculation in trainer.py
  • Loading branch information
jxmorris12 committed Jun 8, 2022
2 parents 6d85f29 + 24ff6f2 commit ea3ae24
Showing 1 changed file with 57 additions and 13 deletions.
70 changes: 57 additions & 13 deletions textattack/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,18 +177,40 @@ def _generate_adversarial_examples(self, epoch):
else:
num_train_adv_examples = self.training_args.num_train_adv_examples

attack_args = AttackArgs(
num_successful_examples=num_train_adv_examples,
num_examples_offset=0,
query_budget=self.training_args.query_budget_train,
shuffle=True,
parallel=self.training_args.parallel,
num_workers_per_device=self.training_args.attack_num_workers_per_device,
disable_stdout=True,
silent=True,
log_to_txt=log_file_name + ".txt",
log_to_csv=log_file_name + ".csv",
)
# Use Different AttackArgs based on num_train_adv_examples value.
# If num_train_adv_examples >= 0 , num_train_adv_examples is
# set as number of successful examples.
# If num_train_adv_examples == -1 , num_examples is set to -1 to
# generate example for all of training data.
if num_train_adv_examples >= 0:
attack_args = AttackArgs(
num_successful_examples=num_train_adv_examples,
num_examples_offset=0,
query_budget=self.training_args.query_budget_train,
shuffle=True,
parallel=self.training_args.parallel,
num_workers_per_device=self.training_args.attack_num_workers_per_device,
disable_stdout=True,
silent=True,
log_to_txt=log_file_name + ".txt",
log_to_csv=log_file_name + ".csv",
)
elif num_train_adv_examples == -1:
# set num_examples when num_train_adv_examples = -1
attack_args = AttackArgs(
num_examples=num_train_adv_examples,
num_examples_offset=0,
query_budget=self.training_args.query_budget_train,
shuffle=True,
parallel=self.training_args.parallel,
num_workers_per_device=self.training_args.attack_num_workers_per_device,
disable_stdout=True,
silent=True,
log_to_txt=log_file_name + ".txt",
log_to_csv=log_file_name + ".csv",
)
else:
assert False, "num_train_adv_examples is negative and not equal to -1."

attacker = Attacker(self.attack, self.train_dataset, attack_args=attack_args)
results = attacker.attack_dataset()
Expand Down Expand Up @@ -609,8 +631,30 @@ def train(self):
)
* num_clean_epochs
)

# calculate total_adv_training_data_length based on type of
# num_train_adv_examples.
# if num_train_adv_examples is float , num_train_adv_examples is a portion of train_dataset.
if isinstance(self.training_args.num_train_adv_examples, float):
total_adv_training_data_length = (
len(self.train_dataset) * self.training_args.num_train_adv_examples
)

# if num_train_adv_examples is int and >=0 then it is taken as value.
elif (
isinstance(self.training_args.num_train_adv_examples, int)
and self.training_args.num_train_adv_examples >= 0
):
total_adv_training_data_length = self.training_args.num_train_adv_examples

# if num_train_adv_examples is = -1 , we generate all possible adv examples.
# Max number of all possible adv examples would be equal to train_dataset.
else:
total_adv_training_data_length = len(self.train_dataset)

# Based on total_adv_training_data_length calculation , find total total_adv_training_steps
total_adv_training_steps = math.ceil(
(len(self.train_dataset) + self.training_args.num_train_adv_examples)
(len(self.train_dataset) + total_adv_training_data_length)
/ (train_batch_size * self.training_args.gradient_accumulation_steps)
) * (self.training_args.num_epochs - num_clean_epochs)

Expand Down

0 comments on commit ea3ae24

Please sign in to comment.