From a80ed19377b6a9b046f7a06d5f302188e0b74f6c Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 1 Oct 2025 01:40:44 +0000 Subject: [PATCH 1/2] bypass accelerate's thread unsafe clear_device_cache Signed-off-by: Qubitium --- gptqmodel/utils/offload.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/gptqmodel/utils/offload.py b/gptqmodel/utils/offload.py index 85a1f045f..56864f9c2 100644 --- a/gptqmodel/utils/offload.py +++ b/gptqmodel/utils/offload.py @@ -9,6 +9,7 @@ import threading from typing import Iterable, List, Optional, Set, Tuple +import accelerate import torch # move base_module tensors to disk @@ -19,11 +20,19 @@ from ..looper.named_module import NamedModule from .device import get_device -from .torch import CPU, HAS_CUDA, META +from .torch import CPU, META _lock = threading.Lock() +# Patch fix thread unsafe accelerate.utils.modeling.clear_device_cache +def _fake_clear_device_cache(garbage_collection=False): + pass + +# keep original +ACCELERATE_CLEAR_DEVICE_CACHE = accelerate.utils.modeling.clear_device_cache +accelerate.utils.modeling.clear_device_cache = _fake_clear_device_cache + def get_module_fullname(model: torch.nn.Module, module: torch.nn.Module) -> str: for name, mod in model.named_modules(): if mod is module: @@ -104,14 +113,6 @@ def _offload_disk(module: nn.Module, name: str, disk_path: str = "."): if not has_params and not has_buffers: return - # print(f"Offload source device: {m_device}") - # print_module_tree(module) - # TODO FIXME pending PR upstream: https://github.com/huggingface/accelerate/pull/3796 - real_cache_flush = None - if HAS_CUDA: - real_cache_flush = torch.cuda.empty_cache - torch.cuda.empty_cache = lambda: None - _ = disk_offload( module, # device_map={ "" : "disk" }, # only touch this subtree @@ -120,9 +121,6 @@ def _offload_disk(module: nn.Module, name: str, disk_path: str = "."): execution_device=m_device, ) - if real_cache_flush: - torch.cuda.empty_cache = real_cache_flush - # print("offload_disk: list item tree") # print_module_tree(module) From 3de8467c2a89cfaf7972905cf4aeb898dcb882c5 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 1 Oct 2025 01:40:59 +0000 Subject: [PATCH 2/2] formt Signed-off-by: Qubitium --- gptqmodel/models/definitions/base_qwen2_5_omni.py | 6 ++++-- gptqmodel/models/definitions/ovis.py | 2 +- gptqmodel/models/definitions/qwen3_omni_moe.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/gptqmodel/models/definitions/base_qwen2_5_omni.py b/gptqmodel/models/definitions/base_qwen2_5_omni.py index 8827b69dc..23f43fc7f 100644 --- a/gptqmodel/models/definitions/base_qwen2_5_omni.py +++ b/gptqmodel/models/definitions/base_qwen2_5_omni.py @@ -3,11 +3,13 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from typing import Dict, Optional import os +from typing import Dict, Optional + import torch from PIL import Image from transformers import AutoModelForTextToWaveform, AutoProcessor, ProcessorMixin + from ...utils.calibration import batched from ...utils.image import extract_vision_info, fetch_image from ...utils.model import MODALITY @@ -84,7 +86,7 @@ def pre_quantize_generate_hook_end(self): module=self.model.thinker.model.rotary_emb, disk_path=self.quantize_config.offload_to_disk_path, ) - + for layer in self.model.thinker.model.layers: layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to(CPU) diff --git a/gptqmodel/models/definitions/ovis.py b/gptqmodel/models/definitions/ovis.py index 8485975c1..234ffa81b 100644 --- a/gptqmodel/models/definitions/ovis.py +++ b/gptqmodel/models/definitions/ovis.py @@ -93,7 +93,7 @@ def prepare_dataset( calibration_dataset, calibration_dataset_concat_size, batch_size: int = 1, - tokenizer=None, + tokenizer=None, **kwargs): calib_data = [] for batch in batched(calibration_dataset, batch_size, self.preprocess_dataset): diff --git a/gptqmodel/models/definitions/qwen3_omni_moe.py b/gptqmodel/models/definitions/qwen3_omni_moe.py index 9d7da24c0..30eceae24 100644 --- a/gptqmodel/models/definitions/qwen3_omni_moe.py +++ b/gptqmodel/models/definitions/qwen3_omni_moe.py @@ -89,6 +89,6 @@ def after_model_load(self, model, load_quantized_model=False): # need to load processor for save processor_config and chat_template if not load_quantized_model: self.processor = AutoProcessor.from_pretrained(self.model_local_path) - + return model