Skip to content

Commit 3161014

Browse files
authored
[BugFix]fix v1 loader moe bf16, and supoort dynamic_load_weight create quant param (#4229)
* fix v1 loader moe bf16, and supoort dynamic_load_weight create quant param * include_stop_str_in_output=False not return eos text
1 parent 44010ce commit 3161014

File tree

3 files changed

+9
-2
lines changed

3 files changed

+9
-2
lines changed

fastdeploy/input/text_processor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@ def __init__(self, model_name_or_path, reasoning_parser_obj=None, tool_parser_ob
185185
from paddleformers.trl.llm_utils import get_eos_token_id
186186

187187
self.eos_token_ids = get_eos_token_id(self.tokenizer, self.generation_config)
188+
data_processor_logger.info(
189+
f"The eos_token_ids obtained by merging tokenizer and generation_config is {self.eos_token_ids}"
190+
)
188191
self.eos_token_id_len = len(self.eos_token_ids)
189192
self.pad_token_id = self.get_pad_id()
190193
self.reasoning_parser = None
@@ -396,7 +399,7 @@ def process_response_dict_normal(self, response_dict, **kwargs):
396399
is_end = response_dict["finished"]
397400
req_id = response_dict["request_id"]
398401
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
399-
if token_ids[-1] == self.tokenizer.eos_token_id:
402+
if token_ids[-1] in self.eos_token_ids:
400403
token_ids = token_ids[:-1]
401404
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
402405
if is_end:
@@ -434,7 +437,7 @@ def process_response_dict_streaming(self, response_dict, **kwargs):
434437
token_ids = response_dict["outputs"]["token_ids"]
435438

436439
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
437-
if token_ids[-1] == self.tokenizer.eos_token_id:
440+
if token_ids[-1] in self.eos_token_ids:
438441
token_ids = token_ids[:-1]
439442
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
440443
response_dict["outputs"]["raw_prediction"] = delta_text

fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,15 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
199199
layer.up_gate_proj_weight,
200200
{
201201
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
202+
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
202203
"model_format": extra_weight_attrs.get("model_format", ""),
203204
},
204205
)
205206
set_weight_attrs(
206207
layer.down_proj_weight,
207208
{
208209
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
210+
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
209211
"model_format": extra_weight_attrs.get("model_format", ""),
210212
},
211213
)

fastdeploy/model_executor/layers/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ def parse_quant_config(args, model_config, is_ernie, is_v1_loader):
8585
else:
8686
if not quantization_config.get("is_quantized"):
8787
quantization_config["is_quantized"] = model_config.is_quantized
88+
if args.dynamic_load_weight and quantization_config is not None:
89+
quantization_config["is_quantized"] = True
8890
quant_cls = get_quantization_config(quant_config_name)
8991
quant_config = quant_cls.from_config(quantization_config)
9092
return quant_config

0 commit comments

Comments
 (0)