-
Notifications
You must be signed in to change notification settings - Fork 143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
beam_width argument in retrieval_lm/run_short_form.py #28
Comments
I hava met the same problem as you
|
@fate-ubw I have done the same thing for so far only able to make run short_form with : always_retrieve mode, other mode are throwing error. |
Thank you so much for reporting! I was changing the codebase before releasing and seems forgot to fix the variable name. I will fix it. |
I fixed the beam_searh argument in the script. Thanks again for reporting the issue! |
Hello,
I'm trying to reproduce paper numbers on arc_challenge by running the following command :
python run_short_form.py
--model_name selfrag/selfrag_llama2_7b
--input_file eval_data/arc_challenge_processed.jsonl
--max_new_tokens 50 --threshold 0.2
--output_file OUTPUT_FILE_NAME
--metric match --ndocs 5 --use_groundness --use_utility --use_seqscore
--task arc_c
But, I'm getting an error :
return call_model_rerank_w_scores_batch(prompt, evidences=evidences, model=model, max_new_tokens=max_new_tokens,
TypeError: call_model_rerank_w_scores_batch() got an unexpected keyword argument 'beam_width'
Opening : retrieval_lm/run_short_form.py
def call_model_rerank_w_scores_batch(prompt, evidences, model, max_new_tokens=15,
ret_tokens=None, rel_tokens=None, grd_tokens=None, ut_tokens=None,
use_seqscore=False, threshold=0.5,
w_rel=1.0, w_sup=1.0, w_use=0.5, mode="adaptive_retrieval", closed=False):
def generate(prompt, evidences, max_new_tokens):
return call_model_rerank_w_scores_batch(prompt, evidences=evidences, model=model, max_new_tokens=max_new_tokens,
rel_tokens=rel_tokens, ret_tokens=ret_tokens, grd_tokens=grd_tokens, ut_tokens=ut_tokens,
threshold=args.threshold, beam_width=args.beam_width, max_depth=args.max_depth, use_seqscore=args.use_seqscore,
w_rel=args.w_rel, w_sup=args.w_sup, w_use=args.w_use, mode=args.mode, closed=args.task in ["fever", "arc_c"])
Maybe I'm missing something, any help would be appreciated !
The text was updated successfully, but these errors were encountered: