Skip to content

Commit

Permalink
[XPU] fused_rotary_position_embedding op support GQA for XPU (#63557)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyk0314 committed Apr 30, 2024
1 parent 898ecab commit 9b29f62
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 30 deletions.
10 changes: 7 additions & 3 deletions paddle/phi/kernels/fusion/xpu/fused_rope_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ void FusedRopeGradKernel(const Context& dev_ctx,
"with use_neox_rotary_style set."));
} else {
if (head_dim * sizeof(T) <= 1024 && head_dim % 64 == 0 && dout_k) {
int64_t num_heads_k = dout_k->dims()[2];
auto* dq_data = reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(dq));
auto* dk_data = reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(dk));
int ret = xpu::rotary_no_freqs_qk_embedding_v2_grad<XPUType>(
Expand All @@ -87,7 +88,8 @@ void FusedRopeGradKernel(const Context& dev_ctx,
{batch_size, seq_len, num_heads, head_dim},
{batch_size, seq_len, 1, head_dim},
{seq_len * num_heads * head_dim, num_heads * head_dim, head_dim, 1},
{seq_len * head_dim, head_dim, head_dim, 1});
{seq_len * head_dim, head_dim, head_dim, 1},
num_heads_k);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "rotary_no_freqs_qk_embedding_v2_grad");
} else {
auto* dq_data = reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(dq));
Expand All @@ -104,6 +106,7 @@ void FusedRopeGradKernel(const Context& dev_ctx,
true);

if (dout_k.get_ptr()) {
int64_t num_heads_k = dout_k->dims()[2];
auto* dk_data =
reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(dk));
XPUFusedRotaryHalf<XPUType, Context>(
Expand All @@ -114,13 +117,14 @@ void FusedRopeGradKernel(const Context& dev_ctx,
dk_data,
batch_size,
seq_len,
num_heads,
num_heads_k,
head_dim,
true);
}
}

if (dout_v.get_ptr()) {
int64_t num_heads_v = dout_v->dims()[2];
auto* dv_data = reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(dv));
XPUFusedRotaryHalf<XPUType, Context>(
dev_ctx,
Expand All @@ -130,7 +134,7 @@ void FusedRopeGradKernel(const Context& dev_ctx,
dv_data,
batch_size,
seq_len,
num_heads,
num_heads_v,
head_dim,
true);
}
Expand Down
10 changes: 7 additions & 3 deletions paddle/phi/kernels/fusion/xpu/fused_rope_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ void FusedRopeKernel(const Context& dev_ctx,
"XPU do not support rotary_embedding with use_neox_rotary_style set."));
} else {
if (head_dim * sizeof(T) <= 1024 && head_dim % 64 == 0 && k) {
int64_t num_heads_k = k->dims()[2];
auto* outq_data =
reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(out_q));
auto* outk_data =
Expand All @@ -94,7 +95,8 @@ void FusedRopeKernel(const Context& dev_ctx,
{batch_size, seq_len, num_heads, head_dim},
{batch_size, seq_len, 1, head_dim},
{seq_len * num_heads * head_dim, num_heads * head_dim, head_dim, 1},
{seq_len * head_dim, head_dim, head_dim, 1});
{seq_len * head_dim, head_dim, head_dim, 1},
num_heads_k);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "rotary_no_freqs_qk_embedding_v2");
} else {
auto* outq_data =
Expand All @@ -111,6 +113,7 @@ void FusedRopeKernel(const Context& dev_ctx,
head_dim);

if (k) {
int64_t num_heads_k = k->dims()[2];
auto* outk_data =
reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(out_k));
XPUFusedRotaryHalf<XPUType, Context>(
Expand All @@ -121,12 +124,13 @@ void FusedRopeKernel(const Context& dev_ctx,
outk_data,
batch_size,
seq_len,
num_heads,
num_heads_k,
head_dim);
}
}

if (v) {
int64_t num_heads_v = k->dims()[2];
auto* outv_data =
reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(out_v));
XPUFusedRotaryHalf<XPUType, Context>(
Expand All @@ -137,7 +141,7 @@ void FusedRopeKernel(const Context& dev_ctx,
outv_data,
batch_size,
seq_len,
num_heads,
num_heads_v,
head_dim);
}
}
Expand Down
75 changes: 51 additions & 24 deletions test/xpu/test_fused_rotary_position_embedding_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,16 @@ def setUp(self):
self.init_threshold()

def init_case(self):
self.shape = [2, 8, 2, 16]
self.shape_q = [2, 8, 2, 128]
self.shape_k = [2, 8, 2, 128]
self.shape_v = [2, 8, 2, 128]
self.dtype = 'float32'

def get_paddle_tensor(self):
tmp = paddle.randn(self.shape, self.dtype)
def get_paddle_tensor(self, shape):
if shape is None:
return None

tmp = paddle.randn(shape, self.dtype)
tmp.stop_gradient = False
return tmp

Expand All @@ -157,9 +162,9 @@ def init_threshold(self):
def get_inputs(self, seed, with_sin_cos, dtype="float32"):
paddle.disable_static()
paddle.seed(seed)
tensor_q = self.get_paddle_tensor()
tensor_k = self.get_paddle_tensor()
tensor_v = self.get_paddle_tensor()
tensor_q = self.get_paddle_tensor(self.shape_q)
tensor_k = self.get_paddle_tensor(self.shape_k)
tensor_v = self.get_paddle_tensor(self.shape_v)

tensor_sin, tensor_cos = (
get_sin_cos_tensor(tensor_q.shape[1], tensor_q.shape[3], 1, dtype)
Expand Down Expand Up @@ -198,8 +203,8 @@ def get_forward_backward(
fw.append(out_v)
paddle.seed(seed + 1)
out_gq = paddle.randn(out_q.shape, self.dtype)
out_gk = paddle.randn(out_q.shape, self.dtype)
out_gv = paddle.randn(out_q.shape, self.dtype)
out_gk = paddle.randn(out_k.shape, self.dtype)
out_gv = paddle.randn(out_v.shape, self.dtype)

paddle.autograd.backward(
[out_q, out_k, out_v], [out_gq, out_gk, out_gv], True
Expand Down Expand Up @@ -310,9 +315,15 @@ def test_static(self):

paddle.enable_static()
with base.program_guard(base.Program(), base.Program()):
q = paddle.static.data(name="q", shape=self.shape, dtype=self.dtype)
k = paddle.static.data(name="k", shape=self.shape, dtype=self.dtype)
v = paddle.static.data(name="v", shape=self.shape, dtype=self.dtype)
q = paddle.static.data(
name="q", shape=self.shape_q, dtype=self.dtype
)
k = paddle.static.data(
name="k", shape=self.shape_k, dtype=self.dtype
)
v = paddle.static.data(
name="v", shape=self.shape_v, dtype=self.dtype
)
sin = paddle.static.data(
name="sin",
shape=(1, tensor_q.shape[1], 1, tensor_q.shape[3]),
Expand Down Expand Up @@ -360,24 +371,28 @@ class XPUTestFusedRotaryPositionEmbeddingFp16_1(
XPUTestFusedRotaryPositionEmbedding
):
def init_case(self):
self.shape = [2, 8, 2, 16]
self.shape_q = [2, 8, 2, 16]
self.shape_k = [2, 8, 2, 16]
self.shape_v = [2, 8, 2, 16]
self.dtype = "float16"


class XPUTestFusedRotaryPositionEmbeddingBf16_1(unittest.TestCase):
def setUp(self):
self.shape = [2, 8, 2, 16]
self.shape_q = [2, 8, 2, 16]
self.shape_k = [2, 8, 2, 16]
self.shape_v = [2, 8, 2, 16]

def test_api(self):
paddle.disable_static()
q_bf16 = paddle.randn(self.shape, dtype="bfloat16")
k_bf16 = paddle.randn(self.shape, dtype="bfloat16")
v_bf16 = paddle.randn(self.shape, dtype="bfloat16")
q_bf16 = paddle.randn(self.shape_q, dtype="bfloat16")
k_bf16 = paddle.randn(self.shape_k, dtype="bfloat16")
v_bf16 = paddle.randn(self.shape_v, dtype="bfloat16")
sin_bf16 = paddle.randn(
[1, self.shape[1], 1, self.shape[3]], dtype="bfloat16"
[1, self.shape_q[1], 1, self.shape_q[3]], dtype="bfloat16"
)
cos_bf16 = paddle.randn(
[1, self.shape[1], 1, self.shape[3]], dtype="bfloat16"
[1, self.shape_q[1], 1, self.shape_q[3]], dtype="bfloat16"
)
q_bf16.stop_gradient = False
k_bf16.stop_gradient = False
Expand All @@ -388,9 +403,9 @@ def test_api(self):
sin_fp32 = paddle.to_tensor(sin_bf16, dtype="float32")
cos_fp32 = paddle.to_tensor(cos_bf16, dtype="float32")

position_ids = paddle.arange(0, self.shape[1], dtype="int64")
position_ids = paddle.arange(0, self.shape_q[1], dtype="int64")
position_ids = paddle.stack(
[position_ids for _ in range(self.shape[0])], axis=0
[position_ids for _ in range(self.shape_q[0])], axis=0
)
out_bf16 = fused_rotary_position_embedding(
q_bf16,
Expand All @@ -402,9 +417,9 @@ def test_api(self):
use_neox_rotary_style=False,
)

grad_out_q_bf16 = paddle.randn(self.shape, dtype="bfloat16")
grad_out_k_bf16 = paddle.randn(self.shape, dtype="bfloat16")
grad_out_v_bf16 = paddle.randn(self.shape, dtype="bfloat16")
grad_out_q_bf16 = paddle.randn(self.shape_q, dtype="bfloat16")
grad_out_k_bf16 = paddle.randn(self.shape_k, dtype="bfloat16")
grad_out_v_bf16 = paddle.randn(self.shape_v, dtype="bfloat16")

paddle.autograd.backward(
out_bf16, [grad_out_q_bf16, grad_out_k_bf16, grad_out_v_bf16], True
Expand Down Expand Up @@ -445,7 +460,19 @@ class XPUTestFusedRotaryPositionEmbeddingBf16_2(
XPUTestFusedRotaryPositionEmbeddingBf16_1
):
def setUp(self):
self.shape = [2, 2048, 16, 128]
self.shape_q = [2, 2048, 16, 128]
self.shape_k = [2, 2048, 16, 128]
self.shape_v = [2, 2048, 16, 128]


class XPUTestFusedRotaryPositionEmbeddingGQA(
XPUTestFusedRotaryPositionEmbedding
):
def init_case(self):
self.shape_q = [2, 8, 2, 16]
self.shape_k = [2, 8, 1, 16]
self.shape_v = [2, 8, 1, 16]
self.dtype = "float32"


# too long for CI
Expand Down

0 comments on commit 9b29f62

Please sign in to comment.