From d394c056d1687d19c53f8ac586a1f4bbe9fe5047 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Tue, 28 Oct 2025 04:01:29 +0000 Subject: [PATCH 1/3] fix ovis2 Signed-off-by: ZX-ModelCloud --- gptqmodel/models/auto.py | 2 + gptqmodel/models/definitions/ovis2.py | 78 ++++++++++++++++++++++ tests/models/ovis/image_to_test_dataset.py | 28 ++++---- tests/models/test_ovis2.py | 58 ++++++++-------- 4 files changed, 119 insertions(+), 47 deletions(-) create mode 100644 gptqmodel/models/definitions/ovis2.py diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index b2169065e..49e381a5a 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -106,6 +106,7 @@ from .definitions.nemotron_h import NemotronHQModel # noqa: E402 from .definitions.opt import OptQModel # noqa: E402 from .definitions.ovis import OvisQModel # noqa: E402 +from .definitions.ovis2 import Ovis2QModel # noqa: E402 from .definitions.pangu_alpha import PanguAlphaQModel # noqa: E402 from .definitions.phi import PhiQModel # noqa: E402 from .definitions.phi3 import Phi3QModel, PhiMoEGPTQForCausalLM # noqa: E402 @@ -204,6 +205,7 @@ "hymba": HymbaQModel, "olmo2": LlamaQModel, # 100% llama clone "ovis": OvisQModel, + "ovis2": Ovis2QModel, "telechat": TeleChat2QModel, "instella": InstellaQModel, "mimo": MimoQModel, diff --git a/gptqmodel/models/definitions/ovis2.py b/gptqmodel/models/definitions/ovis2.py new file mode 100644 index 000000000..5cbe2ff2c --- /dev/null +++ b/gptqmodel/models/definitions/ovis2.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from typing import Dict, Optional + +from PIL import Image +from transformers import AutoModelForImageTextToText, AutoProcessor, ProcessorMixin + +from ...utils.calibration import batched +from ...utils.model import MODALITY +from ..base import BaseQModel +from ...utils.image import extract_vision_info, fetch_image + +class Ovis2QModel(BaseQModel): + loader = AutoModelForImageTextToText + + pre_lm_head_norm_module = "model.language_model.model.norm" + + module_tree = [ + "model", + "language_model", + "layers", + "#", + { + "input_layernorm": ("input_layernorm:!",), + "self_attn": ("q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"), + "post_attention_layernorm": ("post_attention_layernorm:!",), + "mlp": ("gate_proj:0", "up_proj:0", "down_proj:1"), + } + ] + + modality = [MODALITY.IMAGE_TO_TEXT] + + require_load_processor = True + + def preprocess_dataset(self, sample: Dict) -> Dict: + return sample + + def load_processor(self) -> ProcessorMixin: + return AutoProcessor.from_pretrained(self.model_local_path) + + @staticmethod + def process_vision_info( + conversations: list[dict] | list[list[dict]], + ) -> Optional[list[Image.Image]]: + vision_infos = extract_vision_info(conversations) + # Read images + image_inputs = [] + for vision_info in vision_infos: + if "image" in vision_info or "image_url" in vision_info: + image_inputs.append(fetch_image(vision_info)) + else: + raise ValueError("image, image_url should in content.") + if len(image_inputs) == 0: + image_inputs = None + return image_inputs + + def prepare_dataset(self, calibration_dataset, batch_size: int = 1, **kwargs): + processor = self.load_processor() + calib_data = [] + for batch in batched(calibration_dataset, batch_size, process_func=self.preprocess_dataset): + print("batch", batch) + text = processor.apply_chat_template( + batch, tokenize=False, add_generation_prompt=True + ) + image_inputs = self.process_vision_info(batch) + inputs = processor( + text=text, + images=image_inputs, + videos=None, + padding=True, + return_tensors="pt", + ) + calib_data.append(inputs) + del processor + return calib_data diff --git a/tests/models/ovis/image_to_test_dataset.py b/tests/models/ovis/image_to_test_dataset.py index c2e7172cf..c66bddb58 100644 --- a/tests/models/ovis/image_to_test_dataset.py +++ b/tests/models/ovis/image_to_test_dataset.py @@ -3,25 +3,23 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from gptqmodel.models import OvisQModel +from gptqmodel.models.definitions.ovis import OvisQModel +from gptqmodel.models.definitions.ovis2 import Ovis2QModel from gptqmodel.models.definitions.base_qwen2_5_omni import BaseQwen2_5_OmniGPTQ from gptqmodel.models.definitions.base_qwen2_vl import BaseQwen2VLGPTQ def format_ovis_dataset(image, assistant): - return { - "image": image, - "conversations": [ - { - "from": "human", - "value": "\nWrite a detailed description of this image, do not forget about the texts on it if they exist. Also, do not forget to mention the type / style of the image. No bullet points. When writing descriptions, prioritize clarity and direct observation over embellishment or interpretation.\nDon't forget these rules:\n1. **Be Direct and Concise**: Provide straightforward descriptions without adding interpretative or speculative elements.\n2. **Use Segmented Details**: Break down details about different elements of an image into distinct sentences, focusing on one aspect at a time.\n3. **Maintain a Descriptive Focus**: Prioritize purely visible elements of the image, avoiding conclusions or inferences.\n4. **Follow a Logical Structure**: Begin with the central figure or subject and expand outward, detailing its appearance before addressing the surrounding setting.\n5. **Avoid Juxtaposition**: Do not use comparison or contrast language; keep the description purely factual.\n6. **Incorporate Specificity**: Mention age, gender, race, and specific brands or notable features when present, and clearly identify the medium if it's discernible." - }, - { - "from": "gpt", - "value": assistant - } - ] - } + return [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Write a detailed description of this image, do not forget about the texts on it if they exist. Also, do not forget to mention the type / style of the image. No bullet points. When writing descriptions, prioritize clarity and direct observation over embellishment or interpretation.\nDon't forget these rules:\n1. **Be Direct and Concise**: Provide straightforward descriptions without adding interpretative or speculative elements.\n2. **Use Segmented Details**: Break down details about different elements of an image into distinct sentences, focusing on one aspect at a time.\n3. **Maintain a Descriptive Focus**: Prioritize purely visible elements of the image, avoiding conclusions or inferences.\n4. **Follow a Logical Structure**: Begin with the central figure or subject and expand outward, detailing its appearance before addressing the surrounding setting.\n5. **Avoid Juxtaposition**: Do not use comparison or contrast language; keep the description purely factual.\n6. **Incorporate Specificity**: Mention age, gender, race, and specific brands or notable features when present, and clearly identify the medium if it's discernible."}, + ], + }, + {"role": "assistant", "content": assistant}, + ] def format_qwen2_vl_dataset(image, assistant): @@ -68,7 +66,7 @@ def prepare_dataset(format_func, n_sample: int = 20) -> list[list[dict]]: def get_calib_dataset(model): - if isinstance(model, OvisQModel): + if isinstance(model, OvisQModel) or isinstance(model, Ovis2QModel): return prepare_dataset(format_ovis_dataset, n_sample=20) if isinstance(model, BaseQwen2VLGPTQ): diff --git a/tests/models/test_ovis2.py b/tests/models/test_ovis2.py index 324e1ccbb..cfb1b098c 100644 --- a/tests/models/test_ovis2.py +++ b/tests/models/test_ovis2.py @@ -10,50 +10,44 @@ from PIL import Image + class Test(ModelTest): - NATIVE_MODEL_ID = "/monster/data/model/Ovis2-1B" + NATIVE_MODEL_ID = "/monster/data/model/Ovis2-2B-hf" TRUST_REMOTE_CODE = True APPLY_CHAT_TEMPLATE = False EVAL_BATCH_SIZE = 1 def test_ovis(self): - model, tokenizer = self.quantModel(self.NATIVE_MODEL_ID, trust_remote_code=self.TRUST_REMOTE_CODE, - dtype=self.TORCH_DTYPE, multimodal_max_length=8192, batch_size=1) - - text_tokenizer = model.get_text_tokenizer() - visual_tokenizer = model.get_visual_tokenizer() - - # enter image path and prompt + model, tokenizer, processor = self.quantModel(self.NATIVE_MODEL_ID, trust_remote_code=self.TRUST_REMOTE_CODE, + dtype=self.TORCH_DTYPE, batch_size=1) + + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What does this picture show?"}, + ], + }, + ] image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ovis/10016.jpg") image = Image.open(image_path) - text = "What does this picture show?" - query = f'\n{text}' + messages = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + print(messages) - # format conversation - prompt, input_ids, pixel_values = model.preprocess_inputs(query, [image]) - attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id) - input_ids = input_ids.unsqueeze(0).to(device=model.device) - attention_mask = attention_mask.unsqueeze(0).to(device=model.device) - pixel_values = [pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)] + inputs = processor( + images=[image], + text=messages, + return_tensors="pt", + ) + inputs = inputs.to(model.device) + inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16) - # generate output with torch.inference_mode(): - gen_kwargs = { - "max_new_tokens": 1024, - "do_sample": False, - "top_p": None, - "top_k": None, - "temperature": None, - "repetition_penalty": None, - "eos_token_id": model.generation_config.eos_token_id, - "pad_token_id": text_tokenizer.pad_token_id, - "use_cache": True - } - output_ids = \ - model.generate(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, **gen_kwargs)[0] - output = text_tokenizer.decode(output_ids, skip_special_tokens=True) - + output_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False) + generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)] + output = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] print(f'Output:\n{output}') self.assertIn("snow", output.lower()) From b05e67ff23d1879c4b92ca4003f4613ba323d3dd Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Tue, 28 Oct 2025 04:03:47 +0000 Subject: [PATCH 2/3] cleanup Signed-off-by: ZX-ModelCloud --- tests/models/ovis/image_to_test_dataset.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/tests/models/ovis/image_to_test_dataset.py b/tests/models/ovis/image_to_test_dataset.py index c66bddb58..6739ee8f9 100644 --- a/tests/models/ovis/image_to_test_dataset.py +++ b/tests/models/ovis/image_to_test_dataset.py @@ -8,8 +8,22 @@ from gptqmodel.models.definitions.base_qwen2_5_omni import BaseQwen2_5_OmniGPTQ from gptqmodel.models.definitions.base_qwen2_vl import BaseQwen2VLGPTQ - def format_ovis_dataset(image, assistant): + return { + "image": image, + "conversations": [ + { + "from": "human", + "value": "\nWrite a detailed description of this image, do not forget about the texts on it if they exist. Also, do not forget to mention the type / style of the image. No bullet points. When writing descriptions, prioritize clarity and direct observation over embellishment or interpretation.\nDon't forget these rules:\n1. **Be Direct and Concise**: Provide straightforward descriptions without adding interpretative or speculative elements.\n2. **Use Segmented Details**: Break down details about different elements of an image into distinct sentences, focusing on one aspect at a time.\n3. **Maintain a Descriptive Focus**: Prioritize purely visible elements of the image, avoiding conclusions or inferences.\n4. **Follow a Logical Structure**: Begin with the central figure or subject and expand outward, detailing its appearance before addressing the surrounding setting.\n5. **Avoid Juxtaposition**: Do not use comparison or contrast language; keep the description purely factual.\n6. **Incorporate Specificity**: Mention age, gender, race, and specific brands or notable features when present, and clearly identify the medium if it's discernible." + }, + { + "from": "gpt", + "value": assistant + } + ] + } + +def format_ovis2_dataset(image, assistant): return [ { "role": "user", @@ -66,9 +80,12 @@ def prepare_dataset(format_func, n_sample: int = 20) -> list[list[dict]]: def get_calib_dataset(model): - if isinstance(model, OvisQModel) or isinstance(model, Ovis2QModel): + if isinstance(model, OvisQModel): return prepare_dataset(format_ovis_dataset, n_sample=20) + if isinstance(model, Ovis2QModel): + return prepare_dataset(format_ovis2_dataset, n_sample=20) + if isinstance(model, BaseQwen2VLGPTQ): return prepare_dataset(format_qwen2_vl_dataset, n_sample=20) From 03523a39ba7e9e6fb44d2d1aa9e0105b12434ff4 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Tue, 28 Oct 2025 04:05:03 +0000 Subject: [PATCH 3/3] cleanup Signed-off-by: ZX-ModelCloud --- gptqmodel/models/definitions/ovis2.py | 1 - tests/models/ovis/image_to_test_dataset.py | 2 ++ tests/models/test_ovis2.py | 1 - 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gptqmodel/models/definitions/ovis2.py b/gptqmodel/models/definitions/ovis2.py index 5cbe2ff2c..21693f90a 100644 --- a/gptqmodel/models/definitions/ovis2.py +++ b/gptqmodel/models/definitions/ovis2.py @@ -61,7 +61,6 @@ def prepare_dataset(self, calibration_dataset, batch_size: int = 1, **kwargs): processor = self.load_processor() calib_data = [] for batch in batched(calibration_dataset, batch_size, process_func=self.preprocess_dataset): - print("batch", batch) text = processor.apply_chat_template( batch, tokenize=False, add_generation_prompt=True ) diff --git a/tests/models/ovis/image_to_test_dataset.py b/tests/models/ovis/image_to_test_dataset.py index 6739ee8f9..ac9ef0969 100644 --- a/tests/models/ovis/image_to_test_dataset.py +++ b/tests/models/ovis/image_to_test_dataset.py @@ -48,6 +48,7 @@ def format_qwen2_vl_dataset(image, assistant): {"role": "assistant", "content": assistant}, ] + def format_qwen2_5_omni_dataset(image, assistant): return [ { @@ -67,6 +68,7 @@ def format_qwen2_5_omni_dataset(image, assistant): {"role": "assistant", "content": assistant}, ] + def prepare_dataset(format_func, n_sample: int = 20) -> list[list[dict]]: from datasets import load_dataset diff --git a/tests/models/test_ovis2.py b/tests/models/test_ovis2.py index cfb1b098c..186b03a50 100644 --- a/tests/models/test_ovis2.py +++ b/tests/models/test_ovis2.py @@ -34,7 +34,6 @@ def test_ovis(self): image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ovis/10016.jpg") image = Image.open(image_path) messages = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - print(messages) inputs = processor( images=[image],