Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 46 additions & 9 deletions tests/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from transformers import AutoProcessor, AutoTokenizer # noqa: E402

from gptqmodel import BACKEND, GPTQModel # noqa: E402
from gptqmodel.models.base import BaseQModel # noqa: E402
from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402
from gptqmodel.quantization import FORMAT, METHOD # noqa: E402
from gptqmodel.quantization.config import QuantizeConfig # noqa: E402
Expand All @@ -68,7 +69,7 @@ class ModelTest(unittest.TestCase):
TORCH_DTYPE = "auto"
EVAL_BATCH_SIZE = "auto"
QUANT_BATCH_SIZE = 1
LOAD_BACKEND = BACKEND.AUTO
LOAD_BACKEND = BACKEND.TORCH
QUANT_BACKEND = BACKEND.AUTO
USE_VLLM = False
INPUTS_MAX_LENGTH = 2048
Expand Down Expand Up @@ -223,7 +224,12 @@ def run_arc_challenge_eval(self, model, backend, trust_remote_code=False):
def perform_post_quant_validation(self, model_path, trust_remote_code=False):
inference_records = {}
arc_records = {}
compare_backends = (BACKEND.MARLIN, BACKEND.TORCH) if self.FORMAT is FORMAT.GPTQ else (BACKEND.MARLIN, BACKEND.GEMM)
reuse_candidates = {}

compare_backends = (BACKEND.TORCH,) if self.FORMAT is FORMAT.GPTQ else (BACKEND.MARLIN, BACKEND.GEMM)
target_backend = self.LOAD_BACKEND
can_reuse = target_backend not in (BACKEND.AUTO, BACKEND.AUTO_TRAINABLE)

for backend in compare_backends:
log.info(f"Loading post-quant model with backend `{backend.name}`")
model = self.loadQuantModel(
Expand All @@ -233,14 +239,23 @@ def perform_post_quant_validation(self, model_path, trust_remote_code=False):
)
tokenizer = model.tokenizer or self.load_tokenizer(model_path, trust_remote_code=trust_remote_code)
inference_records[backend] = self.run_generic_inference_checks(model, tokenizer, backend)

should_reuse = can_reuse and backend == target_backend and not self.USE_VLLM

try:
arc_records[backend] = self.run_arc_challenge_eval(model, backend, trust_remote_code=trust_remote_code)
finally:
del model
if should_reuse:
reuse_candidates[backend] = model
else:
del model
torch_empty_cache()

self.render_inference_summary(inference_records)
self.render_arc_summary(arc_records)

return reuse_candidates

@staticmethod
def _human_size(num_bytes: int) -> str:
step = 1024.0
Expand Down Expand Up @@ -563,9 +578,17 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne
tokenizer.save_pretrained(path)
self._print_post_quant_artifacts(path)
log.info(f"Quantized Model saved to tmp dir: {path}")
self.perform_post_quant_validation(path, trust_remote_code=trust_remote_code)
q_model = self.loadQuantModel(path, trust_remote_code=trust_remote_code)
q_tokenizer = q_model.tokenizer

target_backend = self.LOAD_BACKEND
reuse_candidates = self.perform_post_quant_validation(path, trust_remote_code=trust_remote_code)

q_model = reuse_candidates.pop(target_backend, None)
if q_model is None:
q_model = self.loadQuantModel(path, trust_remote_code=trust_remote_code)
else:
log.info(f"Reusing post-quant validation model for backend `{target_backend.name}`")

q_tokenizer = q_model.tokenizer or self.load_tokenizer(path, trust_remote_code=trust_remote_code)
if need_create_processor:
processor = AutoProcessor.from_pretrained(path)

Expand Down Expand Up @@ -609,9 +632,13 @@ def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_pa
def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, delete_quantized_model=False, extra_args:dict=None):
try:
with tempfile.TemporaryDirectory() as tmp_dir:
model_path = getattr(model, "model_local_path", None)
if isinstance(model, str):
model_path = model

if self.USE_VLLM:
model_args = {
"pretrained": model.model_local_path,
"pretrained": model_path,
"dtype": "auto", #"float16",
"gpu_memory_utilization": 0.8,
"tensor_parallel_size": 1,
Expand All @@ -630,9 +657,19 @@ def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, del

for framework, tasks in task_groups.items():
log.info(f"TEST: EVAL starting: backend = {self.LOAD_BACKEND}")
log.info(f"Inference from model path: {model.model_local_path}")
if model_path:
log.info(f"Inference from model path: {model_path}")

if isinstance(model, BaseQModel) and not self.USE_VLLM:
eval_target = model
else:
eval_target = model_path

if eval_target is None:
raise ValueError("Model evaluation target could not be determined.")

results = GPTQModel.eval(
model_or_id_or_path=model.model_local_path,
model_or_id_or_path=eval_target,
llm_backend="vllm" if self.USE_VLLM else "gptqmodel",
model_args=model_args,
output_path=tmp_dir,
Expand Down