Skip to content

Commit

Permalink
quantzation awq gemm + gemv
Browse files Browse the repository at this point in the history
  • Loading branch information
minhthuc2502 committed Jun 19, 2024
1 parent fc24971 commit 46cc53c
Show file tree
Hide file tree
Showing 27 changed files with 1,788 additions and 78 deletions.
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ set(SOURCES
src/ops/transpose.cc
src/ops/nccl_ops.cc
src/ops/nccl_ops_cpu.cc
src/ops/awq/dequantize.cc
src/ops/awq/gemm.cc
src/ops/awq/gemv.cc
src/ops/sum.cc
src/padder.cc
src/profiler.cc
src/random.cc
Expand Down Expand Up @@ -595,6 +599,9 @@ if (WITH_CUDA)
src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
src/ops/awq/gemm_gpu.cu
src/ops/awq/gemv_gpu.cu
src/ops/awq/dequantize_gpu.cu
)
elseif(WITH_CUDNN)
message(FATAL_ERROR "WITH_CUDNN=ON requires WITH_CUDA=ON")
Expand Down
3 changes: 3 additions & 0 deletions include/ctranslate2/layers/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,19 @@ namespace ctranslate2 {
const StorageView& _weight;
const StorageView* _bias;
const StorageView* _qscale;
const StorageView* _qzero;
const StorageView* _u8_shift_compensation;
StorageView _partial_weight;
StorageView _partial_bias;
StorageView _partial_qscale;
StorageView _partial_u8_shift_compensation;
const DataType _output_type;
const models::QUANTIZATION_TYPE _quant_method;
const bool _quantized_gemm;
const ops::Gemm _gemm_op;
const ops::Quantize _quantize_op;
const ops::Dequantize _dequantize_op;
const ops::ActivationType* _activation_type;
const bool _is_layer_out;
};

Expand Down
17 changes: 16 additions & 1 deletion include/ctranslate2/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
namespace ctranslate2 {
namespace models {

enum class QUANTIZATION_TYPE {
CT2,
AWQ_GEMM,
AWQ_GEMV
};

static const size_t current_binary_version = 6;

// Checks whether the provided path could contain a CTranslate2 model.
Expand Down Expand Up @@ -90,6 +96,14 @@ namespace ctranslate2 {
return _use_flash_attention;
}

QUANTIZATION_TYPE quant_method() const {
return _quant_method;
}

void set_quant_method(QUANTIZATION_TYPE type) {
_quant_method = type;
}

virtual bool use_global_int16_scale() const {
return true;
}
Expand Down Expand Up @@ -160,7 +174,7 @@ namespace ctranslate2 {

private:
void process_linear_weights();
void set_compute_type(ComputeType type, Device device, int device_index);
void set_compute_type(ComputeType type, Device device, int device_index, bool update_weight=true);
void ensure_dtype(const std::string& name,
StorageView& variable,
const DataType target_dtype);
Expand All @@ -177,6 +191,7 @@ namespace ctranslate2 {
std::unordered_map<std::string, std::shared_ptr<StorageView>> _variable_index;
bool _use_flash_attention = false;
bool _tensor_parallel = false;
QUANTIZATION_TYPE _quant_method = QUANTIZATION_TYPE::CT2;
};

template<>
Expand Down
26 changes: 26 additions & 0 deletions include/ctranslate2/ops/awq/dequantize_awq.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

#include "../op.h"

namespace ctranslate2 {
namespace ops {

class DequantizeAwq : public Op {
public:
DequantizeAwq();

void operator()(const StorageView& input,
const StorageView& scale,
const StorageView& zeros,
StorageView& output) const;

private:
template <Device D, typename InT, typename OutT>
void dequantize(const StorageView& input,
const StorageView& scale,
const StorageView& zeros,
StorageView& output) const;
};

}
}
27 changes: 27 additions & 0 deletions include/ctranslate2/ops/awq/gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include "../activation.h"
#include "../gemm.h"

namespace ctranslate2 {
namespace ops {
class GemmAwq : public Gemm {
public:
using Gemm::Gemm;
void operator()(const StorageView& a,
const StorageView& b,
const StorageView& scale,
const StorageView& zero,
StorageView& c,
const StorageView* bias = nullptr) const;

private:
template <Device D, typename In, typename Out>
void compute(const StorageView& a,
const StorageView& b,
const StorageView& scale,
const StorageView& zero,
StorageView& c) const;
};
}
}
33 changes: 33 additions & 0 deletions include/ctranslate2/ops/awq/gemv.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#pragma once

#include "../activation.h"
#include "../gemm.h"

namespace ctranslate2 {
namespace ops {
class GemvAwq : public Gemm {
public:
using Gemm::Gemm;
void operator()(const StorageView& a,
const StorageView& b,
const StorageView& scale,
const StorageView& zero,
StorageView& c,
const StorageView* bias = nullptr) const;

private:
template <Device D, typename In, typename Out>
void compute_gemv(const StorageView& a,
const StorageView& b,
const StorageView& scale,
const StorageView& zero,
StorageView& c) const;
template <Device D, typename In, typename Out>
void compute_gemv2(const StorageView& a,
const StorageView& b,
const StorageView& scale,
const StorageView& zero,
StorageView& c) const;
};
}
}
3 changes: 2 additions & 1 deletion include/ctranslate2/ops/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ namespace ctranslate2 {
const dim_t k,
const dim_t n,
const float alpha);
protected:
const ActivationType* _activation_type;

private:
float _alpha;
Expand All @@ -47,7 +49,6 @@ namespace ctranslate2 {
bool _trans_b;
bool _a_is_packed;
bool _b_is_packed;
const ActivationType* _activation_type;

template <Device D, typename In, typename Out>
void compute(const StorageView& a,
Expand Down
3 changes: 2 additions & 1 deletion include/ctranslate2/ops/mean.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ namespace ctranslate2 {

void operator()(const StorageView& input, StorageView& output) const override;

private:
protected:
template <Device D, typename T>
void compute(const StorageView& input,
const dim_t outer_size,
const dim_t axis_size,
const dim_t inner_size,
const bool get_sum,
StorageView& output) const;

const dim_t _axis;
Expand Down
4 changes: 4 additions & 0 deletions include/ctranslate2/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@
#include "slide.h"
#include "nccl_ops.h"
#include "flash_attention.h"
#include "awq/gemm.h"
#include "awq/gemv.h"
#include "awq/dequantize_awq.h"
#include "sum.h"
17 changes: 17 additions & 0 deletions include/ctranslate2/ops/sum.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

#include "op.h"
#include "mean.h"

namespace ctranslate2 {
namespace ops {

class Sum : public Mean {
public:
Sum(const dim_t axis);

void operator()(const StorageView& input, StorageView& output) const override;
};

}
}
Loading

0 comments on commit 46cc53c

Please sign in to comment.