diff --git a/gptqmodel/nn_modules/qlinear/awq_gemv.py b/gptqmodel/nn_modules/qlinear/awq_gemv.py index eabd08733..50add314f 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemv.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemv.py @@ -37,7 +37,7 @@ class AwqGEMVQuantLinear(AWQuantLinear): SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] # for transformers/optimum tests compat - QUANT_TYPE = "awq_gemm" + QUANT_TYPE = "awq_gemv" def __init__( self, diff --git a/gptqmodel/nn_modules/qlinear/awq_gemv_fast.py b/gptqmodel/nn_modules/qlinear/awq_gemv_fast.py index 225f02749..e12337046 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemv_fast.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemv_fast.py @@ -37,7 +37,7 @@ class AwqGEMVFastQuantLinear(AWQuantLinear): SUPPORTS_DTYPES = [torch.float16] # for transformers/optimum tests compat - QUANT_TYPE = "awq_gemm" + QUANT_TYPE = "awq_gemv_fast" def __init__( self, diff --git a/gptqmodel/nn_modules/qlinear/awq_marlin.py b/gptqmodel/nn_modules/qlinear/awq_marlin.py index c964ea98c..fd55fcd5b 100644 --- a/gptqmodel/nn_modules/qlinear/awq_marlin.py +++ b/gptqmodel/nn_modules/qlinear/awq_marlin.py @@ -58,7 +58,7 @@ class AwqMarlinQuantLinear(AWQuantLinear): REQUIRES_FORMAT_V2 = False # for transformers/optimum tests compat - QUANT_TYPE = "marlin" + QUANT_TYPE = "awq_marlin" # num_bits -> type TYPE_MAP = { diff --git a/gptqmodel/nn_modules/qlinear/torch_fused.py b/gptqmodel/nn_modules/qlinear/torch_fused.py index f36e47e07..633c324af 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused.py @@ -54,7 +54,7 @@ class TorchFusedQuantLinear(PackableQuantLinear): REQUIRES_FORMAT_V2 = True # for transformers/optimum tests compat - QUANT_TYPE = "torch" + QUANT_TYPE = "torch_fused" def __init__( self,