Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU] llama add xpu support #8282

Merged
merged 9 commits into from Apr 29, 2024
Merged

Conversation

dynamicheart
Copy link
Contributor

@dynamicheart dynamicheart commented Apr 17, 2024

PR types

New features

PR changes

Models

Description

  • Llama model supports XPU
  • 简要说明:XPU自定义的融合算子通过Paddle C++ Extension的方式接入,XPU自定义Paddle算子库的名称叫做paddle_xpu(aka. fast_paddle)

Copy link

paddle-bot bot commented Apr 17, 2024

Thanks for your contribution!

Copy link

codecov bot commented Apr 17, 2024

Codecov Report

Attention: Patch coverage is 42.50000% with 46 lines in your changes are missing coverage. Please review.

Project coverage is 55.35%. Comparing base (273c593) to head (6e0316a).
Report is 8 commits behind head on develop.

Files Patch % Lines
paddlenlp/transformers/llama/modeling.py 39.21% 31 Missing ⚠️
paddlenlp/transformers/linear_utils.py 48.27% 15 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #8282      +/-   ##
===========================================
+ Coverage    55.25%   55.35%   +0.10%     
===========================================
  Files          613      614       +1     
  Lines        95626    95924     +298     
===========================================
+ Hits         52837    53103     +266     
- Misses       42789    42821      +32     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@cqulilujia
Copy link

LGTM

x = paddle.to_tensor(0.0, dtype=dtype)
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype)
expanded_attn_mask = expanded_attn_mask.astype(dtype)
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当传入的xy是整型scalar类型时,paddle.where 会将其视为int64、形状[1]的tensor,并会进行broadcast_add操作,详见search.py

llm/run_pretrain.py Outdated Show resolved Hide resolved
paddlenlp/transformers/llama/modeling.py Outdated Show resolved Hide resolved
Copy link

@ZibinGuo ZibinGuo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wuhuachaocoding
Copy link
Contributor

LGTM

paddlenlp/transformers/llama/modeling.py Outdated Show resolved Hide resolved
paddlenlp/transformers/llama/modeling.py Outdated Show resolved Hide resolved
paddlenlp/transformers/llama/modeling.py Outdated Show resolved Hide resolved
paddlenlp/transformers/llama/modeling.py Outdated Show resolved Hide resolved
LinearConfig.enable_accumulate_steps_opt()
LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps)
except ImportError:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是做什么的?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XPU针对accumulate_steps > 1的场景进行优化,配合下面的paddle_xpu里面的Linear层进行使用

paddlenlp/transformers/linear_utils.py Outdated Show resolved Hide resolved
paddlenlp/transformers/linear_utils.py Show resolved Hide resolved
paddlenlp/transformers/linear_utils.py Show resolved Hide resolved
x = paddle.to_tensor(0.0, dtype=dtype)
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype)
expanded_attn_mask = expanded_attn_mask.astype(dtype)
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里和上面 npu 的逻辑看着差不多,可以复用吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

理论上是可以复用的,但是npu里面写死了dtype是float16,xpu跑的程序是可能是float16,也可能是bfloat16的。我们需要修改npu的模块么?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SylarTiaNII 看一下?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据 @wuhuachaocoding 意见,还是分成if elif两个单独的分支

Comment on lines +1742 to +1743
logits = self.xpu_parallel_matmul(
hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output, training=self.training
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

training 参数必须要吗?如果参数能一样的话,是不是 把 parallel_matmul 的实现在xpu下替换就好了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里面有两个原因:

  • XPU的一个优化是需要将parallel_matmul作为一个对象来存储某些状态
  • XPU需要training信息来进行优化

Copy link
Contributor

@ZHUI ZHUI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@sijunhe sijunhe merged commit ba9d9bd into PaddlePaddle:develop Apr 29, 2024
8 of 11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants