Skip to content

Commit

Permalink
support fused weights for export_model
Browse files Browse the repository at this point in the history
  • Loading branch information
ronny1996 committed Jun 5, 2024
1 parent f36ed75 commit 67a23e6
Showing 1 changed file with 53 additions and 38 deletions.
91 changes: 53 additions & 38 deletions paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,47 +474,57 @@ def set_state_dict(self, state_dict):
unfused_state_dict = {}
head_size = self.hidden_size // self.num_attention_heads

self.embed_tokens.weight.set_value(paddle.to_tensor(state_dict["llama.embed_tokens.weight"]))
self.embed_tokens.weight.set_value(paddle.to_tensor(

Check warning on line 477 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L477

Added line #L477 was not covered by tests
state_dict["llama.embed_tokens.weight"], dtype=self.embed_tokens.weight.dtype))
self.norm.weight.set_value(paddle.to_tensor(state_dict["llama.norm.weight"], dtype=self.norm.weight.dtype))

for idx in range(self.config.num_hidden_layers):
logger.info(f"set state for layer {idx}")

if self.use_weight_only:
logger.info("weight only is enabled")
unfused_state_dict = {}
unfused_state_dict["self_attn.q_proj.weight"] = state_dict[
"llama.layers.{}.self_attn.q_proj.weight".format(idx)
]
unfused_state_dict["self_attn.k_proj.weight"] = state_dict[
"llama.layers.{}.self_attn.k_proj.weight".format(idx)
]
unfused_state_dict["self_attn.v_proj.weight"] = state_dict[
"llama.layers.{}.self_attn.v_proj.weight".format(idx)
]

concated_qkv_weight = (
np.concatenate(
[
unfused_state_dict["self_attn.q_proj.weight"],
unfused_state_dict["self_attn.k_proj.weight"],
unfused_state_dict["self_attn.v_proj.weight"],
],
axis=-1,
)
.transpose(1, 0)
.reshape(
3 * (self.num_attention_heads // self.config.tensor_parallel_degree) * (head_size),
self.hidden_size,
if "llama.layers.{}.self_attn.qkv_proj.weight".format(idx) in state_dict.keys():
concated_qkv_weight = state_dict["llama.layers.{}.self_attn.qkv_proj.weight".format(

Check warning on line 487 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L486-L487

Added lines #L486 - L487 were not covered by tests
idx)].transpose([1, 0])
else:
unfused_state_dict = {}
unfused_state_dict["self_attn.q_proj.weight"] = state_dict[

Check warning on line 491 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L490-L491

Added lines #L490 - L491 were not covered by tests
"llama.layers.{}.self_attn.q_proj.weight".format(idx)
]
unfused_state_dict["self_attn.k_proj.weight"] = state_dict[

Check warning on line 494 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L494

Added line #L494 was not covered by tests
"llama.layers.{}.self_attn.k_proj.weight".format(idx)
]
unfused_state_dict["self_attn.v_proj.weight"] = state_dict[

Check warning on line 497 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L497

Added line #L497 was not covered by tests
"llama.layers.{}.self_attn.v_proj.weight".format(idx)
]
concated_qkv_weight = (

Check warning on line 500 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L500

Added line #L500 was not covered by tests
np.concatenate(
[
unfused_state_dict["self_attn.q_proj.weight"],
unfused_state_dict["self_attn.k_proj.weight"],
unfused_state_dict["self_attn.v_proj.weight"],
],
axis=-1,
)
.transpose(1, 0)
.reshape(
3 * (self.num_attention_heads //
self.config.tensor_parallel_degree) * (head_size),
self.hidden_size,
)
) # reshape(3, self.num_attention_heself.hidden_sizeads // self.config.tensor_parallel_degree, head_size, )
if "llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx) in state_dict.keys():
concated_ffn1_weight = state_dict["llama.layers.{}.mlp.gate_up_fused_proj.weight".format(

Check warning on line 517 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L516-L517

Added lines #L516 - L517 were not covered by tests
idx)]
else:
unfused_state_dict["mlp.gate_proj.weight"] = state_dict["llama.layers.{}.mlp.gate_proj.weight".format(

Check warning on line 520 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L520

Added line #L520 was not covered by tests
idx)]
unfused_state_dict["mlp.up_proj.weight"] = state_dict["llama.layers.{}.mlp.up_proj.weight".format(

Check warning on line 522 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L522

Added line #L522 was not covered by tests
idx)]
concated_ffn1_weight = np.concatenate(

Check warning on line 524 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L524

Added line #L524 was not covered by tests
[unfused_state_dict["mlp.gate_proj.weight"],
unfused_state_dict["mlp.up_proj.weight"]], axis=-1
)
) # reshape(3, self.num_attention_heself.hidden_sizeads // self.config.tensor_parallel_degree, head_size, )

unfused_state_dict["mlp.gate_proj.weight"] = state_dict["llama.layers.{}.mlp.gate_proj.weight".format(idx)]
unfused_state_dict["mlp.up_proj.weight"] = state_dict["llama.layers.{}.mlp.up_proj.weight".format(idx)]

concated_ffn1_weight = np.concatenate(
[unfused_state_dict["mlp.gate_proj.weight"], unfused_state_dict["mlp.up_proj.weight"]], axis=-1
)
ffn1_weight_tensor = paddle.to_tensor(concated_ffn1_weight)

qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight)
Expand All @@ -534,7 +544,8 @@ def set_state_dict(self, state_dict):
paddle.cast(paddle.to_tensor(concated_qkv_weight), "int8")
)
else:
self.transformer_block.qkv_weights[idx].set_value(qkv_weight_tensor)
self.transformer_block.qkv_weights[idx].set_value(

Check warning on line 547 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L547

Added line #L547 was not covered by tests
qkv_weight_tensor.cast(self.transformer_block.qkv_weights[idx].dtype))

linear_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.weight".format(idx)])
if self.use_weight_only:
Expand All @@ -556,7 +567,8 @@ def set_state_dict(self, state_dict):
)
)
else:
self.transformer_block.linear_weights[idx].set_value(linear_weight_tensor)
self.transformer_block.linear_weights[idx].set_value(

Check warning on line 570 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L570

Added line #L570 was not covered by tests
linear_weight_tensor.cast(self.transformer_block.linear_weights[idx].dtype))

if self.use_weight_only:
ffn1_quanted_weight_tensor, ffn1_weight_scale_tensor = weight_quantize(
Expand All @@ -572,7 +584,8 @@ def set_state_dict(self, state_dict):
paddle.cast(paddle.to_tensor(concated_ffn1_weight).transpose((1, 0)), "int8")
)
else:
self.transformer_block.ffn1_weights[idx].set_value(ffn1_weight_tensor)
self.transformer_block.ffn1_weights[idx].set_value(

Check warning on line 587 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L587

Added line #L587 was not covered by tests
ffn1_weight_tensor.cast(self.transformer_block.ffn1_weights[idx].dtype))

ffn2_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.weight".format(idx)])
if self.use_weight_only:
Expand All @@ -594,7 +607,8 @@ def set_state_dict(self, state_dict):
)
)
else:
self.transformer_block.ffn2_weights[idx].set_value(ffn2_weight_tensor)
self.transformer_block.ffn2_weights[idx].set_value(

Check warning on line 610 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L610

Added line #L610 was not covered by tests
ffn2_weight_tensor.cast(self.transformer_block.ffn2_weights[idx].dtype))

if self.quant_type == "a8w8":
if self.shift_smooth_all_linears:
Expand Down Expand Up @@ -1264,7 +1278,8 @@ def forward(
@paddle.no_grad()
def set_state_dict(self, state_dict):
if "lm_head.weight" in state_dict:
self.lm_head.weight.set_value(state_dict["lm_head.weight"])
self.lm_head.weight.set_value(paddle.to_tensor(

Check warning on line 1281 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L1281

Added line #L1281 was not covered by tests
state_dict["lm_head.weight"], dtype=self.lm_head.weight.dtype))
self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()})


Expand Down

0 comments on commit 67a23e6

Please sign in to comment.