diff --git a/format.py b/format.py index 4b886db0a..e73ea9656 100644 --- a/format.py +++ b/format.py @@ -1,6 +1,6 @@ import os import glob -for filename in glob.glob('./**/*.py', recursive=True): +for filename in glob.glob("./**/*.py", recursive=True): print(filename) os.system(f"autopep8 --max-line-length 140 --in-place --aggressive --aggressive {filename}") diff --git a/lightllm/common/basemodel/triton_kernel/dequantize_gemm_int4.py b/lightllm/common/basemodel/triton_kernel/dequantize_gemm_int4.py index 143d93b23..3ab57b2ed 100644 --- a/lightllm/common/basemodel/triton_kernel/dequantize_gemm_int4.py +++ b/lightllm/common/basemodel/triton_kernel/dequantize_gemm_int4.py @@ -7,40 +7,87 @@ @triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=8 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=8 + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8 + ), ], - key=['M', 'N', 'K', 'NO_GROUPS'], + key=["M", "N", "K", "NO_GROUPS"], ) @triton.jit def matmul4_kernel( - a_ptr, b_ptr, c_ptr, - scales_ptr, zeros_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - stride_scales_g, stride_scales_n, - stride_zeros_g, stride_zeros_n, - groupsize, NO_GROUPS: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, + a_ptr, + b_ptr, + c_ptr, + scales_ptr, + zeros_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_scales_g, + stride_scales_n, + stride_zeros_g, + stride_zeros_n, + groupsize, + NO_GROUPS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, ): """ Compute the matrix multiplication C = A x B. @@ -67,17 +114,19 @@ def matmul4_kernel( first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m + pid_n = (pid % num_pid_in_group) // group_size_m offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - a_mask = (offs_am[:, None] < M) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_mask = offs_am[:, None] < M # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) - scales_ptrs = scales_ptr + offs_bn * stride_scales_n # (BLOCK_SIZE_N,) + b_ptrs = b_ptr + ( + (offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + scales_ptrs = scales_ptr + offs_bn * stride_scales_n # (BLOCK_SIZE_N,) # zeros_ptrs is set up such that it repeats elements along the N axis 8 times - zeros_ptrs = zeros_ptr + ((offs_bn // infearure_per_bits) * stride_zeros_n) # (BLOCK_SIZE_N,) + zeros_ptrs = zeros_ptr + ((offs_bn // infearure_per_bits) * stride_zeros_n) # (BLOCK_SIZE_N,) # shifter is used to extract the 4 bits of each element in the 32-bit word from B and zeros shifter = (offs_k % infearure_per_bits) * bits zeros_shifter = (offs_bn % infearure_per_bits) * bits @@ -85,7 +134,7 @@ def matmul4_kernel( if NO_GROUPS: # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop scales = tl.load(scales_ptrs) # (BLOCK_SIZE_N,) - zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_N,), each element is repeated 8 times, int32 + zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_N,), each element is repeated 8 times, int32 # Unpack zeros zeros = (zeros >> zeros_shifter) & 0xF # (BLOCK_SIZE_N,) int32 # zeros = (zeros + 1) * scales # (BLOCK_SIZE_N,) float16 @@ -96,25 +145,25 @@ def matmul4_kernel( # It's calculating BLOCK_SIZE_M batches in parallel, and for each batch, BLOCK_SIZE_N outfeatures in parallel accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, num_pid_k): - a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated if not NO_GROUPS: g_id = k // (groupsize // BLOCK_SIZE_K) ptr = scales_ptrs + g_id * stride_scales_g scales = tl.load(ptr) # (BLOCK_SIZE_N,) - ptr = zeros_ptrs + g_id * stride_zeros_g # (BLOCK_SIZE_N,) - zeros = tl.load(ptr) # (BLOCK_SIZE_N,), each element is repeated 8 times, int32 + ptr = zeros_ptrs + g_id * stride_zeros_g # (BLOCK_SIZE_N,) + zeros = tl.load(ptr) # (BLOCK_SIZE_N,), each element is repeated 8 times, int32 # Unpack zeros zeros = (zeros >> zeros_shifter) & 0xF # (BLOCK_SIZE_N,) int32 - zeros = (zeros) * scales # (BLOCK_SIZE_N,) float16 + zeros = (zeros) * scales # (BLOCK_SIZE_N,) float16 # Now we need to unpack b (which is 4-bit values) into 32-bit values b = (b >> shifter[:, None]) & 0xF # Extract the 4-bit values b = b * scales[None, :] - zeros[None, :] # Scale and shift # print("data type", a, b) accumulator += tl.dot(a, b.to(a.dtype)) a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk - c = accumulator.to(c_ptr.dtype.element_ty) + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + c = accumulator.to(c_ptr.dtype.element_ty) # Store the result offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -123,118 +172,289 @@ def matmul4_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) -def matmul_dequantize_int4_gptq(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, group_size, output=None) -> torch.FloatTensor: - """ - Compute the matrix multiplication C = A x B + bias. - Where B is quantized using GPTQ and groupsize = -1 into 4-bit values. - - A is of shape (..., K) float16 - qweight is of shape (K//8, N) int32 - scales is of shape (G, N) float16 - qzeros is of shape (G, N//8) int32 - bias is of shape (1, N) float16 - - groupsize is the number of infeatures in each group. - G = K // groupsize - - Returns C of shape (..., N) float16 - """ - assert x.shape[-1] == (qweight.shape[0] * 8), "A must be a multiple of 8 in the last dimension" - assert x.is_contiguous(), "A must be contiguous" - - M, K = x.shape - N = qweight.shape[1] - # This is based on the possible BLOCK_SIZE_Ks - # assert K % 16 == 0 and K % 32 == 0 and K % 64 == 0 and K % 128 == 0, "K must be a multiple of 16, 32, 64, and 128" - # # This is based on the possible BLOCK_SIZE_Ns - # assert N % 16 == 0 and N % 32 == 0 and N % 64 == 0 and N % 128 == 0 and N % 256 == 0, "N must be a multiple of 16, 32, 64, 128, and 256" - # # This is based on the possible BLOCK_SIZE_Ks - # assert groupsize % 32 == 0 and groupsize % 64 == 0 and groupsize % 128 == 0, "groupsize must be a multiple of 32, 64, and 128" - - # output = torch.empty((M, N), device='cuda', dtype=torch.float16) - if output is None: - inplace = False - output = torch.empty((M, N), device=x.device, dtype=x.dtype) - else: - inplace = True - - grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - ) - matmul4_kernel[grid]( - x, qweight, output, - scales, qzeros, - M, N, K, - x.stride(0), x.stride(1), - qweight.stride(0), qweight.stride(1), - output.stride(0), output.stride(1), - scales.stride(0), scales.stride(1), - qzeros.stride(0), qzeros.stride(1), - group_size, group_size == K, +def matmul_dequantize_int4_gptq( + x: torch.FloatTensor, + qweight: torch.IntTensor, + scales: torch.FloatTensor, + qzeros: torch.IntTensor, + group_size, + output=None, +) -> torch.FloatTensor: + """ + Compute the matrix multiplication C = A x B + bias. + Where B is quantized using GPTQ and groupsize = -1 into 4-bit values. + + A is of shape (..., K) float16 + qweight is of shape (K//8, N) int32 + scales is of shape (G, N) float16 + qzeros is of shape (G, N//8) int32 + bias is of shape (1, N) float16 + + groupsize is the number of infeatures in each group. + G = K // groupsize + + Returns C of shape (..., N) float16 + """ + assert x.shape[-1] == (qweight.shape[0] * 8), "A must be a multiple of 8 in the last dimension" + assert x.is_contiguous(), "A must be contiguous" + + M, K = x.shape + N = qweight.shape[1] + # This is based on the possible BLOCK_SIZE_Ks + # assert K % 16 == 0 and K % 32 == 0 and K % 64 == 0 and K % 128 == 0, "K must be a multiple of 16, 32, 64, and 128" + # # This is based on the possible BLOCK_SIZE_Ns + # assert N % 16 == 0 and N % 32 == 0 and N % 64 == 0 and N % 128 == 0 and N % 256 == 0, "N must be a multiple of 16, 32, 64, 128, and 256" + # # This is based on the possible BLOCK_SIZE_Ks + # assert groupsize % 32 == 0 and groupsize % 64 == 0 and groupsize % 128 == 0, "groupsize must be a multiple of 32, 64, and 128" + + # output = torch.empty((M, N), device='cuda', dtype=torch.float16) + if output is None: + inplace = False + output = torch.empty((M, N), device=x.device, dtype=x.dtype) + else: + inplace = True + + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),) + matmul4_kernel[grid]( + x, + qweight, + output, + scales, + qzeros, + M, + N, + K, + x.stride(0), + x.stride(1), + qweight.stride(0), + qweight.stride(1), + output.stride(0), + output.stride(1), + scales.stride(0), + scales.stride(1), + qzeros.stride(0), + qzeros.stride(1), + group_size, + group_size == K, ) - # return output - if not inplace: - return output + # return output + if not inplace: + return output @triton.autotune( - configs=[ - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - - ], - key=['M', 'N', 'K'], - reset_to_zero=['c_ptr'] + configs=[ + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=2, + num_warps=8, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 8}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 16}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 512, "GROUP_SIZE_M": 16}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=2, + num_warps=8, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 8}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 16}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 512, "GROUP_SIZE_M": 16}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=2, + num_warps=8, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 8}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 16}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 512, "GROUP_SIZE_M": 16}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=2, + num_warps=8, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 8}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 16}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 512, "GROUP_SIZE_M": 16}, + num_stages=2, + num_warps=4, + ), + ], + key=["M", "N", "K"], + reset_to_zero=["c_ptr"], ) @triton.jit def matmul_kernel( - a_ptr, b_ptr, c_ptr, - bs_ptr, bzp_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - stride_bsk, stride_bsn, - stride_bzpk, stride_bzpn, + a_ptr, + b_ptr, + c_ptr, + bs_ptr, + bzp_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bsk, + stride_bsn, + stride_bzpk, + stride_bzpn, group_size, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr - ): + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, +): """ assert K % (BLOCK_SIZE_K * SPLIT_K) == 0 """ @@ -248,7 +468,7 @@ def matmul_kernel( first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m + pid_n = (pid % num_pid_in_group) // group_size_m offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = pid_sp_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) @@ -266,13 +486,19 @@ def matmul_kernel( accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): # Load the next block of A and B. - # [BLOCK_K, BLOCK_N] but repeated group_size times in K - bs_ptrs = bs_ptr + ((offs_k[:, None] + k * BLOCK_SIZE_K * SPLIT_K) // group_size) * stride_bsk \ + # [BLOCK_K, BLOCK_N] but repeated group_size times in K + bs_ptrs = ( + bs_ptr + + ((offs_k[:, None] + k * BLOCK_SIZE_K * SPLIT_K) // group_size) * stride_bsk + offs_bn[None, :] * stride_bsn + ) # [BLOCK_K, BLOCK_N] but repeated in K and N - bzp_ptrs = bzp_ptr + ((offs_k[:, None] + k * BLOCK_SIZE_K * SPLIT_K) // group_size) * stride_bzpk \ + bzp_ptrs = ( + bzp_ptr + + ((offs_k[:, None] + k * BLOCK_SIZE_K * SPLIT_K) // group_size) * stride_bzpk + (offs_bn[None, :] // 8) * stride_bzpn - b_shift_bits = (offs_k[:, None] % 8) * 4 # assert BLOCK_SIZE_K % 8 == 0 + ) + b_shift_bits = (offs_k[:, None] % 8) * 4 # assert BLOCK_SIZE_K % 8 == 0 bzp_shift_bits = (offs_bn[None, :] % 8) * 4 a = tl.load(a_ptrs) b = tl.load(b_ptrs) @@ -285,7 +511,7 @@ def matmul_kernel( accumulator += tl.dot(a, b.to(a.dtype)) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak - b_ptrs += (BLOCK_SIZE_K * SPLIT_K * stride_bk // 8) # assert BLOCK_SIZE_K % 8 == 0 + b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk // 8 # assert BLOCK_SIZE_K % 8 == 0 # You can fuse arbitrary activation functions here # while the accumulator is still in FP32! c = accumulator.to(c_ptr.dtype.element_ty) @@ -301,28 +527,44 @@ def matmul_kernel( tl.atomic_add(c_ptrs, c, mask=c_mask) -def matmul_dequantize_int4_s2(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, group_size: int = 128, output=None) -> torch.FloatTensor: - """ - """ +def matmul_dequantize_int4_s2( + x: torch.FloatTensor, + qweight: torch.IntTensor, + scales: torch.FloatTensor, + qzeros: torch.IntTensor, + group_size: int = 128, + output=None, +) -> torch.FloatTensor: + """ """ assert x.is_contiguous(), "A must be contiguous" - assert qweight.is_contiguous(), "B must be contiguous" + assert qweight.is_contiguous(), "B must be contiguous" M, K = x.shape N = scales.shape[1] if output is None: - output = torch.zeros((M, N), device=x.device, dtype=x.dtype) + output = torch.zeros((M, N), device=x.device, dtype=x.dtype) grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - META['SPLIT_K'], + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + META["SPLIT_K"], ) matmul_kernel[grid]( - x, qweight, output, - scales, qzeros, - M, N, K, - x.stride(0), x.stride(1), - qweight.stride(0), qweight.stride(1), - output.stride(0), output.stride(1), - scales.stride(0), scales.stride(1), - qzeros.stride(0), qzeros.stride(1), + x, + qweight, + output, + scales, + qzeros, + M, + N, + K, + x.stride(0), + x.stride(1), + qweight.stride(0), + qweight.stride(1), + output.stride(0), + output.stride(1), + scales.stride(0), + scales.stride(1), + qzeros.stride(0), + qzeros.stride(1), group_size, ) return output @@ -330,29 +572,39 @@ def matmul_dequantize_int4_s2(x: torch.FloatTensor, qweight: torch.IntTensor, sc @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=5, num_warps=2), ], - key=['K', 'N'], + key=["K", "N"], ) @triton.jit def dequantize_kernel( # Pointers to matrices - b_ptr, b_scale_ptr, b_zp_ptr, fpb_ptr, + b_ptr, + b_scale_ptr, + b_zp_ptr, + fpb_ptr, # Matrix dimensions - K, N, group_size, - stride_bk, stride_bn, - stride_bsk, stride_bsn, - stride_bzpk, stride_bzpn, - stride_fpbk, stride_fpbn, + K, + N, + group_size, + stride_bk, + stride_bn, + stride_bsk, + stride_bsn, + stride_bzpk, + stride_bzpn, + stride_fpbk, + stride_fpbn, # Meta-parameters - BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, ): """Dequantize tile [BLOCK_SIZE_K, BLOCK_SIZE_N] in full precision. We should assert BLOCK_SIZE_N % 8 == 0. @@ -383,16 +635,25 @@ def dequantize_int4(b, b_scale, b_zero_point, device, dtype, group_size): K = Kw * 8 fp_b = torch.ones((K, N), device=device, dtype=dtype) grid = lambda META: ( - triton.cdiv(K, META['BLOCK_SIZE_K']), - triton.cdiv(N, META['BLOCK_SIZE_N']), + triton.cdiv(K, META["BLOCK_SIZE_K"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), ) dequantize_kernel[grid]( - b, b_scale, b_zero_point, fp_b, - K, N, group_size, - b.stride(0), b.stride(1), - b_scale.stride(0), b_scale.stride(1), - b_zero_point.stride(0), b_zero_point.stride(1), - fp_b.stride(0), fp_b.stride(1) + b, + b_scale, + b_zero_point, + fp_b, + K, + N, + group_size, + b.stride(0), + b.stride(1), + b_scale.stride(0), + b_scale.stride(1), + b_zero_point.stride(0), + b_zero_point.stride(1), + fp_b.stride(0), + fp_b.stride(1), ) return fp_b @@ -430,7 +691,7 @@ def quantize_int4(weight, group_size=128, tp_rank=0): weight_max = torch.where(weight_max < 0, 0, weight_max) weight_min = weight.amin(-1, keepdim=True) weight_min = torch.where(weight_min > 0, 0, weight_min) - weight_range = weight_max - weight_min + weight_range = weight_max - weight_min scale = weight_range / (2 ** 4 - 1) zero_point = (-weight_min / scale).round().clamp(0, 15).to(torch.int32) weight = (weight / scale + zero_point).round().clamp(0, 15).to(torch.int32).view(h1, h2) @@ -447,7 +708,7 @@ def quantize_int4(weight, group_size=128, tp_rank=0): for pack in range(0, h1, 8): for i in range(8): int_zero_point[pack // 8, :] += zero_point[pack + i, :] << (i * 4) - ''' + """ fp_weight = torch.zeros(h1, h2).half().to(weight.device) for pack in range(0, h1 // 8): for i in range(8): @@ -462,9 +723,14 @@ def quantize_int4(weight, group_size=128, tp_rank=0): (int_zero_point[pack, :] >> (i * 4)) & 15 print((fp_zp - zero_point).abs().sum()) - ''' + """ weight = None - return int_weight.transpose(1, 0).contiguous(), scale.transpose(1, 0).contiguous(), int_zero_point.transpose(1, 0).contiguous(), group_size + return ( + int_weight.transpose(1, 0).contiguous(), + scale.transpose(1, 0).contiguous(), + int_zero_point.transpose(1, 0).contiguous(), + group_size, + ) def unpack_int4(weight, scale, zp): @@ -487,8 +753,9 @@ def unpack_int4(weight, scale, zp): for i in range(8): fp_zero_point[pack * 8 + i, :] = (zp[pack, :] >> (i * 4)) & 0xF for g in range(group_num): - fp_weight[:, g * group_size:(g + 1) * group_size] = (fp_weight[:, g * group_size:(g + 1) * group_size] - \ - fp_zero_point[:, g].unsqueeze(1)) * scale[:, g].unsqueeze(1) + fp_weight[:, g * group_size : (g + 1) * group_size] = ( + fp_weight[:, g * group_size : (g + 1) * group_size] - fp_zero_point[:, g].unsqueeze(1) + ) * scale[:, g].unsqueeze(1) return fp_weight.transpose(1, 0) @@ -496,8 +763,8 @@ def test_int4(M, K, N): import time print("M: {} K: {} N: {}".format(M, K, N)) - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) + a = torch.randn((M, K), device="cuda", dtype=torch.float16) + b = torch.randn((K, N), device="cuda", dtype=torch.float16) int_b, b_scale, b_zero_point, _ = quantize_int4(b) for _ in range(10): triton_output = matmul_dequantize_int4_s1(a, int_b, b_scale, b_zero_point) @@ -526,8 +793,8 @@ def test_int4(M, K, N): def test_correct_int4_s1(M=32, K=4096, N=4096): group_size = 128 - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) + a = torch.randn((M, K), device="cuda", dtype=torch.float16) + b = torch.randn((K, N), device="cuda", dtype=torch.float16) int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) cos = torch.nn.CosineSimilarity(0) fp_weight = dequantize_int4(int_b, b_scale, b_zero_point, a.device, a.dtype, group_size) @@ -541,8 +808,8 @@ def test_correct_int4_s1(M=32, K=4096, N=4096): def test_correct_int4_s2(M=32, K=4096, N=4096): group_size = 128 - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) + a = torch.randn((M, K), device="cuda", dtype=torch.float16) + b = torch.randn((K, N), device="cuda", dtype=torch.float16) int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) cos = torch.nn.CosineSimilarity(0) fp_weight = unpack_int4(int_b, b_scale, b_zero_point) @@ -556,8 +823,8 @@ def test_correct_int4_s2(M=32, K=4096, N=4096): def test_correct_int4_gptq(M=32, K=4096, N=4096): group_size = 128 - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) + a = torch.randn((M, K), device="cuda", dtype=torch.float16) + b = torch.randn((K, N), device="cuda", dtype=torch.float16) int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) cos = torch.nn.CosineSimilarity(0) fp_weight = unpack_int4(int_b, b_scale, b_zero_point) @@ -571,17 +838,15 @@ def test_correct_int4_gptq(M=32, K=4096, N=4096): @triton.testing.perf_report( triton.testing.Benchmark( - x_names=['M'], # Argument names to use as an x-axis for the plot - x_vals=[4, 8, 16, 32, 64, 128] + [ - 128 * i for i in range(2, 33, 2) - ], # Different possible values for `x_name` - line_arg='provider', # Argument name whose value corresponds to a different line in the plot + x_names=["M"], # Argument names to use as an x-axis for the plot + x_vals=[4, 8, 16, 32, 64, 128] + [128 * i for i in range(2, 33, 2)], # Different possible values for `x_name` + line_arg="provider", # Argument name whose value corresponds to a different line in the plot # Possible values for `line_arg` - line_vals=['cublas', 'triton-s1', 'dequantize', 'triton-s2', 'triton-gptq'], + line_vals=["cublas", "triton-s1", "dequantize", "triton-s2", "triton-gptq"], # Label name for the lines line_names=["cuBLAS", "Triton-s1", "Dequant(GB/s)", "Triton-s2", "Triton-gptq"], # Line styles - styles=[('green', '-'), ('blue', '-'), ('red', '-'), ('purple', '-'), ('yellow', '-')], + styles=[("green", "-"), ("blue", "-"), ("red", "-"), ("purple", "-"), ("yellow", "-")], ylabel="TFLOPS", # Label name for the y-axis plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot. args={}, @@ -590,27 +855,35 @@ def test_correct_int4_gptq(M=32, K=4096, N=4096): def benchmark(M, provider): K = 4096 N = 4096 - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) + a = torch.randn((M, K), device="cuda", dtype=torch.float16) + b = torch.randn((K, N), device="cuda", dtype=torch.float16) quantiles = [0.5, 0.2, 0.8] - if provider == 'cublas': + if provider == "cublas": ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - if provider == 'triton-s1': + if provider == "triton-s1": intb, b_scale, bzp, _ = quantize_int4(b, group_size=64) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul_dequantize_int4_s1(a, intb, b_scale, bzp, 64), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: matmul_dequantize_int4_s1(a, intb, b_scale, bzp, 64), quantiles=quantiles + ) perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - if provider == 'triton-s2': + if provider == "triton-s2": intb, b_scale, bzp, _ = quantize_int4(b, group_size=64) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul_dequantize_int4_s2(a, intb, b_scale, bzp, 64), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: matmul_dequantize_int4_s2(a, intb, b_scale, bzp, 64), quantiles=quantiles + ) perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - if provider == 'dequantize': + if provider == "dequantize": intb, b_scale, bzp, _ = quantize_int4(b, group_size=64) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: dequantize_int4(intb, b_scale, bzp, 'cuda', torch.float16, 64), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: dequantize_int4(intb, b_scale, bzp, "cuda", torch.float16, 64), quantiles=quantiles + ) perf = lambda ms: 2 * M * K * 1e-9 / (ms * 1e-3) - if provider == 'triton-gptq': + if provider == "triton-gptq": intb, b_scale, bzp, _ = quantize_int4(b, group_size=64) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul_dequantize_int4_gptq(a, intb, b_scale, bzp, 64), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: matmul_dequantize_int4_gptq(a, intb, b_scale, bzp, 64), quantiles=quantiles + ) perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) diff --git a/lightllm/common/basemodel/triton_kernel/dequantize_gemm_int8.py b/lightllm/common/basemodel/triton_kernel/dequantize_gemm_int8.py index e2c5c0dc9..f39936d5a 100644 --- a/lightllm/common/basemodel/triton_kernel/dequantize_gemm_int8.py +++ b/lightllm/common/basemodel/triton_kernel/dequantize_gemm_int8.py @@ -6,38 +6,42 @@ @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128}, num_stages=3, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 256}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}, num_stages=3, num_warps=4), + triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=5, num_warps=2), ], - key=['K', 'N'], + key=["K", "N"], ) - - @triton.jit def dequantize_kernel( # Pointers to matrices - b_ptr, b_scale_ptr, fpb_ptr, + b_ptr, + b_scale_ptr, + fpb_ptr, # Matrix dimensions - K, N, - stride_bk, stride_bn, - stride_fpbk, stride_fpbn, + K, + N, + stride_bk, + stride_bn, + stride_fpbk, + stride_fpbn, # Meta-parameters - BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, ): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) @@ -46,10 +50,12 @@ def dequantize_kernel( n_block_idx = tl.program_id(axis=1) offs_k = tl.arange(0, BLOCK_SIZE_K) offs_n = tl.arange(0, BLOCK_SIZE_N) - b_offs = (k_block_idx * BLOCK_SIZE_K + offs_k[:, None]) * stride_bk + \ - (n_block_idx * BLOCK_SIZE_N + offs_n[None, :]) * stride_bn - fpb_offs = (k_block_idx * BLOCK_SIZE_K + offs_k[:, None]) * stride_fpbk + \ - (n_block_idx * BLOCK_SIZE_N + offs_n[None, :]) * stride_fpbn + b_offs = (k_block_idx * BLOCK_SIZE_K + offs_k[:, None]) * stride_bk + ( + n_block_idx * BLOCK_SIZE_N + offs_n[None, :] + ) * stride_bn + fpb_offs = (k_block_idx * BLOCK_SIZE_K + offs_k[:, None]) * stride_fpbk + ( + n_block_idx * BLOCK_SIZE_N + offs_n[None, :] + ) * stride_fpbn bs_offs = n_block_idx * BLOCK_SIZE_N + offs_n[None, :] n_mask = n_block_idx * BLOCK_SIZE_N + offs_n[None, :] < N mask = (k_block_idx * BLOCK_SIZE_K + offs_k[:, None] < K) & n_mask @@ -72,14 +78,10 @@ def matmul_dequantize_int8(a, b, b_scale, out=None): c = out fp_b = torch.empty((K, N), device=a.device, dtype=a.dtype) grid = lambda META: ( - triton.cdiv(K, META['BLOCK_SIZE_K']), triton.cdiv(N, META['BLOCK_SIZE_N']), - ) - dequantize_kernel[grid]( - b, b_scale, fp_b, - K, N, - b.stride(0), b.stride(1), - fp_b.stride(0), fp_b.stride(1) + triton.cdiv(K, META["BLOCK_SIZE_K"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + dequantize_kernel[grid](b, b_scale, fp_b, K, N, b.stride(0), b.stride(1), fp_b.stride(0), fp_b.stride(1)) torch.mm(a, fp_b, out=c) return c @@ -87,7 +89,7 @@ def matmul_dequantize_int8(a, b, b_scale, out=None): def quantize_int8(weight, axis=0, tp_rank=0): # Weight shape: [H1, H2] # Scale shape: [H2] - scale = weight.abs().amax(axis, keepdim=True) / 127. + scale = weight.abs().amax(axis, keepdim=True) / 127.0 weight = (weight / scale).to(torch.int8) if axis == 0: weight = weight.t().contiguous().t() @@ -100,8 +102,8 @@ def test_int8(M, K, N): print("M: {} K: {} N: {}".format(M, K, N)) torch.manual_seed(0) - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) + a = torch.randn((M, K), device="cuda", dtype=torch.float16) + b = torch.randn((K, N), device="cuda", dtype=torch.float16) int_b, b_scale = quantize_int8(b) for _ in range(10): triton_output = matmul_dequantize_int8(a, int_b, b_scale.unsqueeze(0)) @@ -131,48 +133,46 @@ def test_int8(M, K, N): def test_correct_int8(M=512, K=4096, N=4096): import time - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) + a = torch.randn((M, K), device="cuda", dtype=torch.float16) + b = torch.randn((K, N), device="cuda", dtype=torch.float16) int_b, b_scale = quantize_int8(b) cos = torch.nn.CosineSimilarity(0) triton_output = matmul_dequantize_int8(a, int_b, b_scale) torch_output = torch.matmul(a, b) - print(f"triton_output={triton_output}") + print(f"triton_output={triton_output}") print(f"torch_output={torch_output}") print("Output cos ", cos(triton_output.flatten().to(torch.float32), torch_output.flatten().to(torch.float32))) @triton.testing.perf_report( triton.testing.Benchmark( - x_names=['M', 'N', 'K'], # Argument names to use as an x-axis for the plot - x_vals=[32, 64, 128, 256] + [ - 512 * i for i in range(1, 33) - ], # Different possible values for `x_name` - line_arg='provider', # Argument name whose value corresponds to a different line in the plot + x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plot + x_vals=[32, 64, 128, 256] + [512 * i for i in range(1, 33)], # Different possible values for `x_name` + line_arg="provider", # Argument name whose value corresponds to a different line in the plot # Possible values for `line_arg` - line_vals=['cublas', 'triton'], + line_vals=["cublas", "triton"], # Label name for the lines line_names=["cuBLAS", "Triton"], # Line styles - styles=[('green', '-'), ('blue', '-')], + styles=[("green", "-"), ("blue", "-")], ylabel="TFLOPS", # Label name for the y-axis plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot. args={}, ) ) - - def benchmark(M, N, K, provider): quantiles = [0.5, 0.2, 0.8] - if provider == 'cublas': - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) + if provider == "cublas": + a = torch.randn((M, K), device="cuda", dtype=torch.float16) + b = torch.randn((K, N), device="cuda", dtype=torch.float16) ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) - if provider == 'triton': - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) + if provider == "triton": + a = torch.randn((M, K), device="cuda", dtype=torch.float16) + b = torch.randn((K, N), device="cuda", dtype=torch.float16) intb, b_scale = quantize_int8(b) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul_dequantize_int8(a, intb, b_scale), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: matmul_dequantize_int8(a, intb, b_scale), quantiles=quantiles + ) perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) return perf(ms), perf(min_ms), perf(max_ms) @@ -201,9 +201,9 @@ def test_model_layer(bs, sqe_len, hidden, inter, tp): bs = 32 hidden = 4096 - inter = 11008 + inter = 11008 prefill_len = 512 decode_len = 1 tp = 1 test_model_layer(bs, prefill_len, hidden, inter, tp) - test_model_layer(bs, decode_len, hidden, inter, tp) \ No newline at end of file + test_model_layer(bs, decode_len, hidden, inter, tp) diff --git a/lightllm/common/basemodel/triton_kernel/quantize_gemm_int8.py b/lightllm/common/basemodel/triton_kernel/quantize_gemm_int8.py index 4f3f6a385..27e06c0cb 100644 --- a/lightllm/common/basemodel/triton_kernel/quantize_gemm_int8.py +++ b/lightllm/common/basemodel/triton_kernel/quantize_gemm_int8.py @@ -11,18 +11,24 @@ triton.Config({}, num_stages=2, num_warps=4), triton.Config({}, num_stages=2, num_warps=2), triton.Config({}, num_stages=2, num_warps=1), - ], - key=['K'], + ], + key=["K"], ) @triton.jit def quantize_int8_perrow_kernel( - fpa_ptr, a_ptr, as_ptr, - M, K, - stride_fpam, stride_fpak, - stride_am, stride_ak, + fpa_ptr, + a_ptr, + as_ptr, + M, + K, + stride_fpam, + stride_fpak, + stride_am, + stride_ak, stride_asm, # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, ): pid_m = tl.program_id(axis=0) offs_k = tl.arange(0, BLOCK_SIZE_K) @@ -35,7 +41,7 @@ def quantize_int8_perrow_kernel( fpa = tl.load(fpa_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) a_max = tl.maximum(a_max, tl.max(tl.abs(fpa), axis=1)) fpa_ptrs += BLOCK_SIZE_K * stride_fpak - a_scale = (a_max / 127.) + a_scale = a_max / 127.0 fpa_ptrs = fpa_ptr + offs_am[:, None] * stride_fpam + offs_k[None, :] * stride_fpak for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): fpa = tl.load(fpa_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) @@ -55,73 +61,227 @@ def quantize_int8_perrow(fpa): BLOCK_SIZE_K = triton.next_power_of_2(K) grid = (M // BLOCK_SIZE_M,) quantize_int8_perrow_kernel[grid]( - fpa, a, a_scale, - M, K, - fpa.stride(0), fpa.stride(1), - a.stride(0), a.stride(1), + fpa, + a, + a_scale, + M, + K, + fpa.stride(0), + fpa.stride(1), + a.stride(0), + a.stride(1), a_scale.stride(0), - BLOCK_SIZE_M, BLOCK_SIZE_K, + BLOCK_SIZE_M, + BLOCK_SIZE_K, ) return a, a_scale @triton.autotune( configs=[ - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 16}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 16}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 16}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 16}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 16}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 16}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 16}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 16}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"SPLIT_K": 1, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 16}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"SPLIT_K": 2, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 16}, + num_stages=2, + num_warps=4, + ), ], - key=['M', 'N', 'K'], - reset_to_zero=['c_ptr'] + key=["M", "N", "K"], + reset_to_zero=["c_ptr"], ) @triton.jit def matmul_kernel( # Pointers to matrices - a_ptr, as_ptr, b_ptr, bs_ptr, c_ptr, + a_ptr, + as_ptr, + b_ptr, + bs_ptr, + c_ptr, # Matrix dimensions - M, N, K, + M, + N, + K, # The stride variables represent how much to increase the ptr by when moving by 1 # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` # by to get the element one row down (A has M rows). - stride_am, stride_ak, + stride_am, + stride_ak, stride_asm, - stride_bk, stride_bn, + stride_bk, + stride_bn, stride_bsn, - stride_cm, stride_cn, + stride_cm, + stride_cn, # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, ): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) @@ -203,19 +363,28 @@ def matmul_int8(a, a_scale, b, b_scale, out=None): if out == None: c = torch.zeros((M, N), device=a.device, dtype=torch.float16) else: - c = out.fill_(0.) + c = out.fill_(0.0) grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - META['SPLIT_K'], + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + META["SPLIT_K"], ) matmul_kernel[grid]( - a, a_scale, b, b_scale, c, - M, N, K, - a.stride(0), a.stride(1), + a, + a_scale, + b, + b_scale, + c, + M, + N, + K, + a.stride(0), + a.stride(1), a_scale.stride(0), - b.stride(0), b.stride(1), + b.stride(0), + b.stride(1), b_scale.stride(0), - c.stride(0), c.stride(1), + c.stride(0), + c.stride(1), ) return c @@ -223,7 +392,7 @@ def matmul_int8(a, a_scale, b, b_scale, out=None): def quantize_int8(weight, axis=0, tp_rank=0): # Weight shape: [H1, H2] # Scale shape: [H2] - scale = weight.abs().amax(axis, keepdim=True) / 127. + scale = weight.abs().amax(axis, keepdim=True) / 127.0 weight = (weight / scale).to(torch.int8) # col major will accelerate i8xi8 kernel. if axis == 0: @@ -233,11 +402,14 @@ def quantize_int8(weight, axis=0, tp_rank=0): def test_correct_int8(M=32, N=4096, K=4096): - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) + a = torch.randn((M, K), device="cuda", dtype=torch.float16) + b = torch.randn((K, N), device="cuda", dtype=torch.float16) int_a, scale_a = quantize_int8_perrow(a) cos = torch.nn.CosineSimilarity(0) - print("Quantization cos", cos((int_a * scale_a.unsqueeze(1)).flatten().to(torch.float32), a.flatten().to(torch.float32))) + print( + "Quantization cos", + cos((int_a * scale_a.unsqueeze(1)).flatten().to(torch.float32), a.flatten().to(torch.float32)), + ) int_b, scale_b = quantize_int8(b, axis=0) triton_output = matmul_int8(int_a, scale_a, int_b, scale_b) torch_output = torch.matmul(a, b) @@ -252,8 +424,8 @@ def test_int8(M, K, N): print("M: {} K: {} N: {}".format(M, K, N)) torch.manual_seed(0) - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16).contiguous() + a = torch.randn((M, K), device="cuda", dtype=torch.float16) + b = torch.randn((K, N), device="cuda", dtype=torch.float16).contiguous() int_b, scale_b = quantize_int8(b, axis=0) for _ in range(10): # int_a, a_scale = quantize_int8(a, 1) @@ -263,7 +435,7 @@ def test_int8(M, K, N): iters = 512 t1 = time.time() for _ in range(iters): - #int_a, a_scale, _ = quantize_int8(a, 1) + # int_a, a_scale, _ = quantize_int8(a, 1) int_a, a_scale = quantize_int8_perrow(a) torch.cuda.synchronize() qt2 = time.time() @@ -275,8 +447,11 @@ def test_int8(M, K, N): triton_time = t2 - qt2 triton_tflops = 2 * M * N * K * 1e-12 / (triton_time / iters) quant_bandwith = 2 * M * K * 1e-9 / (quant_time / iters) - print("Triton time cost: {} (tflops {}) + quant: {} (bandwidth {})".format( - triton_time, triton_tflops, quant_time, quant_bandwith)) + print( + "Triton time cost: {} (tflops {}) + quant: {} (bandwidth {})".format( + triton_time, triton_tflops, quant_time, quant_bandwith + ) + ) for _ in range(10): torch_output = torch.matmul(a, b) torch.cuda.synchronize() @@ -294,17 +469,15 @@ def test_int8(M, K, N): @triton.testing.perf_report( triton.testing.Benchmark( - x_names=['M'], # Argument names to use as an x-axis for the plot - x_vals=[32, 64, 128, 256] + [ - 512 * i * 2 for i in range(1, 17) - ], # Different possible values for `x_name` - line_arg='provider', # Argument name whose value corresponds to a different line in the plot + x_names=["M"], # Argument names to use as an x-axis for the plot + x_vals=[32, 64, 128, 256] + [512 * i * 2 for i in range(1, 17)], # Different possible values for `x_name` + line_arg="provider", # Argument name whose value corresponds to a different line in the plot # Possible values for `line_arg` - line_vals=['cublas', 'triton-i8', 'triton-quant-i8', 'quant-perrow'], + line_vals=["cublas", "triton-i8", "triton-quant-i8", "quant-perrow"], # Label name for the lines line_names=["cuBLAS", "Triton-i8", "Triton-Quant-i8", "Quant-perrow(GB/s)"], # Line styles - styles=[('green', '-'), ('blue', '-'), ('red', '-'), ('purple', '-')], + styles=[("green", "-"), ("blue", "-"), ("red", "-"), ("purple", "-")], ylabel="TFLOPS", # Label name for the y-axis plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot. args={}, @@ -314,26 +487,30 @@ def benchmark(M, provider): K = 10240 N = 27392 * 2 // 8 quantiles = [0.5, 0.2, 0.8] - if provider == 'cublas': - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) + if provider == "cublas": + a = torch.randn((M, K), device="cuda", dtype=torch.float16) + b = torch.randn((K, N), device="cuda", dtype=torch.float16) ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - if provider == 'triton-i8': - a = torch.randn((M, K), device='cuda', dtype=torch.float16).to(torch.int8).contiguous() - b = torch.randn((K, N), device='cuda', dtype=torch.float16).to(torch.int8).contiguous() + if provider == "triton-i8": + a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(torch.int8).contiguous() + b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(torch.int8).contiguous() int_a, a_scale = quantize_int8(a, axis=1) int_b, b_scale = quantize_int8(b, axis=0) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul_int8(int_a, a_scale, int_b, b_scale), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: matmul_int8(int_a, a_scale, int_b, b_scale), quantiles=quantiles + ) perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - if provider == 'triton-quant-i8': - a = torch.randn((M, K), device='cuda', dtype=torch.float16).to(torch.int8).contiguous() - b = torch.randn((K, N), device='cuda', dtype=torch.float16).to(torch.int8).contiguous() + if provider == "triton-quant-i8": + a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(torch.int8).contiguous() + b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(torch.int8).contiguous() int_b, b_scale = quantize_int8(b, axis=0) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul_quantize_int8(a, int_b, b_scale), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: matmul_quantize_int8(a, int_b, b_scale), quantiles=quantiles + ) perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - if provider == 'quant-perrow': - a = torch.randn((M, K), device='cuda', dtype=torch.float16).to(torch.int8).contiguous() + if provider == "quant-perrow": + a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(torch.int8).contiguous() ms, min_ms, max_ms = triton.testing.do_bench(lambda: quantize_int8_perrow(a), quantiles=quantiles) perf = lambda ms: 2 * M * K * 1e-9 / (ms * 1e-3) return perf(ms), perf(min_ms), perf(max_ms) @@ -368,7 +545,7 @@ def test_model_layer(bs, sqe_len, hidden, inter, tp): bs = 32 hidden = 4096 - inter = 11008 + inter = 11008 prefill_len = 512 decode_len = 1 tp = 1 diff --git a/lightllm/common/basemodel/triton_kernel/redundancy_topk_ids_repair.py b/lightllm/common/basemodel/triton_kernel/redundancy_topk_ids_repair.py index ba48f414d..692332cc0 100644 --- a/lightllm/common/basemodel/triton_kernel/redundancy_topk_ids_repair.py +++ b/lightllm/common/basemodel/triton_kernel/redundancy_topk_ids_repair.py @@ -22,7 +22,7 @@ def _redundancy_topk_ids_repair_kernel( if ENABLE_COUNTER: tl.atomic_add(expert_counter_ptr + current_topk_ids, 1, mask=mask) - + # Remap original expert IDs to a new space that accounts for redundant expert slots. new_current_topk_ids = (current_topk_ids // ep_expert_num) * redundancy_expert_num + current_topk_ids diff --git a/lightllm/common/build_utils.py b/lightllm/common/build_utils.py index fd35c30a0..ed1b4e5be 100644 --- a/lightllm/common/build_utils.py +++ b/lightllm/common/build_utils.py @@ -1,4 +1,3 @@ - def repair_config(config, same_names): find_value = None for name in same_names: @@ -7,4 +6,4 @@ def repair_config(config, same_names): break for name in same_names: config[name] = find_value - return \ No newline at end of file + return diff --git a/lightllm/models/bloom/layer_weights/hf_load_utils.py b/lightllm/models/bloom/layer_weights/hf_load_utils.py index 01c4c5862..fda5041e4 100755 --- a/lightllm/models/bloom/layer_weights/hf_load_utils.py +++ b/lightllm/models/bloom/layer_weights/hf_load_utils.py @@ -6,16 +6,16 @@ def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None): if isinstance(data_type, str): - data_type = torch.float16 if data_type == 'fp16' else torch.float32 + data_type = torch.float16 if data_type == "fp16" else torch.float32 if pre_post_layer is not None: assert pre_post_layer.data_type_ == data_type, "type is not right" if transformer_layer_list is not None: assert transformer_layer_list[0].data_type_ == data_type, "type is not right" if weight_dict: new_w = {} - for k,v in weight_dict.items(): + for k, v in weight_dict.items(): if "transformer." in k: - new_w[k[len("transformer."):]] = v + new_w[k[len("transformer.") :]] = v else: new_w[k] = v del weight_dict @@ -29,21 +29,21 @@ def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_laye return use_safetensors = True files = os.listdir(weight_dir) - candidate_files = list(filter(lambda x : x.endswith('.safetensors'), files)) + candidate_files = list(filter(lambda x: x.endswith(".safetensors"), files)) if len(candidate_files) == 0: use_safetensors = False - candidate_files = list(filter(lambda x : x.endswith('.bin'), files)) + candidate_files = list(filter(lambda x: x.endswith(".bin"), files)) assert len(candidate_files) != 0, "can only support pytorch tensor and safetensors format for weights." for file_ in candidate_files: if use_safetensors: - weights = safe_open(os.path.join(weight_dir, file_), 'pt', 'cpu') + weights = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") weights = {k: weights.get_tensor(k) for k in weights.keys()} else: - weights = torch.load(os.path.join(weight_dir, file_), 'cpu') + weights = torch.load(os.path.join(weight_dir, file_), "cpu") new_w = {} - for k,v in weights.items(): + for k, v in weights.items(): if "transformer." in k: - new_w[k[len("transformer."):]] = v + new_w[k[len("transformer.") :]] = v else: new_w[k] = v del weights diff --git a/lightllm/models/bloom/triton_kernel/layernorm.py b/lightllm/models/bloom/triton_kernel/layernorm.py index 6911d707b..538bb2b13 100644 --- a/lightllm/models/bloom/triton_kernel/layernorm.py +++ b/lightllm/models/bloom/triton_kernel/layernorm.py @@ -24,15 +24,15 @@ def _layer_norm_fwd_fused( _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) - a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) _mean += a mean = tl.sum(_mean, axis=0) / N # Compute variance _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) - x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) - x = tl.where(cols < N, x - mean, 0.) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.0) _var += x * x var = tl.sum(_var, axis=0) / N rstd = 1 / tl.sqrt(var + eps) @@ -42,7 +42,7 @@ def _layer_norm_fwd_fused( mask = cols < N w = tl.load(W + cols, mask=mask).to(tl.float32) b = tl.load(B + cols, mask=mask).to(tl.float32) - x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) x_hat = (x - mean) * rstd y = x_hat * w + b # Write output @@ -72,17 +72,18 @@ def _layer_norm_fwd_fused( # BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) # return y + def layernorm_forward(x, weight, bias, eps): return torch.layer_norm(x, (x.shape[-1],), weight, bias, eps) -def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): +def test_layer_norm(M, N, dtype, eps=1e-5, device="cuda"): # create data x_shape = (M, N) - w_shape = (x_shape[-1], ) - weight = torch.rand(w_shape, dtype=dtype, device='cuda') - bias = torch.rand(w_shape, dtype=dtype, device='cuda') - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda') + w_shape = (x_shape[-1],) + weight = torch.rand(w_shape, dtype=dtype, device="cuda") + bias = torch.rand(w_shape, dtype=dtype, device="cuda") + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") # forward pass y_tri = layernorm_forward(x, weight, bias, eps) y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) diff --git a/lightllm/models/bloom/triton_kernel/token_attention_nopad_att1.py b/lightllm/models/bloom/triton_kernel/token_attention_nopad_att1.py index 923d6d83b..afc233012 100644 --- a/lightllm/models/bloom/triton_kernel/token_attention_nopad_att1.py +++ b/lightllm/models/bloom/triton_kernel/token_attention_nopad_att1.py @@ -7,15 +7,27 @@ @triton.jit def _fwd_kernel_token_att1( - Q, K, sm_scale, Alibi, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, # B_Start_Loc 保存的是如果连续存储时候的累加输入和 + Q, + K, + sm_scale, + Alibi, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, # B_Start_Loc 保存的是如果连续存储时候的累加输入和 Att_Out, - stride_req_to_tokens_b, stride_req_to_tokens_s, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - att_stride_h, att_stride_bs, - + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + att_stride_h, + att_stride_bs, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr + BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -40,7 +52,11 @@ def _fwd_kernel_token_att1( alibi_m = tl.load(Alibi + cur_head) q = tl.load(Q + off_q + start_mark) offs_n_new = cur_batch_start_index + offs_n - k_loc = tl.load(Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_id + stride_req_to_tokens_s * offs_n_new, mask=offs_n_new < cur_batch_end_index, other=0) + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_id + stride_req_to_tokens_s * offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) off_k = k_loc[:, None] * stride_kbs + cur_head * stride_kh + offs_d[None, :] * stride_kd k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) att_value = tl.sum(q[None, :] * k, 1) @@ -68,12 +84,25 @@ def token_att_fwd(q, k, att_out, alibi, Req_to_tokens, B_req_idx, B_Start_Loc, B num_warps = 2 _fwd_kernel_token_att1[grid]( - q, k, sm_scale, alibi, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, + q, + k, + sm_scale, + alibi, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, att_out, - Req_to_tokens.stride(0), Req_to_tokens.stride(1), - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - att_out.stride(0), att_out.stride(1), + Req_to_tokens.stride(0), + Req_to_tokens.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + att_out.stride(0), + att_out.stride(1), BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, num_warps=num_warps, @@ -88,7 +117,9 @@ def torch_att(xq, xk, bs, seqlen, num_head, head_dim): keys = xk xq = xq.transpose(1, 2) keys = keys.transpose(1, 2) - scores = (torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape(num_head, -1) + scores = ( + (torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape(num_head, -1) + ) print("s ", scores.shape) return scores @@ -99,4 +130,4 @@ def torch_att1(xq, xk, seqlen, num_head, head_dim): logics = torch.sum(xq * xk, dim=-1, keepdim=False) logics = logics.transpose(0, 1) / math.sqrt(head_dim) - return logics \ No newline at end of file + return logics diff --git a/lightllm/models/bloom/triton_kernel/token_attention_nopad_softmax.py b/lightllm/models/bloom/triton_kernel/token_attention_nopad_softmax.py index 25af80fab..23f12ed4a 100644 --- a/lightllm/models/bloom/triton_kernel/token_attention_nopad_softmax.py +++ b/lightllm/models/bloom/triton_kernel/token_attention_nopad_softmax.py @@ -6,11 +6,15 @@ @triton.jit def _fwd_kernel_token_softmax( - Logics, B_Start_Loc, B_Seqlen, + Logics, + B_Start_Loc, + B_Seqlen, Prob_Out, - stride_logic_h, stride_logic_bs, - stride_prob_h, stride_prob_bs, - BLOCK_SIZE: tl.constexpr + stride_logic_h, + stride_logic_bs, + stride_prob_h, + stride_prob_bs, + BLOCK_SIZE: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -19,16 +23,22 @@ def _fwd_kernel_token_softmax( cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - row = tl.load(Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, - mask=col_offsets < cur_batch_seq_len, other=-float('inf')).to(tl.float32) + row = tl.load( + Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, + mask=col_offsets < cur_batch_seq_len, + other=-float("inf"), + ).to(tl.float32) row_minus_max = row - tl.max(row, axis=0) numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator - tl.store(Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) - * stride_prob_bs, softmax_output, mask=col_offsets < cur_batch_seq_len) + tl.store( + Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) * stride_prob_bs, + softmax_output, + mask=col_offsets < cur_batch_seq_len, + ) return @@ -44,10 +54,14 @@ def token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len): num_warps = 16 _fwd_kernel_token_softmax[(batch, head_num)]( - Logics, B_Start_Loc, B_Seqlen, + Logics, + B_Start_Loc, + B_Seqlen, Prob_Out, - Logics.stride(0), Logics.stride(1), - Prob_Out.stride(0), Prob_Out.stride(1), + Logics.stride(0), + Logics.stride(1), + Prob_Out.stride(0), + Prob_Out.stride(1), num_warps=num_warps, BLOCK_SIZE=BLOCK_SIZE, ) @@ -107,7 +121,7 @@ def test2(): start = 0 for i in range(B): end = start + b_seq_len[i] - torch_o = Logics[:, start: end].reshape(H * 1, -1).softmax(-1).reshape(H, 1 * b_seq_len[i]) + torch_o = Logics[:, start:end].reshape(H * 1, -1).softmax(-1).reshape(H, 1 * b_seq_len[i]) start = end torch_out.append(torch_o) torch_out = torch.cat(torch_out, dim=-1) diff --git a/lightllm/models/llama/layer_weights/ds_load_utils.py b/lightllm/models/llama/layer_weights/ds_load_utils.py index 091c056ca..560b84544 100644 --- a/lightllm/models/llama/layer_weights/ds_load_utils.py +++ b/lightllm/models/llama/layer_weights/ds_load_utils.py @@ -3,36 +3,39 @@ import os import gc -def load_ds_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None, prefix="", num_layer=0): + +def load_ds_weights( + data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None, prefix="", num_layer=0 +): if weight_dict: return weight_dict files = os.listdir(weight_dir) - candidate_files = sorted(list(filter(lambda x : x.endswith('.pt') and x.startswith('layer'), files))) + candidate_files = sorted(list(filter(lambda x: x.endswith(".pt") and x.startswith("layer"), files))) assert len(candidate_files) != 0, "can only support pytorch tensor format for weights." if weight_dict: weights_all = weight_dict else: weights_all = {} for file_ in candidate_files: - file_split = file_.split('-') - layer_num = int(file_split[0].split('_')[-1]) - rank_num = int(file_split[0].split('_')[-1]) - weights = torch.load(os.path.join(weight_dir, file_), 'cpu') - for k,v in weights.items(): - if layer_num >=3 and layer_num < 3 + num_layer: - k = prefix + str(layer_num - 3) + '.' + k + file_split = file_.split("-") + layer_num = int(file_split[0].split("_")[-1]) + rank_num = int(file_split[0].split("_")[-1]) + weights = torch.load(os.path.join(weight_dir, file_), "cpu") + for k, v in weights.items(): + if layer_num >= 3 and layer_num < 3 + num_layer: + k = prefix + str(layer_num - 3) + "." + k if layer_num == num_layer + 5: - k = 'lm_head.weight' + k = "lm_head.weight" if layer_num == num_layer + 4: - k = 'model.norm.weight' + k = "model.norm.weight" if layer_num == 1: - k = 'model.embed_tokens.weight' + k = "model.embed_tokens.weight" if k not in weights_all: - weights_all[k] = v + weights_all[k] = v else: - if 'q_proj' in k or 'k_proj' in k or 'v_proj' in k or 'gate_proj' in k or 'up_proj' in k: + if "q_proj" in k or "k_proj" in k or "v_proj" in k or "gate_proj" in k or "up_proj" in k: weights_all[k] = torch.cat([weights_all[k], v], dim=0) - elif 'o_proj' in k or 'down_proj' in k: + elif "o_proj" in k or "down_proj" in k: weights_all[k] = torch.cat([weights_all[k], v], dim=1) else: weights_all[k] = v @@ -45,5 +48,6 @@ def load_ds_weights(data_type, weight_dir, pre_post_layer=None, transformer_laye gc.collect() return -if __name__ == '__main__': - load_ds_weight('fp16', '/nvme/baishihao/llama7b', prefix='model.layers.', num_layer=32) \ No newline at end of file + +if __name__ == "__main__": + load_ds_weight("fp16", "/nvme/baishihao/llama7b", prefix="model.layers.", num_layer=32) diff --git a/lightllm/models/llama/triton_kernel/flash_decoding_stage1.py b/lightllm/models/llama/triton_kernel/flash_decoding_stage1.py index 86a3af103..74381eaec 100644 --- a/lightllm/models/llama/triton_kernel/flash_decoding_stage1.py +++ b/lightllm/models/llama/triton_kernel/flash_decoding_stage1.py @@ -2,21 +2,40 @@ import triton import triton.language as tl + @triton.jit def _fwd_kernel_flash_decode_stage1( - Q, K, V, sm_scale, Req_to_tokens, B_req_idx, B_Seqlen, - Mid_O, # [batch, head, seq_block_num, head_dim] - Mid_O_LogExpSum, #[batch, head, seq_block_num] - stride_req_to_tokens_b, stride_req_to_tokens_s, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - stride_vbs, stride_vh, stride_vd, - stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, - stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, + Q, + K, + V, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, gqa_group_size, - BLOCK_SEQ: tl.constexpr, + BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr + BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -30,11 +49,18 @@ def _fwd_kernel_flash_decode_stage1( cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d - - block_n_size = tl.where(cur_batch_end_index - cur_batch_start_index <= 0, 0, cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1) // BLOCK_N - + + block_n_size = ( + tl.where( + cur_batch_end_index - cur_batch_start_index <= 0, + 0, + cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1, + ) + // BLOCK_N + ) + offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) - + q = tl.load(Q + off_q) sum_exp = 0.0 @@ -51,7 +77,7 @@ def _fwd_kernel_flash_decode_stage1( att_value *= sm_scale att_value = tl.where(offs_n_new < cur_batch_end_index, att_value, float("-inf")) v = tl.load(V + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - + cur_max_logic = tl.max(att_value, axis=0) new_max_logic = tl.maximum(cur_max_logic, max_logic) @@ -62,7 +88,7 @@ def _fwd_kernel_flash_decode_stage1( sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0) max_logic = new_max_logic - + need_store = tl.where(block_n_size == 0, 0, 1) for _ in range(0, need_store, 1): off_mid_o = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + seq_start_block * stride_mid_os + offs_d @@ -73,7 +99,9 @@ def _fwd_kernel_flash_decode_stage1( @torch.no_grad() -def flash_decode_stage1(q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_batch, mid_out, mid_out_logsumexp, block_seq): +def flash_decode_stage1( + q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_batch, mid_out, mid_out_logsumexp, block_seq +): BLOCK_SEQ = block_seq BLOCK_N = 16 assert BLOCK_SEQ % BLOCK_N == 0 @@ -85,17 +113,35 @@ def flash_decode_stage1(q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_ batch, head_num = B_req_idx.shape[0], q.shape[1] grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) gqa_group_size = q.shape[1] // k.shape[1] - + _fwd_kernel_flash_decode_stage1[grid]( - q, k, v, sm_scale, Req_to_tokens, B_req_idx, B_Seqlen, + q, + k, + v, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, mid_out, mid_out_logsumexp, - Req_to_tokens.stride(0), Req_to_tokens.stride(1), - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), mid_out.stride(3), - mid_out_logsumexp.stride(0), mid_out_logsumexp.stride(1), mid_out_logsumexp.stride(2), + Req_to_tokens.stride(0), + Req_to_tokens.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + mid_out.stride(0), + mid_out.stride(1), + mid_out.stride(2), + mid_out.stride(3), + mid_out_logsumexp.stride(0), + mid_out_logsumexp.stride(1), + mid_out_logsumexp.stride(2), gqa_group_size, BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=Lk, @@ -103,4 +149,4 @@ def flash_decode_stage1(q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_ num_warps=1, num_stages=2, ) - return \ No newline at end of file + return diff --git a/lightllm/models/llama/triton_kernel/flash_decoding_stage2.py b/lightllm/models/llama/triton_kernel/flash_decoding_stage2.py index 81227f967..6cc39d5a7 100644 --- a/lightllm/models/llama/triton_kernel/flash_decoding_stage2.py +++ b/lightllm/models/llama/triton_kernel/flash_decoding_stage2.py @@ -6,14 +6,22 @@ @triton.jit def _fwd_kernel_flash_decode_stage2( B_Seqlen, - Mid_O, # [batch, head, seq_block_num, head_dim] - Mid_O_LogExpSum, #[batch, head, seq_block_num] - O, #[batch, head, head_dim] - stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, - stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, - stride_obs, stride_oh, stride_od, + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + O, # [batch, head, head_dim] + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + stride_obs, + stride_oh, + stride_od, BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr): + BLOCK_DMODEL: tl.constexpr, +): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -32,14 +40,14 @@ def _fwd_kernel_flash_decode_stage2( tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os) tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) new_max_logic = tl.maximum(tlogic, max_logic) - + old_scale = tl.exp(max_logic - new_max_logic) acc *= old_scale exp_logic = tl.exp(tlogic - new_max_logic) acc += exp_logic * tv sum_exp = sum_exp * old_scale + exp_logic max_logic = new_max_logic - + tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp) return @@ -50,15 +58,25 @@ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq): assert Lk in {16, 32, 64, 128} batch, head_num = mid_out.shape[0], mid_out.shape[1] grid = (batch, head_num) - + _fwd_kernel_flash_decode_stage2[grid]( - B_Seqlen, mid_out, mid_out_logexpsum, O, - mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), mid_out.stride(3), - mid_out_logexpsum.stride(0), mid_out_logexpsum.stride(1), mid_out_logexpsum.stride(2), - O.stride(0), O.stride(1), O.stride(2), + B_Seqlen, + mid_out, + mid_out_logexpsum, + O, + mid_out.stride(0), + mid_out.stride(1), + mid_out.stride(2), + mid_out.stride(3), + mid_out_logexpsum.stride(0), + mid_out_logexpsum.stride(1), + mid_out_logexpsum.stride(2), + O.stride(0), + O.stride(1), + O.stride(2), BLOCK_SEQ=block_seq, BLOCK_DMODEL=Lk, num_warps=4, num_stages=2, ) - return \ No newline at end of file + return diff --git a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage2.py b/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage2.py index 81227f967..6cc39d5a7 100644 --- a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage2.py +++ b/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage2.py @@ -6,14 +6,22 @@ @triton.jit def _fwd_kernel_flash_decode_stage2( B_Seqlen, - Mid_O, # [batch, head, seq_block_num, head_dim] - Mid_O_LogExpSum, #[batch, head, seq_block_num] - O, #[batch, head, head_dim] - stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, - stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, - stride_obs, stride_oh, stride_od, + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + O, # [batch, head, head_dim] + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + stride_obs, + stride_oh, + stride_od, BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr): + BLOCK_DMODEL: tl.constexpr, +): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -32,14 +40,14 @@ def _fwd_kernel_flash_decode_stage2( tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os) tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) new_max_logic = tl.maximum(tlogic, max_logic) - + old_scale = tl.exp(max_logic - new_max_logic) acc *= old_scale exp_logic = tl.exp(tlogic - new_max_logic) acc += exp_logic * tv sum_exp = sum_exp * old_scale + exp_logic max_logic = new_max_logic - + tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp) return @@ -50,15 +58,25 @@ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq): assert Lk in {16, 32, 64, 128} batch, head_num = mid_out.shape[0], mid_out.shape[1] grid = (batch, head_num) - + _fwd_kernel_flash_decode_stage2[grid]( - B_Seqlen, mid_out, mid_out_logexpsum, O, - mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), mid_out.stride(3), - mid_out_logexpsum.stride(0), mid_out_logexpsum.stride(1), mid_out_logexpsum.stride(2), - O.stride(0), O.stride(1), O.stride(2), + B_Seqlen, + mid_out, + mid_out_logexpsum, + O, + mid_out.stride(0), + mid_out.stride(1), + mid_out.stride(2), + mid_out.stride(3), + mid_out_logexpsum.stride(0), + mid_out_logexpsum.stride(1), + mid_out_logexpsum.stride(2), + O.stride(0), + O.stride(1), + O.stride(2), BLOCK_SEQ=block_seq, BLOCK_DMODEL=Lk, num_warps=4, num_stages=2, ) - return \ No newline at end of file + return diff --git a/lightllm/models/llama/triton_kernel/ppl_int4kv_copy_kv.py b/lightllm/models/llama/triton_kernel/ppl_int4kv_copy_kv.py index 7ba0f3b31..04c2fc3cf 100644 --- a/lightllm/models/llama/triton_kernel/ppl_int4kv_copy_kv.py +++ b/lightllm/models/llama/triton_kernel/ppl_int4kv_copy_kv.py @@ -47,7 +47,9 @@ def _fwd_kernel_destindex_copy_quantize_int4_kv( abs_data_0 = tl.abs(src_data_0) abs_data_1 = tl.abs(src_data_1) - data_scale = (tl.maximum(tl.max(abs_data_0, axis=1), tl.max(abs_data_1, axis=1)) / 7.0).to(Out_scale.dtype.element_ty) + data_scale = (tl.maximum(tl.max(abs_data_0, axis=1), tl.max(abs_data_1, axis=1)) / 7.0).to( + Out_scale.dtype.element_ty + ) q_src_data_0 = (src_data_0 / data_scale[:, None]).to(tl.int8) q_src_data_0 = tl.where(q_src_data_0 > 7, 7, q_src_data_0) q_src_data_0 = tl.where(q_src_data_0 < -7, -7, q_src_data_0) diff --git a/lightllm/models/llama/triton_kernel/rotary_emb.py b/lightllm/models/llama/triton_kernel/rotary_emb.py index c6d4f3010..60ccfa9da 100755 --- a/lightllm/models/llama/triton_kernel/rotary_emb.py +++ b/lightllm/models/llama/triton_kernel/rotary_emb.py @@ -116,7 +116,7 @@ def _rotary_kernel( @torch.no_grad() -def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.): +def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.0): total_len = q.shape[0] head_num_q, head_num_k = q.shape[1], k.shape[1] head_dim = int(q.shape[2] * partial_rotary_factor) diff --git a/lightllm/models/llama/triton_kernel/token_attention_nopad_softmax.py b/lightllm/models/llama/triton_kernel/token_attention_nopad_softmax.py index 5e6040ac5..265d7908d 100644 --- a/lightllm/models/llama/triton_kernel/token_attention_nopad_softmax.py +++ b/lightllm/models/llama/triton_kernel/token_attention_nopad_softmax.py @@ -5,11 +5,15 @@ @triton.jit def _fwd_kernel_token_softmax( - Logics, B_Start_Loc, B_Seqlen, + Logics, + B_Start_Loc, + B_Seqlen, Prob_Out, - stride_logic_h, stride_logic_bs, - stride_prob_h, stride_prob_bs, - BLOCK_SIZE: tl.constexpr + stride_logic_h, + stride_logic_bs, + stride_prob_h, + stride_prob_bs, + BLOCK_SIZE: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -18,18 +22,25 @@ def _fwd_kernel_token_softmax( cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - row = tl.load(Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, - mask=col_offsets < cur_batch_seq_len, other=-float('inf')).to(tl.float32) + row = tl.load( + Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, + mask=col_offsets < cur_batch_seq_len, + other=-float("inf"), + ).to(tl.float32) row_minus_max = row - tl.max(row, axis=0) numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator - tl.store(Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) - * stride_prob_bs, softmax_output, mask=col_offsets < cur_batch_seq_len) + tl.store( + Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) * stride_prob_bs, + softmax_output, + mask=col_offsets < cur_batch_seq_len, + ) return + @torch.no_grad() def token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len): BLOCK_SIZE = triton.next_power_of_2(max_input_len) @@ -42,15 +53,20 @@ def token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len): num_warps = 16 _fwd_kernel_token_softmax[(batch, head_num)]( - Logics, B_Start_Loc, B_Seqlen, + Logics, + B_Start_Loc, + B_Seqlen, Prob_Out, - Logics.stride(0), Logics.stride(1), - Prob_Out.stride(0), Prob_Out.stride(1), + Logics.stride(0), + Logics.stride(1), + Prob_Out.stride(0), + Prob_Out.stride(1), num_warps=num_warps, BLOCK_SIZE=BLOCK_SIZE, ) return + def test1(): import torch diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index 01e3a7268..f52e71481 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -227,14 +227,14 @@ def _init_datatype(self): def rot_pos_emb(self, grid_thw): pos_ids = [] s = self.spatial_merge_size - for _, h, w in grid_thw: + for t, h, w in grid_thw: pos_shape = (h // s, s, w // s, s) hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) hpos_ids = hpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() wpos_ids = wpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1)) + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() cos_full, sin_full = self.rotary_pos_emb(max_grid_size) diff --git a/lightllm/models/whisper/whisper_audio.py b/lightllm/models/whisper/whisper_audio.py index c5959ea1e..d3f7de4c4 100644 --- a/lightllm/models/whisper/whisper_audio.py +++ b/lightllm/models/whisper/whisper_audio.py @@ -12,6 +12,7 @@ from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed from lightllm.server.multimodal_params import AudioItem from rpyc.utils.classic import obtain +import pickle # tokenizer_class removed class WhisperProcessor(ProcessorMixin): @@ -190,7 +191,9 @@ def encode(self, audio_items: List[AudioItem]): audio_lens_after_cnn = np.array(audio_lens_after_cnn, dtype=np.int32) audio_token_num = (audio_lens_after_cnn - 2) // 2 + 1 - ready_audio = obtain(self.cache_client.root.get_items_embed(uuids)) + uuids_blob = pickle.dumps(uuids) + ready_audio = self.cache_client.root.get_items_embed(uuids_blob) + ready_audio = pickle.loads(ready_audio) ids_to_set = [] for i, ready in enumerate(ready_audio): if not ready: @@ -199,4 +202,5 @@ def encode(self, audio_items: List[AudioItem]): create_shm(get_shm_name_embed(uid), cur_embed_bytes) ids_to_set.append(uid) if ids_to_set: - self.cache_client.root.set_items_embed(ids=ids_to_set) + ids_to_set_blob = pickle.dumps(ids_to_set) + self.cache_client.root.set_items_embed(ids_to_set_blob) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 8fa519578..73381a438 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -324,6 +324,9 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--cache_capacity", type=int, default=200, help="cache server capacity for multimodal resources" ) + parser.add_argument('--max_tasks_per_worker', type=int, default=32, + help='Maximum number of tasks each worker thread can handle (default: 32)') + parser.add_argument("--concurrent_alloc_workers", type=int, default=2, help="max concurrent threadpool workers") parser.add_argument( "--data_type", type=str, diff --git a/lightllm/server/audioserver/manager.py b/lightllm/server/audioserver/manager.py index 709ea5ca2..22ea792be 100644 --- a/lightllm/server/audioserver/manager.py +++ b/lightllm/server/audioserver/manager.py @@ -103,7 +103,9 @@ async def loop_for_fwd(self): if disable_prompt_cache: ready_audio = [False] * len(audio_uuids) else: - ready_audio = obtain(self.cache_client.root.get_items_embed(audio_uuids)) + audio_uuids = pickle.dumps(audio_uuids) + ready_audio = self.cache_client.root.get_items_embed(audio_uuids) + ready_audio = pickle.loads(ready_audio) for audio, ready in zip(multimodal_params.audios, ready_audio): if not ready: diff --git a/lightllm/server/embed_cache/__init__.py b/lightllm/server/embed_cache/__init__.py index 12230a53c..9ce01001a 100644 --- a/lightllm/server/embed_cache/__init__.py +++ b/lightllm/server/embed_cache/__init__.py @@ -1 +1 @@ -from . import impl \ No newline at end of file +from . import impl diff --git a/lightllm/server/embed_cache/impl/__init__.py b/lightllm/server/embed_cache/impl/__init__.py index 2e5630a62..f5e5eb292 100644 --- a/lightllm/server/embed_cache/impl/__init__.py +++ b/lightllm/server/embed_cache/impl/__init__.py @@ -1 +1 @@ -from . import naive_memory_cache \ No newline at end of file +from . import naive_memory_cache diff --git a/lightllm/server/embed_cache/manager.py b/lightllm/server/embed_cache/manager.py index f0b68c45e..81c2277b1 100644 --- a/lightllm/server/embed_cache/manager.py +++ b/lightllm/server/embed_cache/manager.py @@ -7,6 +7,7 @@ from lightllm.server.embed_cache.impl.naive_memory_cache import InMemoryCache from rpyc.utils.classic import obtain from lightllm.utils.envs_utils import get_unique_server_name +import pickle class CacheServer(rpyc.Service): @@ -24,31 +25,72 @@ def on_disconnect(self, conn): # (to finalize the service, if needed) pass - def exposed_alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[list[dict]]: - md5sum_list = obtain(md5sum_list) - token_num_list = obtain(token_num_list) + # def exposed_alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[list[dict]]: + # md5sum_list = obtain(md5sum_list) + # token_num_list = obtain(token_num_list) + # record = self._impl.alloc(md5sum_list, token_num_list) + # return record + + # def exposed_release(self, ids: list[int]) -> None: + # ids = obtain(ids) + # return self._impl.release(ids) + + # def exposed_set_items_data(self, ids: list[int]) -> None: + # ids = obtain(ids) + # return self._impl.set_items_data(ids) + + # def exposed_get_items_data(self, ids: list[int]) -> list[bool]: + # ids = obtain(ids) + # return self._impl.get_items_data(ids) + + # def exposed_set_items_embed(self, ids: list[int]) -> None: + # ids = obtain(ids) + # return self._impl.set_items_embed(ids) + + # def exposed_get_items_embed(self, ids: list[int]) -> list[bool]: + # ids = obtain(ids) + # return self._impl.get_items_embed(ids) + + def exposed_alloc(self, batch_md5_token_nums: bytes) -> bytes: + """ + batch_md5_token_nums: pickle.dumps([(md5sum, token_num), ...]) + 返回: pickle.dumps(records) + """ + batch_requests = pickle.loads(batch_md5_token_nums) + md5sum_list = [obtain(md5) for md5, num in batch_requests] + token_num_list = [obtain(num) for md5, num in batch_requests] record = self._impl.alloc(md5sum_list, token_num_list) - return record + return pickle.dumps(record) - def exposed_release(self, ids: list[int]) -> None: - ids = obtain(ids) + def exposed_release(self, ids_blob: bytes) -> None: + ids = pickle.loads(ids_blob) + ids = [obtain(id) for id in ids] return self._impl.release(ids) - def exposed_set_items_data(self, ids: list[int]) -> None: - ids = obtain(ids) - return self._impl.set_items_data(ids) - - def exposed_get_items_data(self, ids: list[int]) -> list[bool]: - ids = obtain(ids) - return self._impl.get_items_data(ids) - - def exposed_set_items_embed(self, ids: list[int]) -> None: - ids = obtain(ids) - return self._impl.set_items_embed(ids) - - def exposed_get_items_embed(self, ids: list[int]) -> list[bool]: - ids = obtain(ids) - return self._impl.get_items_embed(ids) + def exposed_set_items_data(self, ids_blob: bytes) -> bytes: + ids = pickle.loads(ids_blob) + ids = [obtain(id) for id in ids] + status_list = self._impl.set_items_data(ids) + return pickle.dumps(status_list) + + def exposed_get_items_data(self, ids_blob: bytes) -> bytes: + ids = pickle.loads(ids_blob) + ids = [obtain(id) for id in ids] + status_list = self._impl.get_items_data(ids) + return pickle.dumps(status_list) + + def exposed_set_items_embed(self, ids_blob: bytes) -> None: + + ids = pickle.loads(ids_blob) + ids = [obtain(id) for id in ids] + status_list = self._impl.set_items_embed(ids) + return pickle.dumps(status_list) + + def exposed_get_items_embed(self, ids_blob: bytes) -> bytes: + ids = pickle.loads(ids_blob) + ids = [obtain(id) for id in ids] + status_list = self._impl.get_items_embed(ids) + return pickle.dumps(status_list) def start_cache_manager(port: int, args, pipe_writer): diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index dc04d081f..c32b077c6 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -10,6 +10,7 @@ import datetime import pickle from frozendict import frozendict +import concurrent.futures asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from typing import Union, List, Tuple, Dict, Optional @@ -34,6 +35,7 @@ from lightllm.utils.config_utils import get_vocab_size from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken +from lightllm.utils.infer_utils import calculate_cpu_time_async from rpyc.utils.classic import obtain logger = init_logger(__name__) @@ -114,13 +116,20 @@ def __init__( # If the timemark is not updated for a pre-set time, a prob request will be sent to the backend. self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark") self.latest_success_infer_time_mark.set_value(int(time.time())) + + # 线程池用于multimodal resource alloc + self.max_concurrent = self.args.concurrent_alloc_workers * self.args.max_tasks_per_worker + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.args.concurrent_alloc_workers) return async def _alloc_resource(self, items, md5sums, token_nums, datas): - + batch_requests = [(md5sum, token_num) for md5sum, token_num in zip(md5sums, token_nums)] while True: - records = obtain(self.cache_client.root.alloc(md5sums, token_nums)) - + t1 = time.time() + req_blob = pickle.dumps(batch_requests) + res_blob = self.cache_client.root.alloc(req_blob) + records = pickle.loads(res_blob) + logger.info(f"cache manager batch alloc time: {(time.time() - t1)*1000} ms") if records is None: await asyncio.sleep(0.1) continue @@ -132,18 +141,37 @@ async def _alloc_resource(self, items, md5sums, token_nums, datas): item.token_num = rec["token_num"] uid_list.append(rec["id"]) - ready_flags = obtain(self.cache_client.root.get_items_data(uid_list)) - update_data_ids = [] + uid_blob = pickle.dumps(uid_list) + ready_flags = self.cache_client.root.get_items_data(uid_blob) + ready_flags = pickle.loads(ready_flags) + + max_concurrent_shm = min(len(items), self.max_concurrent) # 限制最大并发 + semaphore = asyncio.Semaphore(max_concurrent_shm) + + async def create_shm_with_limit(uid, data): + async with semaphore: + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self.executor, create_shm, get_shm_name_data(uid), data) + update_data_ids = [] + shm_tasks = [] for uid, ready, data in zip(uid_list, ready_flags, datas): if not ready: - create_shm(get_shm_name_data(uid), data) + task = create_shm_with_limit(uid, data) + shm_tasks.append(task) update_data_ids.append(uid) + if len(shm_tasks): + t_shm = time.time() + await asyncio.gather(*shm_tasks) + logger.info(f"concurrent create shm time: {(time.time() - t_shm)*1000} ms") + if update_data_ids: - self.cache_client.root.set_items_data(update_data_ids) + update_dataids_blob = pickle.dumps(update_data_ids) + self.cache_client.root.set_items_data(update_dataids_blob) return + @calculate_cpu_time_async(show=True) async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, sampling_params: SamplingParams): # 只有 P 和 NORMAL 节点需要真的管理多模态资源 if self.pd_mode.is_P_or_NORMAL(): @@ -151,28 +179,46 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, # 如果不加任何锁,假如请求1和请求2都有6张图片,而cache_capacity为10, # 那么如果某一时刻shm中存在请求1的5张图和请求2的5张图,将会资源竞争产生死锁。 async with self._resource_lock: - items, md5sums, tokens_nums, datas = [], [], [], [] - for img in multimodal_params.images: - self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params) - data = img.read() - # must after init_imageitem_extral_params - token_num = self.tokenizer.get_image_token_length(img) - md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(img.extra_params))) - md5sums.append(md5sum) - tokens_nums.append(token_num) - datas.append(data) - items.append(img) - for audio in multimodal_params.audios: - self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, sampling_params) - data = audio.read() - token_num = self.tokenizer.get_audio_token_length(audio) - md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(audio.extra_params))) - md5sums.append(md5sum) - tokens_nums.append(token_num) - datas.append(data) - items.append(audio) - - await self._alloc_resource(items, md5sums, tokens_nums, datas) + all_items = multimodal_params.images + multimodal_params.audios + if not all_items: + return + loop = asyncio.get_event_loop() + + def _process_item(item, multimodal_params, sampling_params): + """初始化item参数、读取数据并计算MD5""" + if isinstance(item, ImageItem): # 图片 + self.tokenizer.init_imageitem_extral_params(item, multimodal_params, sampling_params) + elif isinstance(item, AudioItem): + self.tokenizer.init_audioitem_extral_params(item, multimodal_params, sampling_params) + + data = item.read() + md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(item.extra_params))) + return data, md5sum + + chunk_size = self.max_concurrent # 可以根据需要调整 + for i in range(0, len(all_items), chunk_size): + chunk = all_items[i : i + chunk_size] + + # 并发处理chunk内的所有item + process_tasks = [ + loop.run_in_executor(self.executor, _process_item, item, multimodal_params, sampling_params) + for item in chunk + ] + chunk_results = await asyncio.gather(*process_tasks) + chunk_items, chunk_md5sums, chunk_tokens_nums, chunk_datas = [], [], [], [] + for j, item in enumerate(chunk): + data, md5sum = chunk_results[j] + if isinstance(item, ImageItem): + token_num = self.tokenizer.get_image_token_length(item) + elif isinstance(item, AudioItem): + token_num = self.tokenizer.get_audio_token_length(item) + chunk_items.append(item) + chunk_md5sums.append(md5sum) + chunk_tokens_nums.append(token_num) + chunk_datas.append(data) + + await self._alloc_resource(chunk_items, chunk_md5sums, chunk_tokens_nums, chunk_datas) + return async def _release_multimodal_resources(self, multimodal_params: MultimodalParams): @@ -195,7 +241,8 @@ async def _release_multimodal_resources(self, multimodal_params: MultimodalParam audio.token_id = None audio.token_num = None if ids_to_release: - self.cache_client.root.release(ids_to_release) + release_id_blobs = pickle.dumps(ids_to_release) + self.cache_client.root.release(release_id_blobs) return def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwargs=None): @@ -400,7 +447,6 @@ def _count_multimodal_tokens(self, multimodal_params: MultimodalParams) -> Tuple return image_tokens, audio_tokens async def _log_req_header(self, request_headers, group_request_id: int): - x_request_id = request_headers.get("X-Request-Id", "") x_session_id = request_headers.get("X-Session-Id", "") @@ -501,12 +547,12 @@ async def transfer_to_next_module( self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, - ) + ) else: self.send_to_router.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, - ) + ) return if self.pd_mode.is_D(): @@ -517,8 +563,6 @@ async def transfer_to_next_module( ) return - return - assert False, "dead code path" return @@ -531,7 +575,6 @@ async def _wait_to_token_package( req_status: "ReqStatus", request: Request, ): - event = req_status.event unfinished_count = sampling_params.best_of out_token_counter = 0 @@ -637,7 +680,6 @@ async def recycle_resource_loop(self): pre_time_mark = time.time() while True: - try: await asyncio.wait_for(self.recycle_event.wait(), timeout=0.02) except asyncio.TimeoutError: @@ -708,7 +750,6 @@ async def handle_loop(self): for _ in range(read_token_count): if not req.out_tokens_queue.is_empty(): - text, src_index, special, count_output_tokens = req.out_tokens_queue.peek() req.cumlogprob += float(req.shm_logprobs.arr[src_index]) metadata = { diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 652c59d5e..0a22220ec 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -129,7 +129,9 @@ async def loop_for_fwd(self): if disable_prompt_cache: ready_image = [False] * len(img_uuids) else: - ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids)) + img_uuids = pickle.dumps(img_uuids) + ready_image = self.cache_client.root.get_items_embed(img_uuids) + ready_image = pickle.loads(ready_image) for img, ready in zip(multimodal_params.images, ready_image): if not ready: diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index a25065e42..2eba88359 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -3,6 +3,7 @@ import rpyc import torch import inspect +import pickle from datetime import timedelta from typing import Dict, List, Tuple from transformers.configuration_utils import PretrainedConfig @@ -93,14 +94,11 @@ def exposed_init_model(self, kvargs): def forward(self, images: List[ImageItem]): return self.model.encode(images) - # @calculate_time(show=False, min_cost_ms=300) - def exposed_encode(self, images: List[ImageItem]): - images = obtain(images) - all_img_embeds, uuids, valid_ids = self.forward(images) - all_img_embeds = all_img_embeds.to(torch.device("cpu")) - + def alloc_img_embed_resources(self, all_img_embeds, uuids, valid_ids): if self.tp_rank_id == 0: - ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) + uuids_blob = pickle.dumps(uuids) + ready_flags_status = self.cache_client.root.get_items_embed(uuids_blob) + ready_flags = pickle.loads(ready_flags_status) ids_to_set = [] for i, ready in enumerate(ready_flags): if ready: @@ -111,9 +109,16 @@ def exposed_encode(self, images: List[ImageItem]): create_shm(get_shm_name_embed(uid), cur_embed_bytes) ids_to_set.append(uid) if ids_to_set: + ids_to_set = pickle.dumps(ids_to_set) self.cache_client.root.set_items_embed(ids_to_set) return + def exposed_encode(self, images: List[ImageItem]): + images = obtain(images) + all_img_embeds, uuids, valid_ids = self.forward(images) + all_img_embeds = all_img_embeds.to(torch.device("cpu")) + self.alloc_img_embed_resources(all_img_embeds, uuids, valid_ids) + class VisualModelRpcClient: def __init__(self, model_rpc, vit_tp, rpc_server_process=None): @@ -179,14 +184,14 @@ async def start_model_process(port, vit_tp, device_id): proc.start() await asyncio.sleep(2) repeat_count = 0 - while repeat_count < 20: + while repeat_count < 30: try: con = rpyc.connect("localhost", port, config={"allow_pickle": True}) break except BaseException: - await asyncio.sleep(1) + await asyncio.sleep(2) repeat_count += 1 - if repeat_count == 20: + if repeat_count == 30: raise Exception("init rpc env error!") assert proc.is_alive() diff --git a/lightllm/utils/infer_utils.py b/lightllm/utils/infer_utils.py index dadd96648..9e12f3e7e 100644 --- a/lightllm/utils/infer_utils.py +++ b/lightllm/utils/infer_utils.py @@ -1,5 +1,6 @@ import torch import torch.distributed as dist +import functools import time from typing import Callable @@ -69,6 +70,38 @@ def inner_func(*args, **kwargs): return wrapper +def calculate_cpu_time_sync(show=False): + def wrapper(func): + @functools.wraps(func) + def inner_func(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + cost_time = (time.time() - start_time) * 1000 + if show: + logger.debug(f"Function {func.__name__} took {cost_time} ms to run.") + return result + + return inner_func + + return wrapper + + +def calculate_cpu_time_async(show=False): + def wrapper(func): + @functools.wraps(func) + async def inner_func(*args, **kwargs): + start_time = time.time() + result = await func(*args, **kwargs) + cost_time = (time.time() - start_time) * 1000 + if show: + logger.debug(f"Async Function {func.__name__} took {cost_time} ms to run.") + return result + + return inner_func + + return wrapper + + def benchmark_time(func: Callable, *args, warmup: int = 1, repeat: int = 5, **kwargs) -> float: torch.cuda.synchronize() for _ in range(warmup): diff --git a/tools/quick_launch_docker.py b/tools/quick_launch_docker.py index 68a111ac1..62e8cb5ed 100755 --- a/tools/quick_launch_docker.py +++ b/tools/quick_launch_docker.py @@ -10,9 +10,7 @@ default="ghcr.io/modeltc/lightllm:main", help="default to ghcr.io/modeltc/lightllm:main", ) -group_container.add_argument( - "--name", type=str, required=False, help="set a name to the container" -) +group_container.add_argument("--name", type=str, required=False, help="set a name to the container") group_container.add_argument( "--keep-container", "-K", @@ -27,22 +25,14 @@ ) group_server = args.add_argument_group("server") -group_server.add_argument( - "-m", "--model", type=str, required=True, help="path to model dir" -) +group_server.add_argument("-m", "--model", type=str, required=True, help="path to model dir") group_server.add_argument("-p", "--port", type=int, default=8080) -group_server.add_argument( - "-n", "--num-proc", type=int, default=1, help="number of process/gpus" -) +group_server.add_argument("-n", "--num-proc", type=int, default=1, help="number of process/gpus") group_server.add_argument("-mt", "--max-total-tokens", type=int, default=4096) args = args.parse_args() model_path = os.path.abspath(args.model) -shm_size = ( - args.shm_size - if args.shm_size - else (os.sysconf("SC_PAGE_SIZE") * os.sysconf("SC_PHYS_PAGES") // 2) -) +shm_size = args.shm_size if args.shm_size else (os.sysconf("SC_PAGE_SIZE") * os.sysconf("SC_PHYS_PAGES") // 2) launch_args = [ "docker",