diff --git a/gptqmodel/looper/stage_inputs_capture.py b/gptqmodel/looper/stage_inputs_capture.py index 5a8c36614..f1d6f954f 100644 --- a/gptqmodel/looper/stage_inputs_capture.py +++ b/gptqmodel/looper/stage_inputs_capture.py @@ -187,6 +187,8 @@ def store_input_hook(module, args, kwargs): **example, **self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS, ) + elif is_ovis: + self.gptq_model.model.generate(inputs=example.pop("input_ids"), **example) else: self.gptq_model.model(**example, use_cache=use_cache) except StopForward: diff --git a/gptqmodel/models/definitions/ovis.py b/gptqmodel/models/definitions/ovis.py index 234ffa81b..a9144ed19 100644 --- a/gptqmodel/models/definitions/ovis.py +++ b/gptqmodel/models/definitions/ovis.py @@ -14,6 +14,7 @@ from ...utils.model import MODALITY, move_to from .._const import CPU from ..base import BaseQModel +from ...utils.offload import offload_to_disk class OvisQModel(BaseQModel): @@ -48,11 +49,24 @@ def monkey_patch(self): self.model.visual_tokenizer = self.model.visual_tokenizer.to(dtype=self.model.llm.dtype) self.model.vte = self.model.vte.to(dtype=self.model.llm.dtype) + self.model.llm.generation_config.max_length = 8192 + def pre_quantize_generate_hook_start(self): - self.model.visual_tokenizer = move_to(self.model.visual_tokenizer, device=self.quantize_config.device) - self.model.vte = move_to(self.model.vte, device=self.quantize_config.device) + self.shell_module_materialize(self.model.visual_tokenizer, self.quantize_config.device) + self.shell_module_materialize(self.model.vte, self.quantize_config.device) def pre_quantize_generate_hook_end(self): + if self.quantize_config.offload_to_disk: + offload_to_disk(model=self.model, + module=self.model.visual_tokenizer, + disk_path=self.quantize_config.offload_to_disk_path, + ) + offload_to_disk(model=self.model, + module=self.model.vte, + disk_path=self.quantize_config.offload_to_disk_path, + ) + return + self.model.visual_tokenizer = move_to(self.model.visual_tokenizer, device=CPU) self.model.vte = move_to(self.model.vte, device=CPU)