Skip to content

Commit

Permalink
add min_new_tokens sampling parameter (#384)
Browse files Browse the repository at this point in the history
  • Loading branch information
hiworldwzj committed Apr 1, 2024
1 parent 7cd10a0 commit a231505
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 21 deletions.
6 changes: 5 additions & 1 deletion lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
top_p: float = 1.0,
top_k: int = -1,
vocab_size: int = -1,
min_new_tokens: int = 1,
) -> None:
self.do_sample = do_sample
self.presence_penalty = presence_penalty
Expand All @@ -39,6 +40,7 @@ def __init__(
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.min_new_tokens = min_new_tokens
if self.top_k == -1:
self.top_k = vocab_size
return
Expand Down Expand Up @@ -202,7 +204,9 @@ def filter(self, request_ids: List[str], finished_request_ids: List[str]):
return self
if len(request_ids) == 0:
self.free_self()
return InferBatch(batch_id=self.batch_id, request_ids=[], req_manager=self.req_manager, radix_cache=self.radix_cache)
return InferBatch(
batch_id=self.batch_id, request_ids=[], req_manager=self.req_manager, radix_cache=self.radix_cache
)
free_req_index = []
free_token_index = []
for request_id in finished_request_ids:
Expand Down
17 changes: 12 additions & 5 deletions lightllm/server/router/model_infer/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty


def sample(logits, reqs, eos_id:List[int]=[2]):
def sample(logits, reqs, eos_id: List[int] = [2]):
logits = logits.contiguous()
(
presence_penalties,
Expand All @@ -20,6 +20,7 @@ def sample(logits, reqs, eos_id:List[int]=[2]):
p_cumsum_seq_len,
p_max_len_in_batch,
length_penalty_idx,
mask_eos_reqs,
) = _get_post_sample_tensors(reqs)

apply_penalty(
Expand All @@ -35,6 +36,7 @@ def sample(logits, reqs, eos_id:List[int]=[2]):
logits[:, eos_id] = logits[:, eos_id] + torch.abs(logits[:, eos_id]) * (
torch.pow(exponential_decay_length_penalties, length_penalty_idx).view((-1, 1)) - 1
)
logits[mask_eos_reqs, eos_id] = -1000000.0
logits.div_(temperatures.view((-1, 1)))
probs = torch.softmax(logits, dim=-1)
probs_sort, probs_idx = _top_p_top_k(probs, top_ps, top_ks)
Expand Down Expand Up @@ -67,19 +69,22 @@ def _get_post_sample_tensors(reqs):
top_ks: List[int] = []
p_token_ids: List[int] = []
p_token_counts: List[int] = []
p_seq_len: List[int] = [0,]
p_seq_len: List[int] = [
0,
]
p_max_len_in_batch: int = 0
length_penalty_idx: List[int] = []
mask_eos_reqs: List[bool] = []
for i, req_obj in enumerate(reqs):
id_to_count = req_obj.out_token_id_count
sample_param = req_obj.sampling_param
presence_penalties.append(sample_param.presence_penalty)
frequency_penalties.append(sample_param.frequency_penalty)
repetition_penalties.append(sample_param.repetition_penalty)
exponential_decay_length_penalties.append(sample_param.exponential_decay_length_penalty[1])
length_penalty_idx.append(
max(len(req_obj.input_token_ids) - req_obj.prompt_len - sample_param.exponential_decay_length_penalty[0], 0)
)
out_token_len = len(req_obj.input_token_ids) - req_obj.prompt_len
length_penalty_idx.append(max(out_token_len - sample_param.exponential_decay_length_penalty[0], 0))
mask_eos_reqs.append(out_token_len < sample_param.min_new_tokens - 1)

temperatures.append(sample_param.temperature)
top_ps.append(sample_param.top_p)
Expand All @@ -105,6 +110,7 @@ def _get_post_sample_tensors(reqs):
p_seq_len = torch.tensor(p_seq_len, dtype=torch.int32, device="cuda")
p_cumsum_seq_len = torch.cumsum(p_seq_len, dim=0, dtype=torch.int32)
length_penalty_idx = torch.tensor(length_penalty_idx, dtype=torch.int32, device="cuda")
mask_eos_reqs = torch.tensor(mask_eos_reqs, dtype=torch.bool, device="cuda")
return (
presence_penalties,
frequency_penalties,
Expand All @@ -118,4 +124,5 @@ def _get_post_sample_tensors(reqs):
p_cumsum_seq_len,
p_max_len_in_batch,
length_penalty_idx,
mask_eos_reqs,
)
57 changes: 42 additions & 15 deletions lightllm/server/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@


class SamplingParams:

def __init__(
self,
do_sample: bool = False,
Expand All @@ -15,13 +14,14 @@ def __init__(
exponential_decay_length_penalty: Tuple[int, float] = (1, 1.0),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1, # -1 is for all
top_k: int = -1, # -1 is for all
ignore_eos: bool = False,
max_new_tokens: int = 16,
min_new_tokens: int = 1,
stop_sequences: Optional[Union[str, List[str]]] = None, # 停止句子条件
skip_special_tokens: bool = True, # whether to skip special tokens when decoding
add_spaces_between_special_tokens: bool = True, # whether to add spaces between special tokens when decoding
print_eos_token: bool = False, # eos_id will be always ignored except the value is set to True
skip_special_tokens: bool = True, # whether to skip special tokens when decoding
add_spaces_between_special_tokens: bool = True, # whether to add spaces between special tokens when decoding
print_eos_token: bool = False, # eos_id will be always ignored except the value is set to True
) -> None:
self.do_sample = do_sample
self.presence_penalty = presence_penalty
Expand All @@ -33,19 +33,22 @@ def __init__(
self.top_k = top_k
self.ignore_eos = ignore_eos
self.max_new_tokens = max_new_tokens
self.min_new_tokens = min_new_tokens
self.stop_sequences = stop_sequences
self.skip_special_tokens = skip_special_tokens
self.add_spaces_between_special_tokens = add_spaces_between_special_tokens
self.print_eos_token = print_eos_token
if self.do_sample == False:
if self.do_sample is False:
self.temperature = 1.0
self.top_p = 1.0
self.top_k = 1
if self.temperature >= 0.0 and self.temperature < _SAMPLING_EPS: # temperature is too slow, change to greedy search
if (
self.temperature >= 0.0 and self.temperature < _SAMPLING_EPS
): # temperature is too slow, change to greedy search
self.temperature = 1.0
self.top_k = 1
return

def verify(self):
if self.presence_penalty < 0.0:
raise ValueError(f"presence_penalty must >= 0.0, got {self.presence_penalty}")
Expand All @@ -61,12 +64,35 @@ def verify(self):
raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.")
if self.max_new_tokens < 1:
raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.")
if self.min_new_tokens < 1:
raise ValueError(f"min_new_tokens must be at least 1 , got {self.min_new_tokens}.")
if self.min_new_tokens > self.max_new_tokens:
raise ValueError(
f"min_new_tokens must <= max_new_tokens, but got min {self.min_new_tokens}, max {self.max_new_tokens}."
)

if len(self.exponential_decay_length_penalty) != 2:
raise ValueError(f"exponential_decay_length_penalty must be a tuple of (int, float), got {self.exponential_decay_length_penalty}.")
if not isinstance(self.exponential_decay_length_penalty[0], int) or self.exponential_decay_length_penalty[0] < 0:
raise ValueError(f"exponential_decay_length_penalty[0] must be a non-negative integer, got {self.exponential_decay_length_penalty[0]}.")
if not isinstance(self.exponential_decay_length_penalty[1], float) or self.exponential_decay_length_penalty[1] < 1.0:
raise ValueError(f"exponential_decay_length_penalty[1] must be a float >= 1.0, got {self.exponential_decay_length_penalty[1]}.")
raise ValueError(
f"exponential_decay_length_penalty must be a tuple of (int, float), \
got {self.exponential_decay_length_penalty}."
)
if (
not isinstance(self.exponential_decay_length_penalty[0], int)
or self.exponential_decay_length_penalty[0] < 0
):
raise ValueError(
f"exponential_decay_length_penalty[0] must be a non-negative integer, \
got {self.exponential_decay_length_penalty[0]}."
)
if (
not isinstance(self.exponential_decay_length_penalty[1], float)
or self.exponential_decay_length_penalty[1] < 1.0
):
raise ValueError(
f"exponential_decay_length_penalty[1] must be a float >= 1.0, \
got {self.exponential_decay_length_penalty[1]}."
)

return

def stop_sentences_to_token_ids(self, tokenizer):
Expand All @@ -78,13 +104,13 @@ def stop_sentences_to_token_ids(self, tokenizer):
new_stop_sequences = []
for stop_str in self.stop_sequences:
stop_str_ids = tokenizer.encode(stop_str)
if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id
if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id
stop_str_ids = stop_str_ids[1:]
if len(stop_str_ids) > 0:
new_stop_sequences.append(stop_str_ids)
self.stop_sequences = new_stop_sequences
return

def to_dict(self):
ret = {}
ret["do_sample"] = self.do_sample
Expand All @@ -95,6 +121,7 @@ def to_dict(self):
ret["temperature"] = self.temperature
ret["top_p"] = self.top_p
ret["top_k"] = self.top_k
ret["min_new_tokens"] = self.min_new_tokens
# if self.ignore_eos is not None:
# ret["ignore_eos"] = self.ignore_eos
# if self.max_tokens is not None:
Expand Down

0 comments on commit a231505

Please sign in to comment.