Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 56 additions & 69 deletions custom_ops/gpu_ops/append_attn/append_attention_func.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -383,56 +383,71 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
}
}

template <uint32_t block_size,
template <SharedMemFillMode fill_mode,
uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_k_dynamic_scale(
T* k_smem_scale,
T* cache_k_reg,
__device__ __forceinline__ void produce_kv_dynamic_scale_gmem2smem_async(
smem_t kv_scale_smem,
const int* block_table_now,
const T* cache_k_scale,
const T* cache_kv_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
const uint32_t tid = ty * 32 + tx;
if constexpr (NUM_WARP_Q == 4) {
// 4 warps shared block_size
const uint32_t tid = ty * 32 + tx;
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_k_scale_now = cache_k_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size;
if (tid < block_size) {
k_smem_scale[tid] = cache_k_scale_now[tid];
if (tid < block_size / 8) {
const T* cache_k_scale_now = cache_kv_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size + tid * 8;
const int kv_idx_this_thread = kv_idx + tid * 8;
kv_scale_smem.load_128b_async<fill_mode>(
tid, cache_k_scale_now, kv_idx_this_thread < chunk_end);
}
__syncthreads();
} else {
// 1 warp 32 tokens
if (tid < block_size / 8 * 2) {
const uint32_t kv_idx_now = kv_idx + block_size * tid / 8;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const int kv_idx_this_thread = kv_idx + tid * 8;
const T* cache_k_scale_now = cache_kv_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size + tid % 8 * 8;
kv_scale_smem.load_128b_async<fill_mode>(
tid, cache_k_scale_now, kv_idx_this_thread < chunk_end);
}
}
}

template <uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_k_dynamic_scale_smem2reg(
T* k_smem_scale, T* cache_k_reg) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
if constexpr (NUM_WARP_Q == 4) {
// 4 warps shared block_size
const uint32_t row_id = tx / 4;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_k_reg[fz * 2] = k_smem_scale[fz * 16 + row_id];
cache_k_reg[fz * 2 + 1] = k_smem_scale[fz * 16 + row_id + 8];
const uint32_t scale_idx = fz * 16 + row_id;
cache_k_reg[fz * 2] = k_smem_scale[scale_idx];
cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8];
}
} else {
// 1 warp 32 tokens
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_k_scale_now = cache_k_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size;
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
if (kv_idx_this_thread < chunk_end) {
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
} else {
k_smem_scale[ty * 32 + tx] = 0;
}
__syncwarp();
const uint32_t row_id = tx / 4;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_k_reg[fz * 2] = k_smem_scale[ty * 32 + fz * 16 + row_id];
cache_k_reg[fz * 2 + 1] = k_smem_scale[ty * 32 + fz * 16 + row_id + 8];
const uint32_t scale_idx = ty * 32 + fz * 16 + row_id;
cache_k_reg[fz * 2] = k_smem_scale[scale_idx];
cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8];
}
}
}
Expand All @@ -441,57 +456,29 @@ template <uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_v_dynamic_scale(
T* v_smem_scale,
T* cache_v_reg,
const int* block_table_now,
const T* cache_v_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end) {
__device__ __forceinline__ void produce_v_dynamic_scale_smem2reg(
T* v_smem_scale, T* cache_v_reg) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;

if constexpr (NUM_WARP_Q == 4) {
// 4 warps shared block_size
const uint32_t tid = ty * 32 + tx;
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_v_scale_now = cache_v_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size;
if (tid < block_size) {
v_smem_scale[tid] = cache_v_scale_now[tid];
}
__syncthreads();
const uint32_t row_id = tx % 4 * 2;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_v_reg[fz * 4] = v_smem_scale[fz * 16 + row_id];
cache_v_reg[fz * 4 + 1] = v_smem_scale[fz * 16 + row_id + 1];
cache_v_reg[fz * 4 + 2] = v_smem_scale[fz * 16 + row_id + 8];
cache_v_reg[fz * 4 + 3] = v_smem_scale[fz * 16 + row_id + 9];
const uint32_t scale_idx = fz * 16 + row_id;
cache_v_reg[fz * 4] = v_smem_scale[scale_idx];
cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1];
cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8];
cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9];
}
} else {
// 1 warp 32 tokens
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_v_scale_now = cache_v_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size;
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
if (kv_idx_this_thread < chunk_end) {
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
} else {
v_smem_scale[ty * 32 + tx] = 0;
}
__syncwarp();
const uint32_t row_id = tx % 4 * 2;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_v_reg[fz * 4] = v_smem_scale[ty * 32 + fz * 16 + row_id];
cache_v_reg[fz * 4 + 1] = v_smem_scale[ty * 32 + fz * 16 + row_id + 1];
cache_v_reg[fz * 4 + 2] = v_smem_scale[ty * 32 + fz * 16 + row_id + 8];
cache_v_reg[fz * 4 + 3] = v_smem_scale[ty * 32 + fz * 16 + row_id + 9];
const uint32_t scale_idx = ty * 32 + fz * 16 + row_id;
cache_v_reg[fz * 4] = v_smem_scale[scale_idx];
cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1];
cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8];
cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9];
}
}
}
Expand Down
Loading
Loading