Skip to content

Commit

Permalink
fix attention mask
Browse files Browse the repository at this point in the history
  • Loading branch information
Haskely committed Aug 10, 2023
1 parent 2d12248 commit dd8e714
Show file tree
Hide file tree
Showing 4 changed files with 17,156 additions and 4 deletions.
9 changes: 5 additions & 4 deletions llama_gen_and_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,21 +97,22 @@ def gsm8k_batch_gen(
def get_batch_llama(model: LlamaForCausalLM, tokenizer: LlamaTokenizer):
@torch.inference_mode()
def batch_llama(input_strs: list[str]) -> list[str]:
input_ids = tokenizer(
input_ids_w_attnmask = tokenizer(
input_strs,
padding=True,
return_tensors="pt",
).input_ids.to(model.device)
).to(model.device)
output_ids = model.generate(
inputs=input_ids,
input_ids=input_ids_w_attnmask.input_ids,
attention_mask=input_ids_w_attnmask.attention_mask,
generation_config=GenerationConfig(
max_length=512,
do_sample=False,
temperature=0.0, # t=0.0 raise error if do_sample=True
),
).tolist()
real_output_ids = [
output_id[len(input_ids[i]) :] for i, output_id in enumerate(output_ids)
output_id[len(input_ids_w_attnmask.input_ids[i]) :] for i, output_id in enumerate(output_ids)
]
output_strs = tokenizer.batch_decode(real_output_ids, skip_special_tokens=True)
return output_strs
Expand Down
Loading

0 comments on commit dd8e714

Please sign in to comment.