Skip to content

Commit

Permalink
[XPU] xpu devices support llama-7b basic mode inference (turn on Bloc…
Browse files Browse the repository at this point in the history
…kAttention) (#8588)

* xpu devices support llama-7b basic mode inference (turn on BlockAttention)
  • Loading branch information
zhink committed Jun 13, 2024
1 parent 5ba7a94 commit 3d777c1
Show file tree
Hide file tree
Showing 11 changed files with 165 additions and 71 deletions.
12 changes: 12 additions & 0 deletions llm/docs/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ PaddleNLP 针对于Transformer 系列编写了高性能自定义算子,提升

```shell
git clone https://github.com/PaddlePaddle/PaddleNLP
#GPU设备安装自定义算子
cd ./paddlenlp/csrc && python setup_cuda.py install
#XPU设备安装自定义算子
cd ./paddlenlp/csrc/xpu/src && sh cmake_build.sh
```

### 2.3 关闭BlockAttention的高性能推理
Expand Down Expand Up @@ -163,6 +166,9 @@ python predictor.py --model_name_or_path ./inference --inference_model --quant_
# 动态图模型推理命令参考
python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --block_attn

# XPU设备动态图模型推理命令参考
python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --block_attn --device xpu

# Weight Only Int8 动态图推理参考
python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --quant_type weight_only_int8 --block_attn

Expand All @@ -179,6 +185,9 @@ python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_
# 动转静命令参考
python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --block_attn

# XPU设备动转静命令参考
python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --block_attn --device xpu

# Weight Only Int8 动转静命令参考
python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --quant_type weight_only_int8 --block_attn

Expand All @@ -194,6 +203,9 @@ python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --infere
# 静态图推理命令参考
python predictor.py --model_name_or_path ./inference --inference_model --dtype "float16" --mode "static" --block_attn

# XPU设备静态图推理命令参考
python predictor.py --model_name_or_path ./inference --inference_model --dtype "float16" --mode "static" --block_attn --device xpu

# Weight Only Int8 静态图推理命令参考
python predictor.py --model_name_or_path ./inference --inference_model --dtype "float16" --mode "static" --quant_type weight_only_int8 --block_attn

Expand Down
27 changes: 25 additions & 2 deletions llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,11 @@ def _create_predictor(self, predictor_args: PredictorArgument):
if predictor_args.device in paddle.device.get_all_custom_device_type():
device_id = int(os.environ.get("FLAGS_selected_{}s".format(predictor_args.device), 0))
config.enable_custom_device(predictor_args.device, device_id)
elif predictor_args.device == "xpu":
raise ValueError(
"you should export xpu static model with --block_attn flag and use predictor with --block_attn too"
"https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/inference.md"
)
else:
device_id = int(os.environ.get("FLAGS_selected_gpus", 0))
config.enable_use_gpu(100, device_id)
Expand Down Expand Up @@ -920,7 +925,9 @@ def _preprocess(self, source):
source = [self.tokenizer.apply_chat_template(sentence, tokenize=False) for sentence in source]

for i, text in enumerate(source):
add_special_tokens = self.tokenizer.chat_template is None or isinstance(self.tokenizer, (ChatGLMv2Tokenizer, ChatGLMTokenizer))
add_special_tokens = self.tokenizer.chat_template is None or isinstance(
self.tokenizer, (ChatGLMv2Tokenizer, ChatGLMTokenizer)
)
add_special_tokens = add_special_tokens if not self.benchmark else False
tokens = self.tokenizer(
text,
Expand Down Expand Up @@ -1076,6 +1083,15 @@ def _create_predictor(self, predictor_args: PredictorArgument):
if predictor_args.device in paddle.device.get_all_custom_device_type():
device_id = int(os.environ.get("FLAGS_selected_{}s".format(predictor_args.device), 0))
config.enable_custom_device(predictor_args.device, device_id)
elif predictor_args.device == "xpu":
config.enable_xpu()
device_id = int(os.environ.get("FLAGS_selected_xpus", 0))
config.set_xpu_device_id(device_id)
xpu_config = paddle.inference.XpuConfig()
xpu_config.device_id = device_id
xpu_config.l3_size = 63 * 1024 * 1024
xpu_config.l3_autotune_size = 63 * 1024 * 1024
config.set_xpu_config(xpu_config)
else:
device_id = int(os.environ.get("FLAGS_selected_gpus", 0))
config.enable_use_gpu(100, device_id)
Expand Down Expand Up @@ -1331,6 +1347,11 @@ def create_predictor(
tensor_parallel_rank=tensor_parallel_rank,
)
else:
if predictor_args.device == "xpu":
raise ValueError(
"you should run xpu dynamic model with --block_attn flag"
"https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/inference.md"
)
from paddlenlp.experimental.transformers import (
LlamaForCausalLMInferenceModel as LlamaInferenceModel,
)
Expand Down Expand Up @@ -1588,7 +1609,9 @@ def predict():

def benchmark(predictor, predictor_args, model_args):
# Just construct a simple benchmark input. We pad input to the src_length.
benchmark_texts = [predictor.tokenizer.pad_token * predictor_args.src_length for _ in range(predictor_args.batch_size)]
benchmark_texts = [
predictor.tokenizer.pad_token * predictor_args.src_length for _ in range(predictor_args.batch_size)
]

batch_benchmark_texts = batchfy_text(benchmark_texts, predictor_args.batch_size)
print("***********Start Benchmark**********")
Expand Down
5 changes: 4 additions & 1 deletion paddlenlp/experimental/transformers/bloom/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from paddle import Tensor, nn
from paddle.distributed import fleet
from paddle.nn.quant import weight_quantize
from paddlenlp_ops import get_padding_offset, get_padding_offset_v2

from paddlenlp.experimental.transformers.fused_transformer_layers import (
FusedBlockMultiTransformer,
Expand Down Expand Up @@ -219,6 +218,8 @@ def set_input_embeddings(self, new_embeddings: Tensor):
def remove_padding(self, input_ids, seq_lens_this_time):
cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time)
token_num = paddle.sum(seq_lens_this_time)
from paddlenlp_ops import get_padding_offset

ids_remove_padding, cum_offsets, padding_offset = get_padding_offset(
input_ids, cum_offsets_now, token_num, seq_lens_this_time
)
Expand Down Expand Up @@ -592,6 +593,8 @@ def set_transformer_block(self, transformer_config):
def remove_padding(self, input_ids, seq_lens_this_time):
cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time)
token_num = paddle.sum(seq_lens_this_time)
from paddlenlp_ops import get_padding_offset_v2

ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2(
input_ids, cum_offsets_now, token_num, seq_lens_this_time
)
Expand Down
3 changes: 2 additions & 1 deletion paddlenlp/experimental/transformers/chatglm/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from paddle import nn
from paddle.distributed import fleet
from paddle.nn.quant import weight_quantize
from paddlenlp_ops import get_padding_offset

from paddlenlp.experimental.transformers.fused_transformer_layers import (
FusedMultiTransformerConfig,
Expand Down Expand Up @@ -273,6 +272,8 @@ def __init__(self, config: ChatGLMConfig):
def remove_padding(self, input_ids, seq_lens_this_time):
cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time)
token_num = paddle.sum(seq_lens_this_time)
from paddlenlp_ops import get_padding_offset

ids_remove_padding, cum_offsets, padding_offset = get_padding_offset(
input_ids, cum_offsets_now, token_num, seq_lens_this_time
)
Expand Down
3 changes: 2 additions & 1 deletion paddlenlp/experimental/transformers/chatglm_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import paddle.distributed.fleet as fleet
import paddle.nn as nn
from paddle.nn.quant import weight_quantize
from paddlenlp_ops import get_padding_offset

from paddlenlp.experimental.transformers.fused_transformer_layers import (
FusedMultiTransformerBase,
Expand Down Expand Up @@ -202,6 +201,8 @@ def set_input_embeddings(self, value):
def remove_padding(self, input_ids, seq_lens_this_time):
cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time)
token_num = paddle.sum(seq_lens_this_time)
from paddlenlp_ops import get_padding_offset

ids_remove_padding, cum_offsets, padding_offset = get_padding_offset(
input_ids, cum_offsets_now, token_num, seq_lens_this_time
)
Expand Down
133 changes: 87 additions & 46 deletions paddlenlp/experimental/transformers/fused_transformer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import paddle
import paddle.distributed as dist
from paddle.framework import LayerHelper, in_dynamic_mode
from paddle.framework import LayerHelper, core, in_dynamic_mode
from paddle.incubate.nn.functional import (
fused_layer_norm,
fused_rms_norm,
Expand All @@ -29,23 +29,24 @@
from paddlenlp.utils.import_utils import is_paddlenlp_ops_available
from paddlenlp.utils.log import logger

if is_paddlenlp_ops_available():
if not is_paddlenlp_ops_available():
logger.warning(
"The paddlenlp_ops package is not installed. you can read the docs and install it by hand, "
"you can refer to: https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md"
)

from paddlenlp_ops import rebuild_padding_v2

if core.is_compiled_with_cuda():
from paddlenlp_ops import (
dequant_int8,
encode_rotary_qk,
qkv_transpose_split,
quant_int8,
rebuild_padding,
rebuild_padding_v2,
transpose_remove_padding,
write_cache_kv,
)
else:
logger.warning(
"The paddlenlp_ops package is not installed. you can read the docs and install it by hand, "
"you can refer to: https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md"
)


__all__ = [
"FusedMultiTransformerConfig",
Expand Down Expand Up @@ -1348,6 +1349,9 @@ def compute_bias_residual_layernorm(self, ffn2_out, residual_input, i, num_layer
class FusedBlockMultiTransformer(FusedMultiTransformerBase):
def __init__(self, config: FusedMultiTransformerConfig):
super().__init__(config)
if not core.is_compiled_with_cuda():
self.cache_k_per_batch_maxs = paddle.full(shape=[10, 6], fill_value=0, dtype="float32")
self.cache_v_per_batch_maxs = paddle.full(shape=[10, 6], fill_value=0, dtype="float32")

def compute_attn(
self,
Expand Down Expand Up @@ -1375,43 +1379,80 @@ def compute_attn(
v_quant_scales = self.cache_v_scales
k_dequant_scales = self.cache_k_out_scales
v_dequant_scales = self.cache_v_out_scales

fmha_out = paddle.incubate.nn.functional.block_multihead_attention(
qkv_out,
caches[2 * i],
caches[2 * i + 1],
kwargs.get("seq_lens_encoder", None),
kwargs.get("seq_lens_decoder", None),
kwargs.get("seq_lens_this_time", None),
kwargs.get("padding_offsets", None),
kwargs.get("cum_offsets", None),
kwargs.get("cu_seqlens_q", None),
kwargs.get("cu_seqlens_k", None),
kwargs.get("block_tables", None),
pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache
pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache
k_quant_scales[i] if k_quant_scales is not None else None,
v_quant_scales[i] if v_quant_scales is not None else None,
k_dequant_scales[i] if k_dequant_scales is not None else None,
v_dequant_scales[i] if v_dequant_scales is not None else None,
None, # qkv_out_scales
None, # qkv_bias
None, # out_shifts
None, # out_smooths
kwargs.get("max_enc_len_this_time", None),
kwargs.get("max_dec_len_this_time", None),
rotary_embs,
attn_mask,
kwargs.get("tgt_mask", None),
kwargs.get("max_input_length", -1),
kwargs.get("block_size", 64),
self.use_neox_rotary_style,
self.config.use_dynamic_cachekv_quant,
quant_round_type=self.config.quant_round_type,
quant_max_bound=self.config.quant_max_bound,
quant_min_bound=self.config.quant_min_bound,
)[0]

if not core.is_compiled_with_cuda():
fmha_out = paddle.incubate.nn.functional.block_multihead_attention_xpu(
qkv_out,
caches[2 * i],
caches[2 * i + 1],
kwargs.get("seq_lens_encoder", None),
kwargs.get("seq_lens_decoder", None),
kwargs.get("seq_lens_this_time", None),
kwargs.get("padding_offsets", None),
kwargs.get("cum_offsets", None),
kwargs.get("cu_seqlens_q", None),
kwargs.get("cu_seqlens_k", None),
kwargs.get("block_tables", None),
self.cache_k_per_batch_maxs,
self.cache_v_per_batch_maxs,
pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache
pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache
k_quant_scales[i] if k_quant_scales is not None else None,
v_quant_scales[i] if v_quant_scales is not None else None,
k_dequant_scales[i] if k_dequant_scales is not None else None,
v_dequant_scales[i] if v_dequant_scales is not None else None,
None, # qkv_out_scales
None, # qkv_bias
None, # out_shifts
None, # out_smooths
kwargs.get("max_enc_len_this_time", None),
kwargs.get("max_dec_len_this_time", None),
rotary_embs,
attn_mask,
kwargs.get("tgt_mask", None),
kwargs.get("max_input_length", -1),
kwargs.get("block_size", 64),
self.use_neox_rotary_style,
self.config.use_dynamic_cachekv_quant,
quant_round_type=self.config.quant_round_type,
quant_max_bound=self.config.quant_max_bound,
quant_min_bound=self.config.quant_min_bound,
)[0]
else:
fmha_out = paddle.incubate.nn.functional.block_multihead_attention(
qkv_out,
caches[2 * i],
caches[2 * i + 1],
kwargs.get("seq_lens_encoder", None),
kwargs.get("seq_lens_decoder", None),
kwargs.get("seq_lens_this_time", None),
kwargs.get("padding_offsets", None),
kwargs.get("cum_offsets", None),
kwargs.get("cu_seqlens_q", None),
kwargs.get("cu_seqlens_k", None),
kwargs.get("block_tables", None),
pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache
pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache
k_quant_scales[i] if k_quant_scales is not None else None,
v_quant_scales[i] if v_quant_scales is not None else None,
k_dequant_scales[i] if k_dequant_scales is not None else None,
v_dequant_scales[i] if v_dequant_scales is not None else None,
None, # qkv_out_scales
None, # qkv_bias
None, # out_shifts
None, # out_smooths
kwargs.get("max_enc_len_this_time", None),
kwargs.get("max_dec_len_this_time", None),
rotary_embs,
attn_mask,
kwargs.get("tgt_mask", None),
kwargs.get("max_input_length", -1),
kwargs.get("block_size", 64),
self.use_neox_rotary_style,
self.config.use_dynamic_cachekv_quant,
quant_round_type=self.config.quant_round_type,
quant_max_bound=self.config.quant_max_bound,
quant_min_bound=self.config.quant_min_bound,
)[0]
out_linear_out = self.compute_out_linear(fmha_out, i)

return out_linear_out
Expand Down
Loading

0 comments on commit 3d777c1

Please sign in to comment.