Skip to content

Commit

Permalink
bf16 compatibility for omitted kernel (#396)
Browse files Browse the repository at this point in the history
Co-authored-by: baishihao <baishihao@sensetime.com>
Co-authored-by: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com>
  • Loading branch information
3 people committed Apr 15, 2024
1 parent bcb3212 commit 4b037f6
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 16 deletions.
14 changes: 7 additions & 7 deletions lightllm/common/basemodel/triton_kernel/dequantize_gemm_int4.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def matmul4_kernel(
b = (b >> shifter[:, None]) & 0xF # Extract the 4-bit values
b = b * scales[None, :] - zeros[None, :] # Scale and shift
# print("data type", a, b)
accumulator += tl.dot(a, b.to(tl.float16))
accumulator += tl.dot(a, b.to(a.dtype))
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
c = accumulator.to(tl.float16)
c = accumulator.to(c_ptr.dtype.element_ty)
# Store the result
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
Expand Down Expand Up @@ -154,7 +154,7 @@ def matmul_dequantize_int4_gptq(x: torch.FloatTensor, qweight: torch.IntTensor,
# output = torch.empty((M, N), device='cuda', dtype=torch.float16)
if output is None:
inplace = False
output = torch.empty((M, N), device=x.device, dtype=torch.float16)
output = torch.empty((M, N), device=x.device, dtype=x.dtype)
else:
inplace = True

Expand Down Expand Up @@ -281,14 +281,14 @@ def matmul_kernel(
# We accumulate along the K dimension.
int_b = (b >> b_shift_bits) & 0xF
int_bzp = (bzp >> bzp_shift_bits) & 0xF
b = ((int_b - int_bzp) * bs).to(tl.float16)
accumulator += tl.dot(a.to(tl.float16), b.to(tl.float16))
b = ((int_b - int_bzp) * bs).to(a.dtype)
accumulator += tl.dot(a, b.to(a.dtype))
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
b_ptrs += (BLOCK_SIZE_K * SPLIT_K * stride_bk // 8) # assert BLOCK_SIZE_K % 8 == 0
# You can fuse arbitrary activation functions here
# while the accumulator is still in FP32!
c = accumulator.to(tl.float16)
c = accumulator.to(c_ptr.dtype.element_ty)
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
Expand All @@ -309,7 +309,7 @@ def matmul_dequantize_int4_s2(x: torch.FloatTensor, qweight: torch.IntTensor, sc
M, K = x.shape
N = scales.shape[1]
if output is None:
output = torch.zeros((M, N), device=x.device, dtype=torch.float16)
output = torch.zeros((M, N), device=x.device, dtype=x.dtype)
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
META['SPLIT_K'],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _fwd_kernel_destindex_copy_quantize_kv(
src_data = tl.load(K + cur_index * stride_k_bs + offs_h[:, None] * stride_k_h + stride_k_d * offs_d[None, :],
mask=offs_h[:, None] < head_num, other=0.0)
abs_data = tl.abs(src_data)
data_scale = (tl.max(abs_data, axis=1) / 127.).to(tl.float16)[:, None]
data_scale = (tl.max(abs_data, axis=1) / 127.).to(Out_scale.dtype.element_ty)[:, None]
q_src_data = (src_data / data_scale).to(tl.int8)
o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]
os_ptrs = Out_scale + dest_index * stride_os_bs + stride_os_h * offs_h[:, None]
Expand Down
6 changes: 3 additions & 3 deletions lightllm/common/basemodel/triton_kernel/quantize_gemm_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def quantize_int8_perrow_kernel(

def quantize_int8_perrow(fpa):
a = torch.empty(fpa.shape, device=fpa.device, dtype=torch.int8)
a_scale = torch.empty(fpa.shape[0], device=fpa.device, dtype=torch.float16)
a_scale = torch.empty(fpa.shape[0], device=fpa.device, dtype=fpa.dtype)
M, K = fpa.shape
BLOCK_SIZE_M = 1
BLOCK_SIZE_K = triton.next_power_of_2(K)
Expand Down Expand Up @@ -175,7 +175,7 @@ def matmul_kernel(
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
# You can fuse arbitrary activation functions here
# while the accumulator is still in FP32!
c = (accumulator.to(tl.float32) * a_scale[:, None] * b_scale[None, :]).to(tl.float16)
c = (accumulator.to(tl.float32) * a_scale[:, None] * b_scale[None, :]).to(c_ptr.dtype.element_ty)
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
Expand All @@ -201,7 +201,7 @@ def matmul_int8(a, a_scale, b, b_scale, out=None):
K, N = b.shape
# Allocates output.
if out == None:
c = torch.zeros((M, N), device=a.device, dtype=torch.float16)
c = torch.zeros((M, N), device=a.device, dtype=a.dtype)
else:
c = out.fill_(0.)
grid = lambda META: (
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/bloom/triton_kernel/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _layer_norm_fwd_fused(
x_hat = (x - mean) * rstd
y = x_hat * w + b
# Write output
tl.store(Y + cols, y.to(tl.float16), mask=mask)
tl.store(Y + cols, y.to(Y.dtype.element_ty), mask=mask)


# def layernorm_forward(x, weight, bias, eps):
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/llama/triton_kernel/ppl_int4kv_copy_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _fwd_kernel_destindex_copy_quantize_int4_kv(
abs_data_0 = tl.abs(src_data_0)
abs_data_1 = tl.abs(src_data_1)

data_scale = (tl.maximum(tl.max(abs_data_0, axis=1), tl.max(abs_data_1, axis=1)) / 7.0).to(tl.float16)
data_scale = (tl.maximum(tl.max(abs_data_0, axis=1), tl.max(abs_data_1, axis=1)) / 7.0).to(Out_scale.dtype.element_ty)
q_src_data_0 = (src_data_0 / data_scale[:, None]).to(tl.int8)
q_src_data_0 = tl.where(q_src_data_0 > 7, 7, q_src_data_0)
q_src_data_0 = tl.where(q_src_data_0 < -7, -7, q_src_data_0)
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/llama/triton_kernel/ppl_quant_copy_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _fwd_kernel_destindex_copy_quantize_kv(
other=0.0,
)
abs_data = tl.abs(src_data)
data_scale = (tl.max(abs_data, axis=1) / 127.0).to(tl.float16)
data_scale = (tl.max(abs_data, axis=1) / 127.0).to(Out_scale.dtype.element_ty)
q_src_data = (src_data / data_scale[:, None]).to(tl.int8)

o_ptrs = Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _fwd_kernel_token_att2(
)
acc += tl.sum(p_value[:, None] * v_value, 0)

acc = acc.to(tl.float16)
acc = acc.to(Out.dtype.element_ty)
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
Expand Down Expand Up @@ -165,7 +165,7 @@ def _fwd_kernel_token_att2_int8v(
)
acc += tl.sum(p_value[:, None] * v_value * vs_value, 0)

acc = acc.to(tl.float16)
acc = acc.to(Out.dtype.element_ty)
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
Expand Down

0 comments on commit 4b037f6

Please sign in to comment.