Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bf16 compatibility for omitted kernel (#396) #397

Merged
merged 1 commit into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions lightllm/common/basemodel/triton_kernel/dequantize_gemm_int4.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def matmul4_kernel(
b = (b >> shifter[:, None]) & 0xF # Extract the 4-bit values
b = b * scales[None, :] - zeros[None, :] # Scale and shift
# print("data type", a, b)
accumulator += tl.dot(a, b.to(tl.float16))
accumulator += tl.dot(a, b.to(a.dtype))
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
c = accumulator.to(tl.float16)
c = accumulator.to(c_ptr.dtype.element_ty)
# Store the result
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
Expand Down Expand Up @@ -154,7 +154,7 @@ def matmul_dequantize_int4_gptq(x: torch.FloatTensor, qweight: torch.IntTensor,
# output = torch.empty((M, N), device='cuda', dtype=torch.float16)
if output is None:
inplace = False
output = torch.empty((M, N), device=x.device, dtype=torch.float16)
output = torch.empty((M, N), device=x.device, dtype=x.dtype)
else:
inplace = True

Expand Down Expand Up @@ -281,14 +281,14 @@ def matmul_kernel(
# We accumulate along the K dimension.
int_b = (b >> b_shift_bits) & 0xF
int_bzp = (bzp >> bzp_shift_bits) & 0xF
b = ((int_b - int_bzp) * bs).to(tl.float16)
accumulator += tl.dot(a.to(tl.float16), b.to(tl.float16))
b = ((int_b - int_bzp) * bs).to(a.dtype)
accumulator += tl.dot(a, b.to(a.dtype))
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
b_ptrs += (BLOCK_SIZE_K * SPLIT_K * stride_bk // 8) # assert BLOCK_SIZE_K % 8 == 0
# You can fuse arbitrary activation functions here
# while the accumulator is still in FP32!
c = accumulator.to(tl.float16)
c = accumulator.to(c_ptr.dtype.element_ty)
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
Expand All @@ -309,7 +309,7 @@ def matmul_dequantize_int4_s2(x: torch.FloatTensor, qweight: torch.IntTensor, sc
M, K = x.shape
N = scales.shape[1]
if output is None:
output = torch.zeros((M, N), device=x.device, dtype=torch.float16)
output = torch.zeros((M, N), device=x.device, dtype=x.dtype)
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
META['SPLIT_K'],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _fwd_kernel_destindex_copy_quantize_kv(
src_data = tl.load(K + cur_index * stride_k_bs + offs_h[:, None] * stride_k_h + stride_k_d * offs_d[None, :],
mask=offs_h[:, None] < head_num, other=0.0)
abs_data = tl.abs(src_data)
data_scale = (tl.max(abs_data, axis=1) / 127.).to(tl.float16)[:, None]
data_scale = (tl.max(abs_data, axis=1) / 127.).to(Out_scale.dtype.element_ty)[:, None]
q_src_data = (src_data / data_scale).to(tl.int8)
o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]
os_ptrs = Out_scale + dest_index * stride_os_bs + stride_os_h * offs_h[:, None]
Expand Down
6 changes: 3 additions & 3 deletions lightllm/common/basemodel/triton_kernel/quantize_gemm_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def quantize_int8_perrow_kernel(

def quantize_int8_perrow(fpa):
a = torch.empty(fpa.shape, device=fpa.device, dtype=torch.int8)
a_scale = torch.empty(fpa.shape[0], device=fpa.device, dtype=torch.float16)
a_scale = torch.empty(fpa.shape[0], device=fpa.device, dtype=fpa.dtype)
M, K = fpa.shape
BLOCK_SIZE_M = 1
BLOCK_SIZE_K = triton.next_power_of_2(K)
Expand Down Expand Up @@ -175,7 +175,7 @@ def matmul_kernel(
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
# You can fuse arbitrary activation functions here
# while the accumulator is still in FP32!
c = (accumulator.to(tl.float32) * a_scale[:, None] * b_scale[None, :]).to(tl.float16)
c = (accumulator.to(tl.float32) * a_scale[:, None] * b_scale[None, :]).to(c_ptr.dtype.element_ty)
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
Expand All @@ -201,7 +201,7 @@ def matmul_int8(a, a_scale, b, b_scale, out=None):
K, N = b.shape
# Allocates output.
if out == None:
c = torch.zeros((M, N), device=a.device, dtype=torch.float16)
c = torch.zeros((M, N), device=a.device, dtype=a.dtype)
else:
c = out.fill_(0.)
grid = lambda META: (
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/bloom/triton_kernel/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _layer_norm_fwd_fused(
x_hat = (x - mean) * rstd
y = x_hat * w + b
# Write output
tl.store(Y + cols, y.to(tl.float16), mask=mask)
tl.store(Y + cols, y.to(Y.dtype.element_ty), mask=mask)


# def layernorm_forward(x, weight, bias, eps):
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/llama/triton_kernel/ppl_int4kv_copy_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _fwd_kernel_destindex_copy_quantize_int4_kv(
abs_data_0 = tl.abs(src_data_0)
abs_data_1 = tl.abs(src_data_1)

data_scale = (tl.maximum(tl.max(abs_data_0, axis=1), tl.max(abs_data_1, axis=1)) / 7.0).to(tl.float16)
data_scale = (tl.maximum(tl.max(abs_data_0, axis=1), tl.max(abs_data_1, axis=1)) / 7.0).to(Out_scale.dtype.element_ty)
q_src_data_0 = (src_data_0 / data_scale[:, None]).to(tl.int8)
q_src_data_0 = tl.where(q_src_data_0 > 7, 7, q_src_data_0)
q_src_data_0 = tl.where(q_src_data_0 < -7, -7, q_src_data_0)
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/llama/triton_kernel/ppl_quant_copy_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _fwd_kernel_destindex_copy_quantize_kv(
other=0.0,
)
abs_data = tl.abs(src_data)
data_scale = (tl.max(abs_data, axis=1) / 127.0).to(tl.float16)
data_scale = (tl.max(abs_data, axis=1) / 127.0).to(Out_scale.dtype.element_ty)
q_src_data = (src_data / data_scale[:, None]).to(tl.int8)

o_ptrs = Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _fwd_kernel_token_att2(
)
acc += tl.sum(p_value[:, None] * v_value, 0)

acc = acc.to(tl.float16)
acc = acc.to(Out.dtype.element_ty)
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
Expand Down Expand Up @@ -165,7 +165,7 @@ def _fwd_kernel_token_att2_int8v(
)
acc += tl.sum(p_value[:, None] * v_value * vs_value, 0)

acc = acc.to(tl.float16)
acc = acc.to(Out.dtype.element_ty)
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
Expand Down
Loading