Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support nf4 channel wise quant & fix bug when blocksize>512 #1817

Merged
merged 4 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 74 additions & 10 deletions csrc/lc/dequantize_blockwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ template __global__ void kDequantizeBlockwise<float, 512, 64, 8, NF4>(const floa
//template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(const float *code, const unsigned char * A, const float * absmax, __nv_bfloat16 *out, int blocksize, int n);



template<typename T, int DATA_TYPE> void dequantize_blockwise(const float *code, const unsigned char *A, const float *absmax, T *out, int blocksize, int n)
{
int num_blocks = n/blocksize;
Expand All @@ -226,6 +225,50 @@ template void dequantize_blockwise<float, NF4>(const float *code, const unsigned
//template void dequantize_blockwise<__nv_bfloat16, FP4>(const float *code, const unsigned char *A, const float *absmax, __nv_bfloat16 *out, int blocksize, int n);
//template void dequantize_blockwise<__nv_bfloat16, NF4>(const float *code, const unsigned char *A, const float *absmax, __nv_bfloat16 *out, int blocksize, int n);

template <typename T, int DATA_TYPE>
__global__ void kDequantizeChannelwise(const unsigned char* A,
const float *absmax,
float *out,
int n,
int cout) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;

int num = n / 2;
//int part_n = num / cout;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
float local_absmax = absmax[i%cout];
int idx = 2*(i/cout)* cout + i%cout;
switch(DATA_TYPE)
{
case FP4:
out[i*2 + i%cout] = dDequantizeFP4Tree(A[i] >> 4, local_absmax);
out[i*2 + cout + i%cout] = dDequantizeFP4Tree(A[i] & 0x0F, local_absmax);
break;
case NF4:
out[idx] = dDequantizeNF4(A[i] >> 4)* local_absmax;
out[idx + cout] = dDequantizeNF4(A[i] & 0x0F)* local_absmax;
break;
}
__syncthreads();
}
}

template<typename T, int DATA_TYPE> void dequantize_channelwise(const unsigned char *A, const float *absmax, T *out, int n, int cout)
{
int max_threads = 1024;
int64_t block_size =
std::min(static_cast<int64_t>(n),
static_cast<int64_t>(max_threads/ 4));

const int64_t max_blocks =
std::max(((max_threads - 1) / block_size + 1), static_cast<int64_t>(1));
const int64_t grid_size =
std::min(max_blocks, (n + block_size - 1) / block_size);

kDequantizeChannelwise<T, DATA_TYPE><<<grid_size, block_size>>>(A, absmax, out, n, cout);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

std::vector<paddle::Tensor> DequantizeBlockwise(const paddle::Tensor& input, const paddle::Tensor& code, const paddle::Tensor& absmax, int blocksize, std::string quant_type) {
int64_t input_numel = input.numel();
int n = input_numel;
Expand All @@ -234,23 +277,44 @@ std::vector<paddle::Tensor> DequantizeBlockwise(const paddle::Tensor& input, con
out_shape = {input_numel * 2, 1};
n = n * 2;
}
if (blocksize == -1) {
out_shape = {input.shape()[0] * 2, input.shape()[1]};
}
auto out = paddle::empty(out_shape, paddle::DataType::FLOAT32, input.place());

if (quant_type == "8bit")
dequantize_blockwise<float, General8bit>(code.data<float>(), input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), blocksize, n);
else if (quant_type == "nf4")
dequantize_blockwise<float, NF4>(NULL, input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), blocksize, n);
else if (quant_type == "fp4")
dequantize_blockwise<float, FP4>(NULL, input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), blocksize, n);
else
PD_THROW("NOT supported quant type. Only 8bit, nf4, fp4 are supported. ");
if (blocksize == -1) {
if (quant_type == "8bit")
PD_THROW("blocksize is -1 only support NF4 and FP4.");
else
blocksize = n / absmax.numel() * 2;

int cout = input.shape()[1];
if (quant_type == "nf4")
dequantize_channelwise<float, NF4>(input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), n, cout);
else if (quant_type == "fp4")
dequantize_channelwise<float, FP4>(input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), n, cout);
else
PD_THROW("NOT supported quant type. Only 8bit, nf4, fp4 are supported. ");
} else {
if (quant_type == "8bit")
dequantize_blockwise<float, General8bit>(code.data<float>(), input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), blocksize, n);
else if (quant_type == "nf4")
dequantize_blockwise<float, NF4>(NULL, input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), blocksize, n);
else if (quant_type == "fp4")
dequantize_blockwise<float, FP4>(NULL, input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), blocksize, n);
else
PD_THROW("NOT supported quant type. Only 8bit, nf4, fp4 are supported. ");
}
return {out};
};

std::vector<std::vector<int64_t>> GetDequantizeBlockwiseInferShape(const std::vector<int64_t>& input_shape, const std::vector<int64_t>& code_shape, const std::vector<int64_t>& abs_max_shape, int blocksize, std::string quant_type){
int64_t first_shape = input_shape[0] * input_shape[1] * 2;
if (quant_type != "8bit")
return {{first_shape, 1}};
if (blocksize != -1)
return {{first_shape, 1}};
else
return {{input_shape[0] * 2, input_shape[1]}};
else
return {input_shape};
}
Expand Down
115 changes: 88 additions & 27 deletions csrc/lc/quantize_blockwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ __global__ void kQuantizeBlockwise(const float * code, const T * __restrict__ A,
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++)
{
packed_4bit = 0;
packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4;
packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max);
qvals[j] = packed_4bit;
Expand Down Expand Up @@ -360,9 +361,39 @@ MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, NF4)
MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, NF4)
MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, NF4)

template <typename T, int DATA_TYPE>
__global__ void kQuantizeChannelwise(const float *code,
const T* A,
unsigned char* out,
float *absmax,
int n,
int cout) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;

int num = n / 2;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
int idx = 2*(i/cout)* cout + i%cout;
float local_absmax = absmax[i %cout];
float inv_local_absmax = 1.0f/local_absmax;

unsigned char packed_4bit = 0;
switch(DATA_TYPE)
{
case FP4:
packed_4bit |= dQuantizeFP4(((float)A[idx])*inv_local_absmax) << 4;
packed_4bit |= dQuantizeFP4(((float)A[idx+cout])*inv_local_absmax);
out[i] = packed_4bit;
break;
case NF4:
packed_4bit |= dQuantizeNF4(((float)A[idx])*inv_local_absmax) << 4;
packed_4bit |= dQuantizeNF4(((float)A[idx+cout])*inv_local_absmax);
out[i] = packed_4bit;
break;
}
}
}

template <paddle::DataType D, int DATA_TYPE> void quantize_blockwise(const float *code, const paddle::Tensor& A, float *absmax, unsigned char *out, int blocksize, int n)
template <paddle::DataType D, int DATA_TYPE> void quantize_blockwise(const float *code, const paddle::Tensor& A, paddle::Tensor& absmax, unsigned char *out, int blocksize, int n, int channelwise)
{
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
Expand All @@ -372,61 +403,88 @@ template <paddle::DataType D, int DATA_TYPE> void quantize_blockwise(const float
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;

const DataType_* A_data = reinterpret_cast<const DataType_*>(A.data<data_t>());
if(blocksize == 4096)
kQuantizeBlockwise<DataType_, 4096, 4, 0><<<num_blocks, 1024>>>(code, A_data, absmax, out, n);
else if(blocksize == 2048)
kQuantizeBlockwise<DataType_, 2048, 4, DATA_TYPE><<<num_blocks, 512>>>(code, A_data, absmax, out, n);
else if(blocksize == 1024)
kQuantizeBlockwise<DataType_, 1024, 4, DATA_TYPE><<<num_blocks, 256>>>(code, A_data, absmax, out, n);
else if(blocksize == 512)
kQuantizeBlockwise<DataType_, 512, 2, DATA_TYPE><<<num_blocks, 256>>>(code, A_data, absmax, out, n);
else if(blocksize == 256)
kQuantizeBlockwise<DataType_, 256, 2, DATA_TYPE><<<num_blocks, 128>>>(code, A_data, absmax, out, n);
else if(blocksize == 128)
kQuantizeBlockwise<DataType_, 128, 2, DATA_TYPE><<<num_blocks, 64>>>(code, A_data, absmax, out, n);
else if(blocksize == 64)
kQuantizeBlockwise<DataType_, 64, 2, DATA_TYPE><<<num_blocks, 32>>>(code, A_data, absmax, out, n);
else
PD_THROW("only support blocksize is [64, 128, 256, 512, 1024, 2048, 4096].");
if (channelwise == 0) {
if(blocksize == 4096)
kQuantizeBlockwise<DataType_, 4096, 4, 0><<<num_blocks, 1024>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 2048)
kQuantizeBlockwise<DataType_, 2048, 4, DATA_TYPE><<<num_blocks, 512>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 1024)
kQuantizeBlockwise<DataType_, 1024, 4, DATA_TYPE><<<num_blocks, 256>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 512)
kQuantizeBlockwise<DataType_, 512, 2, DATA_TYPE><<<num_blocks, 256>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 256)
kQuantizeBlockwise<DataType_, 256, 2, DATA_TYPE><<<num_blocks, 128>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 128)
kQuantizeBlockwise<DataType_, 128, 2, DATA_TYPE><<<num_blocks, 64>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 64)
kQuantizeBlockwise<DataType_, 64, 2, DATA_TYPE><<<num_blocks, 32>>>(code, A_data, absmax.data<float>(), out, n);
}
else {
if (DATA_TYPE == General8bit)
PD_THROW("blocksize is -1 only support NF4 and FP4.");

int cout = A.shape()[1];
int max_threads = 1024;

absmax = A.abs().max({0});

int64_t block_size =
std::min(static_cast<int64_t>(n),
static_cast<int64_t>(max_threads/ 4));

const int64_t max_blocks =
std::max(((max_threads - 1) / block_size + 1), static_cast<int64_t>(1));
const int64_t grid_size =
std::min(max_blocks, (n + block_size - 1) / block_size);

kQuantizeChannelwise<DataType_, DATA_TYPE><<<grid_size, block_size, 0>>>(
code, A_data, out, absmax.data<float>(), n, cout);
}


CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

std::vector<paddle::Tensor> QuantizeBlockwise(const paddle::Tensor& input, const paddle::Tensor& code, int blocksize, std::string quant_type) {
int n = input.numel();
int channelwise = 0;
std::vector<int64_t> out_shape = input.shape();
if (quant_type != "8bit") { // 4bit
out_shape = {(n + 1) / 2, 1};
}
if (blocksize == -1){
blocksize = input.shape()[0];
out_shape = {input.shape()[0]/2, input.shape()[1]};
channelwise = 1;
}
auto out = paddle::empty(out_shape, paddle::DataType::UINT8, input.place());
int64_t absmax_shape = n / blocksize;
auto absmax = paddle::empty({absmax_shape}, paddle::DataType::FLOAT32, input.place());
switch(input.type()) {
case paddle::DataType::FLOAT32:
if (quant_type == "8bit")
quantize_blockwise<paddle::DataType::FLOAT32, General8bit>(code.data<float>(), input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT32, General8bit>(code.data<float>(), input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
else if (quant_type == "nf4") {
quantize_blockwise<paddle::DataType::FLOAT32, NF4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT32, NF4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
}
else if (quant_type == "fp4")
quantize_blockwise<paddle::DataType::FLOAT32, FP4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT32, FP4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
return {out, absmax};
case paddle::DataType::FLOAT16:
if (quant_type == "8bit")
quantize_blockwise<paddle::DataType::FLOAT16, General8bit>(code.data<float>(), input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT16, General8bit>(code.data<float>(), input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
else if (quant_type == "nf4")
quantize_blockwise<paddle::DataType::FLOAT16, NF4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT16, NF4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
else if (quant_type == "fp4")
quantize_blockwise<paddle::DataType::FLOAT16, FP4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT16, FP4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
return {out, absmax};
case paddle::DataType::BFLOAT16:
if (quant_type == "8bit")
quantize_blockwise<paddle::DataType::BFLOAT16, General8bit>(code.data<float>(), input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::BFLOAT16, General8bit>(code.data<float>(), input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
else if (quant_type == "nf4")
quantize_blockwise<paddle::DataType::BFLOAT16, NF4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::BFLOAT16, NF4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
else if (quant_type == "fp4")
quantize_blockwise<paddle::DataType::BFLOAT16, FP4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::BFLOAT16, FP4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
return {out, absmax};

default:
Expand All @@ -440,7 +498,10 @@ std::vector<paddle::Tensor> QuantizeBlockwise(const paddle::Tensor& input, const
std::vector<std::vector<int64_t>> GetQuantizeBlockwiseInferShape(const std::vector<int64_t>& input_shape, const std::vector<int64_t>& code_shape, int blocksize, std::string quant_type){
int64_t first_shape = (input_shape[0] * input_shape[1] + 1) / 2;
if (quant_type != "8bit")
return {{first_shape, 1}};
if (blocksize != -1)
return {{first_shape, 1}};
else
return {{input_shape[0]/2, input_shape[1]}};
else
return {input_shape};
}
Expand Down