Skip to content

Commit

Permalink
Refine fuse_mt code
Browse files Browse the repository at this point in the history
  • Loading branch information
penPenf28 committed May 20, 2024
1 parent fa960e7 commit ecca46f
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 103 deletions.
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,3 @@ paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/autogen/*
paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/autogen_tmp/*
paddle/fluid/pybind/static_op_function.*
paddle/fluid/pybind/ops_api.cc

# these files are auto-generated by memory_efficient_fmha_variable
autogen*
10 changes: 9 additions & 1 deletion paddle/fluid/operators/fused/fused_multi_transformer_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/fused/fused_multi_transformer_helper.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_resource_pool.h"

#include "paddle/fluid/framework/op_registry.h"

#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/flash_attn_kernel.h"
#include "paddle/phi/kernels/fusion/gpu/fmha_ref.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
Expand Down
110 changes: 11 additions & 99 deletions paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ limitations under the License. */
#include <iomanip>

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/common/flags.h"
#include "paddle/fluid/distributed/collective/process_group_nccl.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
Expand Down Expand Up @@ -151,8 +150,6 @@ struct Masked_multihead_attention_params {
bool neox_rotary_style;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T,
int Dh,
int Dh_MAX,
Expand Down Expand Up @@ -248,18 +245,12 @@ __global__ void masked_multihead_attention_kernel(

Qk_vec q;
zero(q);
// q = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
// ? *reinterpret_cast<const Qk_vec *>(&q_base[qk_offset])
// : q;
if (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) {
load_func.template load<Qk_vec>(q, qk_offset + hi * Dh);
}

Qk_vec k;
zero(k);
// k = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
// ? *reinterpret_cast<const Qk_vec *>(&k_base[qk_offset])
// : k;
if (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) {
load_func.template load<Qk_vec>(
k, params.num_head * Dh + qk_offset + kv_hi * Dh);
Expand Down Expand Up @@ -321,41 +312,15 @@ __global__ void masked_multihead_attention_kernel(
kv_hi * Dh + right_id * QK_VEC_SIZE;
Qk_vec q_right;
zero(q_right);
// q_right =
// (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh)
// ? *reinterpret_cast<const Qk_vec *>(&q_base[qk_right_offset])
// : q_right;
if (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) {
load_func.template load<Qk_vec>(q_right, q_right_offset);
}
Qk_vec k_right;
zero(k_right);
// k_right =
// (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh)
// ? *reinterpret_cast<const Qk_vec *>(&k_base[qk_right_offset])
// : k_right;
if (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) {
load_func.template load<Qk_vec>(k_right, k_right_offset);
}

// if (params.add_qkv_bias) {
// Qk_vec q_right_bias;
// zero(q_right_bias);
// q_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh)
// ? *reinterpret_cast<const Qk_vec *>(
// &q_bias_base[qk_right_bias_offset])
// : q_right_bias;
// Qk_vec k_right_bias;
// zero(k_right_bias);
// k_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh)
// ? *reinterpret_cast<const Qk_vec *>(
// &k_bias_base[qk_right_bias_offset])
// : k_right_bias;

// q_right = add(q_right, q_right_bias);
// k_right = add(k_right, k_right_bias);
// }

Qk_vec_RoPE cos_emb;
zero(cos_emb);
cos_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
Expand Down Expand Up @@ -383,15 +348,11 @@ __global__ void masked_multihead_attention_kernel(

int co = tid / QK_VECS_IN_16B;
int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE;
// int offset = bhi * params.max_seq_length * Dh +
// co * params.max_seq_length * QK_ELTS_IN_16B +
// act_time_step * QK_ELTS_IN_16B + ci;

int offset = bi * params.gqa_group_size * params.max_seq_length * Dh +
kv_hi * params.max_seq_length * Dh +
co * params.max_seq_length * QK_ELTS_IN_16B +
act_time_step * QK_ELTS_IN_16B + ci;
// quant k and store the int8 value into cache kv
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
*reinterpret_cast<Qk_vec *>(&params.cache_kv[offset]) = k;
}
Expand Down Expand Up @@ -450,18 +411,10 @@ __global__ void masked_multihead_attention_kernel(
T *k_cache =
&params.cache_kv[bi * params.gqa_group_size * params.max_seq_length * Dh +
kv_hi * params.max_seq_length * Dh + ki];
// T *k_cache_batch = &params.cache_kv[bbhi * params.max_seq_length * Dh +
// ki];

int ti_end = div_up(act_time_step, K_PER_WARP) * K_PER_WARP;

const int *beam_offsets = params.beam_cache_offset
? &params.beam_cache_offset[bi_seq_len_offset]
: nullptr;
for (int ti = ko; ti < ti_end; ti += K_PER_ITER) {
const int beam_offset = beam_offsets ? beam_offsets[ti] * params.num_head *
params.max_seq_length * Dh
: 0;
K_vec k[K_VECS_PER_THREAD];
K_vec k_vec_zero;
zero(k_vec_zero);
Expand All @@ -470,20 +423,11 @@ __global__ void masked_multihead_attention_kernel(
int jj = ii * params.max_seq_length + ti;
// get k from the cache_kv, and dequant k for qk operation
if (ti < act_time_step) {
if (beam_offset) {
// k[ii] =
// (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh *
// params.max_seq_length)
// ? *reinterpret_cast<const K_vec *>(
// &k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B])
// : k_vec_zero;
} else {
k[ii] =
(Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length)
? *reinterpret_cast<const K_vec *>(
&k_cache[jj * QK_ELTS_IN_16B])
: k_vec_zero;
}
k[ii] =
(Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length)
? *reinterpret_cast<const K_vec *>(
&k_cache[jj * QK_ELTS_IN_16B])
: k_vec_zero;
}
}

Expand Down Expand Up @@ -549,7 +493,6 @@ __global__ void masked_multihead_attention_kernel(

// FIXME(wangxi): need add 1.e-6f?
float inv_sum = __fdividef(1.f, sum + 1.e-6f);

for (int ti = tid; ti <= act_time_step; ti += THREADS_PER_BLOCK) {
convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum);
}
Expand All @@ -566,10 +509,6 @@ __global__ void masked_multihead_attention_kernel(
params.max_seq_length * Dh +
bi * params.gqa_group_size * params.max_seq_length * Dh +
kv_hi * params.max_seq_length * Dh + vi];
// T *v_cache_batch = &params.cache_kv[params.batch_size * params.num_head *
// params.max_seq_length * Dh +
// bbhi * params.max_seq_length * Dh +
// vi];

#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
using V_vec_acum = typename V_vec_acum_fp32_<V_vec>::Type;
Expand All @@ -583,17 +522,8 @@ __global__ void masked_multihead_attention_kernel(
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
if (Dh == Dh_MAX || vi < Dh) {
for (int ti = vo; ti < act_time_step; ti += V_PER_ITER) {
const int beam_offset =
beam_offsets
? beam_offsets[ti] * params.num_head * params.max_seq_length * Dh
: 0;
V_vec v;
if (beam_offset) {
// v = *reinterpret_cast<const V_vec *>(
// &v_cache_batch[beam_offset + ti * Dh]);
} else {
v = *reinterpret_cast<const V_vec *>(&v_cache[ti * Dh]);
}
v = *reinterpret_cast<const V_vec *>(&v_cache[ti * Dh]);
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
float logit = logits_smem[ti];
out = fma(logit, cast_to_float(v), out);
Expand All @@ -617,8 +547,6 @@ __global__ void masked_multihead_attention_kernel(
V_vec v_bias;
zero(v_bias);
if (vo == (act_time_step % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) {
// V_vec v = *reinterpret_cast<const V_vec *>(
// &params.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]);
V_vec v;
load_func.template load<V_vec>(v,
params.num_head * Dh +
Expand Down Expand Up @@ -668,15 +596,11 @@ __global__ void masked_multihead_attention_kernel(

if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) {
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
// convert_from_float(*reinterpret_cast<V_vec *>(&params.out[bhi * Dh +
// vi]),
// out);
V_vec tmp_out;
convert_from_float(tmp_out, out);
store_func.template store<V_vec>(tmp_out,
thi != -1 ? thi * Dh + vi : bhi * Dh + vi);
#else
// *reinterpret_cast<V_vec *>(&params.out[bhi * Dh + vi]) = out;
store_func.template store<V_vec>(out,
thi != -1 ? thi * Dh + vi : bhi * Dh + vi);
#endif
Expand Down Expand Up @@ -943,18 +867,12 @@ __global__ void write_cache_k_kernel(T *cache_k,
const int seq_len,
const int max_seq_len) {
const int bi = blockIdx.y;
const int seq_len_now = seq_len;
const int len = seq_len_now;
if (len == 0) {
return;
}

const int hi = blockIdx.z;
constexpr int X_ELEMS = VEC_16B / sizeof(T);

// [bsz, num_head, seq_len, dim_head/x, x]
auto k_src = reinterpret_cast<const uint4 *>(
k + bi * num_head * seq_len_now * dim_head + hi * seq_len_now * dim_head);
k + bi * num_head * seq_len * dim_head + hi * seq_len * dim_head);
// [bsz, num_head, dim_head/x, max_seq_len, x]
auto k_dst = reinterpret_cast<uint4 *>(
cache_k + bi * num_head * max_seq_len * dim_head +
Expand All @@ -974,7 +892,7 @@ __global__ void write_cache_k_kernel(T *cache_k,
idx = idx / max_seq_len;
const int k_vec_id = idx % dim_head_div_x;

if (k_seq_len_id < len) {
if (k_seq_len_id < seq_len) {
k_dst[out_idx] = k_src[k_seq_len_id * dim_head_div_x + k_vec_id];
}
}
Expand All @@ -987,17 +905,11 @@ __global__ void write_cache_v_kernel(T *cache_v,
const int seq_len,
const int max_seq_len) {
const int bi = blockIdx.y;
const int seq_len_now = seq_len;
const int len = seq_len_now;
if (len == 0) {
return;
}

const int hi = blockIdx.z;

// [bsz, num_head, seq_len, dim_head/x, x]
auto v_src = reinterpret_cast<const uint4 *>(
v + bi * num_head * seq_len_now * dim_head + hi * seq_len_now * dim_head);
v + bi * num_head * seq_len * dim_head + hi * seq_len * dim_head);
// [bsz, num_head, max_seq_len, dim_head/x, x]
auto v_dst = reinterpret_cast<uint4 *>(
cache_v + bi * num_head * max_seq_len * dim_head +
Expand All @@ -1007,7 +919,7 @@ __global__ void write_cache_v_kernel(T *cache_v,
constexpr int X_ELEMS = VEC_16B / sizeof(T);
const int dim_head_div_x = dim_head / X_ELEMS;

if (idx >= dim_head_div_x * len) return;
if (idx >= dim_head_div_x * seq_len) return;

v_dst[idx] = v_src[idx];
}
Expand All @@ -1034,7 +946,7 @@ void write_cache_kv(const phi::GPUContext &dev_ctx,
"dim_head=%d must be divisible by vec_size=%d", dim_head, x));

int max_size = max_seq_len * dim_head / x;
int size = (seq_len * dim_head) / x;
int size = seq_len * dim_head / x;
dim3 grid(div_up(max_size, block_sz), bsz, num_head);
dim3 grid_v(div_up(size, block_sz), bsz, num_head);

Expand Down

0 comments on commit ecca46f

Please sign in to comment.