From d4ac0efefcf167cea98dd29cade799526fcfd69f Mon Sep 17 00:00:00 2001 From: ceci3 Date: Tue, 19 Dec 2023 18:42:50 +0800 Subject: [PATCH] support nf4 channel wise quant & fix bug when blocksize>512 (#1817) (#1818) --- csrc/lc/dequantize_blockwise.cu | 84 ++++++++++++++++++++--- csrc/lc/quantize_blockwise.cu | 115 ++++++++++++++++++++++++-------- 2 files changed, 162 insertions(+), 37 deletions(-) diff --git a/csrc/lc/dequantize_blockwise.cu b/csrc/lc/dequantize_blockwise.cu index 8046c34a..0bf76a16 100644 --- a/csrc/lc/dequantize_blockwise.cu +++ b/csrc/lc/dequantize_blockwise.cu @@ -201,7 +201,6 @@ template __global__ void kDequantizeBlockwise(const floa //template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(const float *code, const unsigned char * A, const float * absmax, __nv_bfloat16 *out, int blocksize, int n); - template void dequantize_blockwise(const float *code, const unsigned char *A, const float *absmax, T *out, int blocksize, int n) { int num_blocks = n/blocksize; @@ -226,6 +225,50 @@ template void dequantize_blockwise(const float *code, const unsigned //template void dequantize_blockwise<__nv_bfloat16, FP4>(const float *code, const unsigned char *A, const float *absmax, __nv_bfloat16 *out, int blocksize, int n); //template void dequantize_blockwise<__nv_bfloat16, NF4>(const float *code, const unsigned char *A, const float *absmax, __nv_bfloat16 *out, int blocksize, int n); +template +__global__ void kDequantizeChannelwise(const unsigned char* A, + const float *absmax, + float *out, + int n, + int cout) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + + int num = n / 2; + //int part_n = num / cout; + for (int i = idx; i < num; i += blockDim.x * gridDim.x) { + float local_absmax = absmax[i%cout]; + int idx = 2*(i/cout)* cout + i%cout; + switch(DATA_TYPE) + { + case FP4: + out[i*2 + i%cout] = dDequantizeFP4Tree(A[i] >> 4, local_absmax); + out[i*2 + cout + i%cout] = dDequantizeFP4Tree(A[i] & 0x0F, local_absmax); + break; + case NF4: + out[idx] = dDequantizeNF4(A[i] >> 4)* local_absmax; + out[idx + cout] = dDequantizeNF4(A[i] & 0x0F)* local_absmax; + break; + } + __syncthreads(); + } +} + +template void dequantize_channelwise(const unsigned char *A, const float *absmax, T *out, int n, int cout) +{ + int max_threads = 1024; + int64_t block_size = + std::min(static_cast(n), + static_cast(max_threads/ 4)); + + const int64_t max_blocks = + std::max(((max_threads - 1) / block_size + 1), static_cast(1)); + const int64_t grid_size = + std::min(max_blocks, (n + block_size - 1) / block_size); + + kDequantizeChannelwise<<>>(A, absmax, out, n, cout); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + std::vector DequantizeBlockwise(const paddle::Tensor& input, const paddle::Tensor& code, const paddle::Tensor& absmax, int blocksize, std::string quant_type) { int64_t input_numel = input.numel(); int n = input_numel; @@ -234,23 +277,44 @@ std::vector DequantizeBlockwise(const paddle::Tensor& input, con out_shape = {input_numel * 2, 1}; n = n * 2; } + if (blocksize == -1) { + out_shape = {input.shape()[0] * 2, input.shape()[1]}; + } auto out = paddle::empty(out_shape, paddle::DataType::FLOAT32, input.place()); - if (quant_type == "8bit") - dequantize_blockwise(code.data(), input.data(), absmax.data(), out.data(), blocksize, n); - else if (quant_type == "nf4") - dequantize_blockwise(NULL, input.data(), absmax.data(), out.data(), blocksize, n); - else if (quant_type == "fp4") - dequantize_blockwise(NULL, input.data(), absmax.data(), out.data(), blocksize, n); - else - PD_THROW("NOT supported quant type. Only 8bit, nf4, fp4 are supported. "); + if (blocksize == -1) { + if (quant_type == "8bit") + PD_THROW("blocksize is -1 only support NF4 and FP4."); + else + blocksize = n / absmax.numel() * 2; + + int cout = input.shape()[1]; + if (quant_type == "nf4") + dequantize_channelwise(input.data(), absmax.data(), out.data(), n, cout); + else if (quant_type == "fp4") + dequantize_channelwise(input.data(), absmax.data(), out.data(), n, cout); + else + PD_THROW("NOT supported quant type. Only 8bit, nf4, fp4 are supported. "); + } else { + if (quant_type == "8bit") + dequantize_blockwise(code.data(), input.data(), absmax.data(), out.data(), blocksize, n); + else if (quant_type == "nf4") + dequantize_blockwise(NULL, input.data(), absmax.data(), out.data(), blocksize, n); + else if (quant_type == "fp4") + dequantize_blockwise(NULL, input.data(), absmax.data(), out.data(), blocksize, n); + else + PD_THROW("NOT supported quant type. Only 8bit, nf4, fp4 are supported. "); + } return {out}; }; std::vector> GetDequantizeBlockwiseInferShape(const std::vector& input_shape, const std::vector& code_shape, const std::vector& abs_max_shape, int blocksize, std::string quant_type){ int64_t first_shape = input_shape[0] * input_shape[1] * 2; if (quant_type != "8bit") - return {{first_shape, 1}}; + if (blocksize != -1) + return {{first_shape, 1}}; + else + return {{input_shape[0] * 2, input_shape[1]}}; else return {input_shape}; } diff --git a/csrc/lc/quantize_blockwise.cu b/csrc/lc/quantize_blockwise.cu index d4f6ff2c..e8e55b9d 100644 --- a/csrc/lc/quantize_blockwise.cu +++ b/csrc/lc/quantize_blockwise.cu @@ -279,6 +279,7 @@ __global__ void kQuantizeBlockwise(const float * code, const T * __restrict__ A, #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH/2; j++) { + packed_4bit = 0; packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); qvals[j] = packed_4bit; @@ -360,9 +361,39 @@ MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, NF4) +template +__global__ void kQuantizeChannelwise(const float *code, + const T* A, + unsigned char* out, + float *absmax, + int n, + int cout) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + + int num = n / 2; + for (int i = idx; i < num; i += blockDim.x * gridDim.x) { + int idx = 2*(i/cout)* cout + i%cout; + float local_absmax = absmax[i %cout]; + float inv_local_absmax = 1.0f/local_absmax; + unsigned char packed_4bit = 0; + switch(DATA_TYPE) + { + case FP4: + packed_4bit |= dQuantizeFP4(((float)A[idx])*inv_local_absmax) << 4; + packed_4bit |= dQuantizeFP4(((float)A[idx+cout])*inv_local_absmax); + out[i] = packed_4bit; + break; + case NF4: + packed_4bit |= dQuantizeNF4(((float)A[idx])*inv_local_absmax) << 4; + packed_4bit |= dQuantizeNF4(((float)A[idx+cout])*inv_local_absmax); + out[i] = packed_4bit; + break; + } + } +} -template void quantize_blockwise(const float *code, const paddle::Tensor& A, float *absmax, unsigned char *out, int blocksize, int n) +template void quantize_blockwise(const float *code, const paddle::Tensor& A, paddle::Tensor& absmax, unsigned char *out, int blocksize, int n, int channelwise) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; @@ -372,22 +403,43 @@ template void quantize_blockwise(const float num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; const DataType_* A_data = reinterpret_cast(A.data()); - if(blocksize == 4096) - kQuantizeBlockwise<<>>(code, A_data, absmax, out, n); - else if(blocksize == 2048) - kQuantizeBlockwise<<>>(code, A_data, absmax, out, n); - else if(blocksize == 1024) - kQuantizeBlockwise<<>>(code, A_data, absmax, out, n); - else if(blocksize == 512) - kQuantizeBlockwise<<>>(code, A_data, absmax, out, n); - else if(blocksize == 256) - kQuantizeBlockwise<<>>(code, A_data, absmax, out, n); - else if(blocksize == 128) - kQuantizeBlockwise<<>>(code, A_data, absmax, out, n); - else if(blocksize == 64) - kQuantizeBlockwise<<>>(code, A_data, absmax, out, n); - else - PD_THROW("only support blocksize is [64, 128, 256, 512, 1024, 2048, 4096]."); + if (channelwise == 0) { + if(blocksize == 4096) + kQuantizeBlockwise<<>>(code, A_data, absmax.data(), out, n); + else if(blocksize == 2048) + kQuantizeBlockwise<<>>(code, A_data, absmax.data(), out, n); + else if(blocksize == 1024) + kQuantizeBlockwise<<>>(code, A_data, absmax.data(), out, n); + else if(blocksize == 512) + kQuantizeBlockwise<<>>(code, A_data, absmax.data(), out, n); + else if(blocksize == 256) + kQuantizeBlockwise<<>>(code, A_data, absmax.data(), out, n); + else if(blocksize == 128) + kQuantizeBlockwise<<>>(code, A_data, absmax.data(), out, n); + else if(blocksize == 64) + kQuantizeBlockwise<<>>(code, A_data, absmax.data(), out, n); + } + else { + if (DATA_TYPE == General8bit) + PD_THROW("blocksize is -1 only support NF4 and FP4."); + + int cout = A.shape()[1]; + int max_threads = 1024; + + absmax = A.abs().max({0}); + + int64_t block_size = + std::min(static_cast(n), + static_cast(max_threads/ 4)); + + const int64_t max_blocks = + std::max(((max_threads - 1) / block_size + 1), static_cast(1)); + const int64_t grid_size = + std::min(max_blocks, (n + block_size - 1) / block_size); + + kQuantizeChannelwise<<>>( + code, A_data, out, absmax.data(), n, cout); + } CUDA_CHECK_RETURN(cudaPeekAtLastError()); @@ -395,38 +447,44 @@ template void quantize_blockwise(const float std::vector QuantizeBlockwise(const paddle::Tensor& input, const paddle::Tensor& code, int blocksize, std::string quant_type) { int n = input.numel(); + int channelwise = 0; std::vector out_shape = input.shape(); if (quant_type != "8bit") { // 4bit out_shape = {(n + 1) / 2, 1}; } + if (blocksize == -1){ + blocksize = input.shape()[0]; + out_shape = {input.shape()[0]/2, input.shape()[1]}; + channelwise = 1; + } auto out = paddle::empty(out_shape, paddle::DataType::UINT8, input.place()); int64_t absmax_shape = n / blocksize; auto absmax = paddle::empty({absmax_shape}, paddle::DataType::FLOAT32, input.place()); switch(input.type()) { case paddle::DataType::FLOAT32: if (quant_type == "8bit") - quantize_blockwise(code.data(), input, absmax.data(), out.data(), blocksize, n); + quantize_blockwise(code.data(), input, absmax, out.data(), blocksize, n, channelwise); else if (quant_type == "nf4") { - quantize_blockwise(NULL, input, absmax.data(), out.data(), blocksize, n); + quantize_blockwise(NULL, input, absmax, out.data(), blocksize, n, channelwise); } else if (quant_type == "fp4") - quantize_blockwise(NULL, input, absmax.data(), out.data(), blocksize, n); + quantize_blockwise(NULL, input, absmax, out.data(), blocksize, n, channelwise); return {out, absmax}; case paddle::DataType::FLOAT16: if (quant_type == "8bit") - quantize_blockwise(code.data(), input, absmax.data(), out.data(), blocksize, n); + quantize_blockwise(code.data(), input, absmax, out.data(), blocksize, n, channelwise); else if (quant_type == "nf4") - quantize_blockwise(NULL, input, absmax.data(), out.data(), blocksize, n); + quantize_blockwise(NULL, input, absmax, out.data(), blocksize, n, channelwise); else if (quant_type == "fp4") - quantize_blockwise(NULL, input, absmax.data(), out.data(), blocksize, n); + quantize_blockwise(NULL, input, absmax, out.data(), blocksize, n, channelwise); return {out, absmax}; case paddle::DataType::BFLOAT16: if (quant_type == "8bit") - quantize_blockwise(code.data(), input, absmax.data(), out.data(), blocksize, n); + quantize_blockwise(code.data(), input, absmax, out.data(), blocksize, n, channelwise); else if (quant_type == "nf4") - quantize_blockwise(NULL, input, absmax.data(), out.data(), blocksize, n); + quantize_blockwise(NULL, input, absmax, out.data(), blocksize, n, channelwise); else if (quant_type == "fp4") - quantize_blockwise(NULL, input, absmax.data(), out.data(), blocksize, n); + quantize_blockwise(NULL, input, absmax, out.data(), blocksize, n, channelwise); return {out, absmax}; default: @@ -440,7 +498,10 @@ std::vector QuantizeBlockwise(const paddle::Tensor& input, const std::vector> GetQuantizeBlockwiseInferShape(const std::vector& input_shape, const std::vector& code_shape, int blocksize, std::string quant_type){ int64_t first_shape = (input_shape[0] * input_shape[1] + 1) / 2; if (quant_type != "8bit") - return {{first_shape, 1}}; + if (blocksize != -1) + return {{first_shape, 1}}; + else + return {{input_shape[0]/2, input_shape[1]}}; else return {input_shape}; }