Skip to content

Commit

Permalink
add sampling params: skip_special_tokens spaces_between_special_tokens (
Browse files Browse the repository at this point in the history
  • Loading branch information
shihaobai committed Mar 26, 2024
1 parent aa98b35 commit 5667c98
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 31 deletions.
11 changes: 1 addition & 10 deletions lightllm/server/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def main():
default=None,
help="max tokens num for new cat batch, it control prefill batch size to Preventing OOM",
)
parser.add_argument("--eos_id", type=int, default=2, help="eos stop token id")
parser.add_argument("--eos_id", nargs='+', type=int, default=2, help="eos stop token id")
parser.add_argument(
"--running_max_req_size", type=int, default=1000, help="the max size for forward requests in the same time"
)
Expand Down Expand Up @@ -364,15 +364,6 @@ def main():
"--router_max_new_token_len", type=int, default=1024, help="the request max new token len for router"
)

parser.add_argument(
"--no_skipping_special_tokens", action="store_true", help="whether to skip special tokens when decoding"
)
parser.add_argument(
"--no_spaces_between_special_tokens",
action="store_true",
help="whether to add spaces between special tokens when decoding",
)

parser.add_argument("--use_dynamic_prompt_cache", action="store_true", help="use_dynamic_prompt_cache test")

parser.add_argument("--splitfuse_mode", action="store_true", help="use splitfuse mode")
Expand Down
17 changes: 10 additions & 7 deletions lightllm/server/detokenization/decode.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Union, List

from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

Expand All @@ -9,21 +9,24 @@ def decode_token(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
req: ReqDetokenizationState,
new_token_id: int,
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
eos_id: List[int],
) -> str:
new_token = tokenizer.convert_ids_to_tokens(
new_token_id, skip_special_tokens=skip_special_tokens)
new_token_id, skip_special_tokens=req.skip_special_tokens)
req.output_tokens.append(new_token)

if skip_special_tokens and new_token_id in tokenizer.all_special_ids:

is_eos_id = new_token_id in eos_id
if is_eos_id and not req.print_eos_token:
return req.output_str

if req.skip_special_tokens and new_token_id in tokenizer.all_special_ids and not is_eos_id:
return req.output_str

if not getattr(tokenizer, "added_tokens_encoder", {}):
output_text = tokenizer.convert_tokens_to_string(req.output_tokens)
return output_text

sep = " " if spaces_between_special_tokens else ""
sep = " " if req.add_spaces_between_special_tokens else ""

if new_token in tokenizer.added_tokens_encoder:
if req.current_sub_text:
Expand Down
12 changes: 4 additions & 8 deletions lightllm/server/detokenization/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@
class DeTokenizationManager:
def __init__(
self,
eos_id,
model_weightdir,
tokenizor_mode,
detokenization_port,
httpserver_port,
trust_remote_code,
skip_special_tokens,
spaces_between_special_tokens,
):
context = zmq.asyncio.Context(2)
self.recv_from_router = context.socket(zmq.PULL)
Expand All @@ -35,9 +34,8 @@ def __init__(
self.tokenizer = get_tokenizer(model_weightdir, tokenizor_mode, trust_remote_code=trust_remote_code)
self.all_special_ids = set(self.tokenizer.all_special_ids)
self.req_id_to_out = {}
self.eos_id = eos_id

self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens

async def handle_loop(self):
while True:
Expand Down Expand Up @@ -67,8 +65,7 @@ async def handle_loop(self):
self.tokenizer,
req_out,
new_token_id,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
self.eos_id,
)

if out_text.endswith(u'\ufffd'):
Expand All @@ -92,13 +89,12 @@ async def handle_loop(self):
def start_detokenization_process(args, detokenization_port, httpserver_port, pipe_writer):
try:
router = DeTokenizationManager(
args.eos_id,
args.model_dir,
args.tokenizer_mode,
detokenization_port=detokenization_port,
httpserver_port=httpserver_port,
trust_remote_code=args.trust_remote_code,
skip_special_tokens=not args.no_skipping_special_tokens,
spaces_between_special_tokens=not args.no_spaces_between_special_tokens,
)
except Exception as e:
pipe_writer.send(str(e))
Expand Down
12 changes: 10 additions & 2 deletions lightllm/server/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def to_rpc_obj(self):

def to_req_detokenization_state(self):
out = ReqDetokenizationState(
self.request_id, self.prompt_ids, self.max_output_len, self.sample_params.ignore_eos
self.request_id, self.prompt_ids, self.max_output_len, self.sample_params.ignore_eos,
self.sample_params.skip_special_tokens, self.sample_params.add_spaces_between_special_tokens,
self.sample_params.print_eos_token
)
# if self.output_metadata_list: # looks like no use
# out.gen_metadata.update(self.output_metadata_list[-1])
Expand Down Expand Up @@ -226,6 +228,9 @@ def __init__(
prompt_ids: List[int],
max_output_len: int,
ignore_eos: bool,
skip_special_tokens: bool,
add_spaces_between_special_tokens: bool,
print_eos_token: bool,
) -> None:
self.request_id = request_id
self.prompt_ids = prompt_ids
Expand All @@ -237,6 +242,9 @@ def __init__(
self.max_output_len = max_output_len
self.ignore_eos = ignore_eos
self.gen_metadata = {}
self.skip_special_tokens = skip_special_tokens
self.add_spaces_between_special_tokens = add_spaces_between_special_tokens
self.print_eos_token = print_eos_token


class Batch:
Expand Down Expand Up @@ -265,7 +273,7 @@ def mark_and_get_finished_req_and_preupdate_status(self, eos_id):
for req in self.reqs:
if req.stop_sequences_matched():
req.finish_status = FinishStatus.FINISHED_STOP
elif len(req.output_ids) >= 1 and req.output_ids[-1] == eos_id and req.sample_params.ignore_eos is False:
elif len(req.output_ids) >= 1 and req.output_ids[-1] in eos_id and req.sample_params.ignore_eos is False:
req.finish_status = FinishStatus.FINISHED_STOP
elif len(req.output_ids) >= req.max_output_len:
req.finish_status = FinishStatus.FINISHED_LENGTH
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/router/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ async def wait_to_model_ready(self):
"splitfuse_block_size": self.splitfuse_block_size,
"return_all_prompt_logprobs": self.args.return_all_prompt_logprobs,
"use_dynamic_prompt_cache": self.args.use_dynamic_prompt_cache,
"eos_id": self.eos_id,
}
init_model_ret.append(self.model_rpcs[rank_id].init_model(kvargs))

Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/router/model_infer/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def exposed_init_model(self, kvargs):
self.splitfuse_block_size = kvargs.get("splitfuse_block_size", None)
self.return_all_prompt_logprobs = kvargs.get("return_all_prompt_logprobs", False)
self.use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False)
self.eos_id = kvargs.get("eos_id", [2])

self.cache = {}
self.logger = init_logger(__name__)
Expand Down Expand Up @@ -101,7 +102,6 @@ def exposed_init_model(self, kvargs):

try:
self.model_type = model_cfg.get("model_type", "")
self.eos_id = model_cfg.get("eos_token_id", 2)
if self.model_type == "bloom":
self.model = BloomTpPartModel(model_kvargs)
elif self.model_type == "llama":
Expand Down
4 changes: 2 additions & 2 deletions lightllm/server/router/model_infer/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty


def sample(logits, reqs, eos_id=2):
def sample(logits, reqs, eos_id:List[int]=[2]):
logits = logits.contiguous()
(
presence_penalties,
Expand Down Expand Up @@ -33,7 +33,7 @@ def sample(logits, reqs, eos_id=2):
p_max_len_in_batch,
)
logits[:, eos_id] = logits[:, eos_id] + torch.abs(logits[:, eos_id]) * (
torch.pow(exponential_decay_length_penalties, length_penalty_idx).view((-1,)) - 1
torch.pow(exponential_decay_length_penalties, length_penalty_idx).view((-1, 1)) - 1
)
logits.div_(temperatures.view((-1, 1)))
probs = torch.softmax(logits, dim=-1)
Expand Down
8 changes: 7 additions & 1 deletion lightllm/server/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ def __init__(
top_k: int = -1, # -1 is for all
ignore_eos: bool = False,
max_new_tokens: int = 16,
stop_sequences: Optional[Union[str, List[str]]] = None # 停止句子条件
stop_sequences: Optional[Union[str, List[str]]] = None, # 停止句子条件
skip_special_tokens: bool = True, # whether to skip special tokens when decoding
add_spaces_between_special_tokens: bool = True, # whether to add spaces between special tokens when decoding
print_eos_token: bool = False, # eos_id will be always ignored except the value is set to True
) -> None:
self.do_sample = do_sample
self.presence_penalty = presence_penalty
Expand All @@ -31,6 +34,9 @@ def __init__(
self.ignore_eos = ignore_eos
self.max_new_tokens = max_new_tokens
self.stop_sequences = stop_sequences
self.skip_special_tokens = skip_special_tokens
self.add_spaces_between_special_tokens = add_spaces_between_special_tokens
self.print_eos_token = print_eos_token
if self.do_sample == False:
self.temperature = 1.0
self.top_p = 1.0
Expand Down

0 comments on commit 5667c98

Please sign in to comment.