Skip to content

Commit

Permalink
Make all tests pass (#546)
Browse files Browse the repository at this point in the history
* fix tests

* all tests pass

* typog
  • Loading branch information
fxmarty committed Feb 15, 2024
1 parent adc2496 commit 75f0d51
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 11 deletions.
7 changes: 5 additions & 2 deletions auto_gptq/modeling/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,8 @@ def skip(*args, **kwargs):
quantize_config.desc_act,
quantize_config.group_size,
bits=quantize_config.bits,
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
)

# TODO: move this logic in an awq_utils.py file.
Expand Down Expand Up @@ -1217,7 +1219,6 @@ def skip(*args, **kwargs):

safe_save(new_state_dict, model_save_name)

# TODO: Move this logic in a marlin_utils.py file.
if use_marlin:
if torch.version.hip:
raise ValueError("Can not use Marlin int4*fp16 kernel with AMD ROCm version of PyTorch as the kernel is not compatible. Please do not use `use_marlin=True` when using ROCm devices.")
Expand All @@ -1234,13 +1235,15 @@ def skip(*args, **kwargs):
)

# Load the quant linear type we need.
# TODO: load directy marlin with the right quantlinear class.
quant_linear_class = dynamically_import_QuantLinear(
use_triton=use_triton,
desc_act=quantize_config.desc_act,
group_size=quantize_config.group_size,
bits=quantize_config.bits,
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
disable_marlin=True, # Get the "original" QuantLienar class
)

# Prepare model for marlin load.
Expand Down Expand Up @@ -1428,7 +1431,7 @@ def make_sure_compatible_with_peft(
):
GeneralQuantLinear.inject_to_model(
model,
dynamically_import_QuantLinear(use_triton, desc_act, group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, use_marlin=use_marlin, use_qigen=use_qigen),
dynamically_import_QuantLinear(use_triton, desc_act, group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, disable_marlin=not use_marlin, use_qigen=use_qigen),
)

def __getattr__(self, item):
Expand Down
9 changes: 7 additions & 2 deletions auto_gptq/modeling/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def make_quant(
group_size,
name="",
use_triton: bool = False,
use_marlin: bool = False,
disable_exllama: Optional[bool] = None,
disable_exllamav2: bool = False,
use_qigen: bool = False,
Expand All @@ -80,6 +81,7 @@ def make_quant(
desc_act=desc_act,
group_size=group_size,
bits=bits,
disable_marlin=not use_marlin,
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
use_qigen=use_qigen,
Expand Down Expand Up @@ -266,6 +268,7 @@ def pack_model(
bits=bits,
disable_exllama=False,
disable_exllamav2=True,
disable_marlin=True,
)

if force_layer_back_to_cpu:
Expand Down Expand Up @@ -489,8 +492,10 @@ def autogptq_post_init(model, use_act_order: bool, max_input_length: Optional[in
return model


def make_sure_no_tensor_in_meta_device(model, use_triton, desc_act, group_size, bits: int):
QuantLinear = dynamically_import_QuantLinear(use_triton, desc_act, group_size, bits=bits)
def make_sure_no_tensor_in_meta_device(
model, use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool, disable_exllamav2: bool, use_marlin: bool = False,
):
QuantLinear = dynamically_import_QuantLinear(use_triton, desc_act, group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, disable_marlin=not use_marlin)
for n, m in model.named_modules():
if isinstance(m, QuantLinear) and m.bias.device == torch.device("meta"):
m.register_buffer("bias", torch.zeros((m.outfeatures), dtype=torch.float16, device="cpu"))
Expand Down
11 changes: 7 additions & 4 deletions tests/test_awq_compatibility_generation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# ruff: noqa: I001
import unittest

import torch
import autogptq_cuda_64
import autogptq_cuda_256
import torch
from transformers import AutoTokenizer

from auto_gptq import AutoGPTQForCausalLM
Expand All @@ -12,9 +13,8 @@
try:
from awq import AutoAWQForCausalLM
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"AutoAWQ package (https://github.com/casper-hansen/AutoAWQ) is required to run this test. {e}"
)
AutoAWQForCausalLM = None
AWQ_EXCEPTION = e


class TestAwqCompatibility(unittest.TestCase):
Expand All @@ -23,6 +23,9 @@ class TestAwqCompatibility(unittest.TestCase):
# TODO: test exllama v2.

def test_generation_cuda_old_fp32_pytorch(self):
if AutoAWQForCausalLM is None:
self.skipTest(f"AutoAWQ package (https://github.com/casper-hansen/AutoAWQ) is required to run this test. {AWQ_EXCEPTION}")

device = torch.device("cuda:0")
quant_path = "TheBloke/Llama-2-7B-Chat-AWQ"

Expand Down
10 changes: 9 additions & 1 deletion tests/test_q4.py
Original file line number Diff line number Diff line change
Expand Up @@ -2057,7 +2057,15 @@ def test_generation(self):
device = torch.device("cuda:0")

model_id = "TheBloke/Llama-2-7B-Chat-GPTQ"
model_q = AutoGPTQForCausalLM.from_quantized(model_id, device="cuda:0", use_marlin=True)

try:
model_q = AutoGPTQForCausalLM.from_quantized(model_id, device="cuda:0", use_marlin=True)
except ValueError as e:
if torch.version.hip:
self.assertTrue("Can not use Marlin int4*fp16 kernel with AMD ROCm" in e.text)
self.skipTest("Can not run this test on ROCm")
else:
raise e

has_marlin = False
for _, module in model_q.named_modules():
Expand Down
6 changes: 4 additions & 2 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def test_marlin_local_serialization(self):

second_load_time = end - start

self.assertTrue(second_load_time < 0.2 * first_load_time)
# Since we use a CUDA kernel to repack weights, the first load time is already small.
self.assertTrue(second_load_time < 0.8 * first_load_time)

def test_marlin_hf_cache_serialization(self):
start = time.time()
Expand All @@ -66,4 +67,5 @@ def test_marlin_hf_cache_serialization(self):

second_load_time = end - start

self.assertTrue(second_load_time < 0.2 * first_load_time)
# Since we use a CUDA kernel to repack weights, the first load time is already small.
self.assertTrue(second_load_time < 0.8 * first_load_time)

0 comments on commit 75f0d51

Please sign in to comment.