Skip to content

Commit

Permalink
Merge pull request #476 from QData/api-tests
Browse files Browse the repository at this point in the history
Fix bugs and update training test
  • Loading branch information
jinyongyoo committed Jun 25, 2021
2 parents 254c984 + 4ece786 commit 8637d20
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
2 changes: 1 addition & 1 deletion tests/test_command_line/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def test_train_tiny():
command = "textattack train --model-name-or-path lstm --attack deepwordbug --dataset glue^cola --model-max-length 32 --num-epochs 1 --num-train-adv-examples 200"
command = "textattack train --model-name-or-path lstm --attack deepwordbug --dataset glue^cola --model-max-length 32 --num-epochs 2 --num-clean-epochs 1 --num-train-adv-examples 200"

# Run command and validate outputs.
result = run_command_and_get_result(command)
Expand Down
1 change: 1 addition & 0 deletions textattack/models/tokenizers/glove_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(
self.pad_token_id = pad_token_id
self.oov_token_id = unk_token_id
self.convert_id_to_word = self.id_to_token
self.model_max_length = max_length
# Set defaults.
self.enable_padding(length=max_length, pad_id=pad_token_id)
self.enable_truncation(max_length=max_length)
Expand Down
13 changes: 8 additions & 5 deletions textattack/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def train(self):
else:
train_batch_size = self.training_args.per_device_train_batch_size

if self.training_args.attack is None:
if self.attack is None:
num_clean_epochs = self.training_args.num_epochs
else:
num_clean_epochs = self.training_args.num_clean_epochs
Expand Down Expand Up @@ -831,12 +831,15 @@ def _write_readme(self, best_eval_score, best_eval_score_epoch, train_batch_size
and self.training_args.model_max_length
):
model_max_length = self.training_args.model_max_length
elif isinstance(self.model_wrapper.model, transformers.PreTrainedModel):
model_max_length = self.model_wrapper.model.config.max_position_embeddings
elif isinstance(
self.model_wrapper.model, (LSTMForClassification, WordCNNForClassification)
self.model_wrapper.model,
(
transformers.PreTrainedModel,
LSTMForClassification,
WordCNNForClassification,
),
):
model_max_length = self.model_wrapper.model.max_length
model_max_length = self.model_wrapper.tokenizer.model_max_length
else:
model_max_length = None

Expand Down

0 comments on commit 8637d20

Please sign in to comment.