Skip to content

Commit

Permalink
Fix marian slow test (huggingface#6854)
Browse files Browse the repository at this point in the history
  • Loading branch information
sshleifer authored and Zigur committed Oct 26, 2020
1 parent dc67132 commit 7cd4b45
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions tests/test_modeling_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
convert_hf_name_to_opus_name,
convert_opus_name_to_hf_name,
)
from transformers.modeling_bart import shift_tokens_right
from transformers.pipelines import TranslationPipeline


Expand Down Expand Up @@ -116,18 +117,21 @@ def test_forward(self):
expected_ids = [38, 121, 14, 697, 38848, 0]

model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt).to(torch_device)

self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())

desired_keys = {
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
"labels",
}
self.assertSetEqual(desired_keys, set(model_inputs.keys()))
model_inputs["decoder_input_ids"] = shift_tokens_right(model_inputs.labels, self.tokenizer.pad_token_id)
model_inputs["return_dict"] = True
model_inputs["use_cache"] = False
with torch.no_grad():
logits, *enc_features = self.model(**model_inputs)
max_indices = logits.argmax(-1)
outputs = self.model(**model_inputs)
max_indices = outputs.logits.argmax(-1)
self.tokenizer.batch_decode(max_indices)

def test_unk_support(self):
Expand Down

0 comments on commit 7cd4b45

Please sign in to comment.