Skip to content

Commit

Permalink
modified default gen_kwargs to work better with CLI; changed prompt_l…
Browse files Browse the repository at this point in the history
…ogprobs=1 (EleutherAI#1345)
  • Loading branch information
baberabb authored and anjor committed Jan 31, 2024
1 parent 4a2c48a commit 08af37f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions lm_eval/models/vllm_causallms.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _model_generate(
sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
else:
sampling_params = SamplingParams(
temperature=0, prompt_logprobs=2, max_tokens=1
temperature=0, prompt_logprobs=1, max_tokens=1
)
if self.data_parallel_size > 1:
requests = [list(x) for x in divide(requests, self.data_parallel_size)]
Expand Down Expand Up @@ -405,8 +405,8 @@ def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]:
@staticmethod
def modify_gen_kwargs(kwargs: dict) -> dict:
# sampling_params
do_sample = kwargs.pop("do_sample", False)
if do_sample is not True:
do_sample = kwargs.pop("do_sample", None)
if do_sample is False or "temperature" not in kwargs:
kwargs["temperature"] = 0.0
# hf defaults
kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False)
Expand Down

0 comments on commit 08af37f

Please sign in to comment.