Skip to content

Commit

Permalink
support fused weights for export_model
Browse files Browse the repository at this point in the history
  • Loading branch information
ronny1996 committed Jun 5, 2024
1 parent f36ed75 commit ff61d4a
Showing 1 changed file with 70 additions and 43 deletions.
113 changes: 70 additions & 43 deletions paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
GenerationInferenceModel,
)
from paddlenlp.transformers import LlamaConfig, LlamaPretrainedModel
from paddlenlp.transformers.conversion_utils import split_param_func

Check warning on line 50 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L50

Added line #L50 was not covered by tests
from paddlenlp.transformers.llama.modeling import LlamaLMHead
from paddlenlp.transformers.model_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
Expand Down Expand Up @@ -473,48 +474,66 @@ def forward(
def set_state_dict(self, state_dict):
unfused_state_dict = {}
head_size = self.hidden_size // self.num_attention_heads
split_fn = split_param_func()

Check warning on line 477 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L477

Added line #L477 was not covered by tests

self.embed_tokens.weight.set_value(paddle.to_tensor(state_dict["llama.embed_tokens.weight"]))
self.norm.weight.set_value(paddle.to_tensor(state_dict["llama.norm.weight"], dtype=self.norm.weight.dtype))
self.embed_tokens.weight.set_value(

Check warning on line 479 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L479

Added line #L479 was not covered by tests
paddle.to_tensor(state_dict["llama.embed_tokens.weight"]).cast(self.embed_tokens.weight.dtype)
)
self.norm.weight.set_value(paddle.to_tensor(state_dict["llama.norm.weight"]).cast(self.norm.weight.dtype))

Check warning on line 482 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L482

Added line #L482 was not covered by tests

for idx in range(self.config.num_hidden_layers):
logger.info(f"set state for layer {idx}")

if self.use_weight_only:
logger.info("weight only is enabled")
unfused_state_dict = {}
unfused_state_dict["self_attn.q_proj.weight"] = state_dict[
"llama.layers.{}.self_attn.q_proj.weight".format(idx)
]
unfused_state_dict["self_attn.k_proj.weight"] = state_dict[
"llama.layers.{}.self_attn.k_proj.weight".format(idx)
]
unfused_state_dict["self_attn.v_proj.weight"] = state_dict[
"llama.layers.{}.self_attn.v_proj.weight".format(idx)
]

concated_qkv_weight = (
np.concatenate(
[
unfused_state_dict["self_attn.q_proj.weight"],
unfused_state_dict["self_attn.k_proj.weight"],
unfused_state_dict["self_attn.v_proj.weight"],
],
if "llama.layers.{}.self_attn.qkv_proj.weight".format(idx) in state_dict.keys():
concated_qkv_weight = np.concatenate(

Check warning on line 490 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L489-L490

Added lines #L489 - L490 were not covered by tests
split_fn(
state_dict["llama.layers.{}.self_attn.qkv_proj.weight".format(idx)],
is_qkv=True,
num_heads=self.num_attention_heads // self.config.tensor_parallel_degree,
num_key_value_heads=self.num_attention_heads // self.config.tensor_parallel_degree,
),
axis=-1,
)
.transpose(1, 0)
.reshape(
3 * (self.num_attention_heads // self.config.tensor_parallel_degree) * (head_size),
self.hidden_size,
else:
unfused_state_dict = {}
unfused_state_dict["self_attn.q_proj.weight"] = state_dict[

Check warning on line 501 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L500-L501

Added lines #L500 - L501 were not covered by tests
"llama.layers.{}.self_attn.q_proj.weight".format(idx)
]
unfused_state_dict["self_attn.k_proj.weight"] = state_dict[

Check warning on line 504 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L504

Added line #L504 was not covered by tests
"llama.layers.{}.self_attn.k_proj.weight".format(idx)
]
unfused_state_dict["self_attn.v_proj.weight"] = state_dict[

Check warning on line 507 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L507

Added line #L507 was not covered by tests
"llama.layers.{}.self_attn.v_proj.weight".format(idx)
]
concated_qkv_weight = (

Check warning on line 510 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L510

Added line #L510 was not covered by tests
np.concatenate(
[
unfused_state_dict["self_attn.q_proj.weight"],
unfused_state_dict["self_attn.k_proj.weight"],
unfused_state_dict["self_attn.v_proj.weight"],
],
axis=-1,
)
.transpose(1, 0)
.reshape(
3 * (self.num_attention_heads // self.config.tensor_parallel_degree) * (head_size),
self.hidden_size,
)
) # reshape(3, self.num_attention_heself.hidden_sizeads // self.config.tensor_parallel_degree, head_size, )
if "llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx) in state_dict.keys():
ffn1_weight_tensor = np.concatenate(

Check warning on line 526 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L525-L526

Added lines #L525 - L526 were not covered by tests
split_fn(state_dict["llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx)]), axis=-1
)
else:
unfused_state_dict["mlp.gate_proj.weight"] = state_dict[

Check warning on line 530 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L530

Added line #L530 was not covered by tests
"llama.layers.{}.mlp.gate_proj.weight".format(idx)
]
unfused_state_dict["mlp.up_proj.weight"] = state_dict["llama.layers.{}.mlp.up_proj.weight".format(idx)]
concated_ffn1_weight = np.concatenate(

Check warning on line 534 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L533-L534

Added lines #L533 - L534 were not covered by tests
[unfused_state_dict["mlp.gate_proj.weight"], unfused_state_dict["mlp.up_proj.weight"]], axis=-1
)
) # reshape(3, self.num_attention_heself.hidden_sizeads // self.config.tensor_parallel_degree, head_size, )

unfused_state_dict["mlp.gate_proj.weight"] = state_dict["llama.layers.{}.mlp.gate_proj.weight".format(idx)]
unfused_state_dict["mlp.up_proj.weight"] = state_dict["llama.layers.{}.mlp.up_proj.weight".format(idx)]

concated_ffn1_weight = np.concatenate(
[unfused_state_dict["mlp.gate_proj.weight"], unfused_state_dict["mlp.up_proj.weight"]], axis=-1
)
ffn1_weight_tensor = paddle.to_tensor(concated_ffn1_weight)

qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight)
Expand All @@ -534,7 +553,9 @@ def set_state_dict(self, state_dict):
paddle.cast(paddle.to_tensor(concated_qkv_weight), "int8")
)
else:
self.transformer_block.qkv_weights[idx].set_value(qkv_weight_tensor)
self.transformer_block.qkv_weights[idx].set_value(

Check warning on line 556 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L556

Added line #L556 was not covered by tests
qkv_weight_tensor.cast(self.transformer_block.qkv_weights[idx].dtype)
)

linear_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.weight".format(idx)])
if self.use_weight_only:
Expand All @@ -556,7 +577,9 @@ def set_state_dict(self, state_dict):
)
)
else:
self.transformer_block.linear_weights[idx].set_value(linear_weight_tensor)
self.transformer_block.linear_weights[idx].set_value(

Check warning on line 580 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L580

Added line #L580 was not covered by tests
linear_weight_tensor.cast(self.transformer_block.linear_weights[idx].dtype)
)

if self.use_weight_only:
ffn1_quanted_weight_tensor, ffn1_weight_scale_tensor = weight_quantize(
Expand All @@ -572,7 +595,9 @@ def set_state_dict(self, state_dict):
paddle.cast(paddle.to_tensor(concated_ffn1_weight).transpose((1, 0)), "int8")
)
else:
self.transformer_block.ffn1_weights[idx].set_value(ffn1_weight_tensor)
self.transformer_block.ffn1_weights[idx].set_value(

Check warning on line 598 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L598

Added line #L598 was not covered by tests
ffn1_weight_tensor.cast(self.transformer_block.ffn1_weights[idx].dtype)
)

ffn2_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.weight".format(idx)])
if self.use_weight_only:
Expand All @@ -594,7 +619,9 @@ def set_state_dict(self, state_dict):
)
)
else:
self.transformer_block.ffn2_weights[idx].set_value(ffn2_weight_tensor)
self.transformer_block.ffn2_weights[idx].set_value(

Check warning on line 622 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L622

Added line #L622 was not covered by tests
ffn2_weight_tensor.cast(self.transformer_block.ffn2_weights[idx].dtype)
)

if self.quant_type == "a8w8":
if self.shift_smooth_all_linears:
Expand Down Expand Up @@ -660,16 +687,14 @@ def set_state_dict(self, state_dict):
)

self.transformer_block.ln_scales[idx].set_value(
paddle.to_tensor(
state_dict["llama.layers.{}.input_layernorm.weight".format(idx)],
dtype=self.transformer_block.ln_scales[idx].dtype,
paddle.to_tensor(state_dict["llama.layers.{}.input_layernorm.weight".format(idx)]).cast(
self.transformer_block.ln_scales[idx].dtype
)
)

self.transformer_block.ffn_ln_scales[idx].set_value(
paddle.to_tensor(
state_dict["llama.layers.{}.post_attention_layernorm.weight".format(idx)],
dtype=self.transformer_block.ffn_ln_scales[idx].dtype,
paddle.to_tensor(state_dict["llama.layers.{}.post_attention_layernorm.weight".format(idx)]).cast(
self.transformer_block.ffn_ln_scales[idx].dtype
)
)

Expand Down Expand Up @@ -1264,7 +1289,9 @@ def forward(
@paddle.no_grad()
def set_state_dict(self, state_dict):
if "lm_head.weight" in state_dict:
self.lm_head.weight.set_value(state_dict["lm_head.weight"])
self.lm_head.weight.set_value(

Check warning on line 1292 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L1292

Added line #L1292 was not covered by tests
paddle.to_tensor(state_dict["lm_head.weight"]).cast(self.lm_head.weight.dtype)
)
self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()})


Expand Down

0 comments on commit ff61d4a

Please sign in to comment.