diff --git a/gptqmodel/eora/eora.py b/gptqmodel/eora/eora.py index 56b1fa7b1..6c037471f 100644 --- a/gptqmodel/eora/eora.py +++ b/gptqmodel/eora/eora.py @@ -5,7 +5,7 @@ # EoRA Official Repo: https://github.com/NVlabs/EoRA # This file has been modified by ModelCloud.AI team and qubitium@modelcloud.ai for adoption into GPT-QModel -from typing import Tuple +from typing import Sequence, Tuple import torch from torch import Tensor @@ -24,9 +24,8 @@ def eora_process_input( ) -> Tuple[int, torch.Tensor, float]: """Prepare the per-batch covariance contribution required for EoRA. - The contribution is computed on the target device for throughput, then - transferred to CPU so it can be safely aggregated across worker threads and - devices when the interpreter runs without the GIL. + The contribution remains on the originating device so multi-GPU execution + can accumulate locally before a single merge step. """ inp = input[0].to(device=device, dtype=torch.float32) @@ -37,7 +36,7 @@ def eora_process_input( adds = torch.matmul(inp.transpose(1, 2), inp) adds_sum = torch.sum(adds, dim=0).detach() - contribution = adds_sum.to(device=torch.device("cpu"), dtype=torch.float32) + contribution = adds_sum.to(dtype=torch.float32) contribution /= float(sample_size) # Adding batch to denominator is only for mathematical stability @@ -47,6 +46,30 @@ def eora_process_input( return batch, contribution, scale + +def merge_eora_segments(segments: Sequence[Tuple[torch.Tensor, float]]) -> torch.Tensor: + """Combine pre-aggregated EoRA segments using their scale products. + + Each segment entry is a tuple ``(total, scale_product)`` where ``total`` is + the sequential accumulation result for that segment starting from zero, and + ``scale_product`` is the product of per-batch scale factors encountered in + the same segment. The function mutates the first segment tensor in place + and returns it as the merged result. + """ + if not segments: + raise ValueError("EoRA merge requires at least one segment.") + + result: torch.Tensor | None = None + for total, scale_product in segments: + if result is None: + result = total + else: + result.mul_(float(scale_product)) + result.add_(total) + + assert result is not None + return result + def eora_compute_lora( w_wq_delta: Tensor, # need the w (original weight) and wq (quantized qweight) delta in float32 name: str, diff --git a/gptqmodel/looper/eora_processor.py b/gptqmodel/looper/eora_processor.py index 9834c8cea..a3812b7f8 100644 --- a/gptqmodel/looper/eora_processor.py +++ b/gptqmodel/looper/eora_processor.py @@ -4,15 +4,14 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium import copy -import heapq import time -from typing import Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch from torch.nn import Module from ..adapter.adapter import Lora -from ..eora.eora import eora_compute_lora, eora_process_input +from ..eora.eora import eora_compute_lora, eora_process_input, merge_eora_segments from ..looper.loop_processor import DTYPE_SIZE_COLUMN, MODULE_FEATURE_COLUMN, LoopProcessor from ..looper.named_module import NamedModule from ..models import BaseQModel @@ -20,6 +19,7 @@ PROCESS_LOG_NAME, PROCESS_LOG_TIME, PROCESS_USED_MEMORY) from ..quantization.config import QuantizeConfig from ..utils.logger import setup_logger +from ..utils.device import get_device from ..utils.model import move_to from ..utils.torch import CPU, DEVICE_0, DEVICE_1, torch_streamCtx, torch_sync from ..utils.torch import HAS_CUDA, tf32_disable_guard, torch_streamCtx, torch_sync @@ -38,10 +38,10 @@ def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration, prepare_dataset prepare_dataset_func=prepare_dataset_func, batch_size=batch_size, require_fwd=require_fwd) - # dict: key is module name, value is the accumulated eigen_scaling_diag_matrix - self.eigen_scaling_diag_matrix: Dict[str, torch.Tensor] = {} - self._pending_contributions: Dict[str, List[Tuple[int, int, torch.Tensor, float]]] = {} - self._next_batch_index: Dict[str, int] = {} + # Track per-module segment accumulators keyed by device so we can merge + # contributions without repeatedly moving data through the CPU. + self._segment_accumulators: Dict[str, Dict[torch.device, Dict[str, Any]]] = {} + self._module_target_devices: Dict[str, torch.device] = {} # Increase the dynamo cache size limit, default of 8 is too low if torch._dynamo.config.cache_size_limit < 64: @@ -78,9 +78,12 @@ def preprocess(self, module: NamedModule, **kwargs): # hack store property inside module module.adapter_cfg = adapter_cfg - self.eigen_scaling_diag_matrix[module.name] = None - self._pending_contributions[module.name] = [] - self._next_batch_index[module.name] = 0 + target_device = get_device(module.module) + if target_device.type == "meta": + target_device = torch.device("cpu") + + self._module_target_devices[module.name] = torch.device(target_device) + self._segment_accumulators[module.name] = {} return @@ -99,7 +102,7 @@ def tmp(module, input: Tuple[torch.Tensor, ...], output: torch.Tensor): device=module.weight.data.device, ) - self._stage_eora_contribution( + self._accumulate_eora_contribution( name=name, batch_index=batch_index, batch=batch, @@ -108,7 +111,7 @@ def tmp(module, input: Tuple[torch.Tensor, ...], output: torch.Tensor): ) return tmp - def _stage_eora_contribution( + def _accumulate_eora_contribution( self, *, name: str, @@ -120,49 +123,78 @@ def _stage_eora_contribution( if batch <= 0: return - with self.lock: - queue = self._pending_contributions.setdefault(name, []) - self._next_batch_index.setdefault(name, 0) - - pending_index = batch_index if batch_index is not None else self._next_batch_index[name] - heapq.heappush(queue, (pending_index, batch, contribution, scale)) - self._flush_eora_pending_locked(name) - - def _flush_eora_pending_locked(self, name: str) -> None: - queue = self._pending_contributions.get(name) - if not queue: - return + contribution = contribution.detach() + device = torch.device(contribution.device) + scale_value = float(scale) - expected = self._next_batch_index.get(name, 0) - - while queue: - index, _batch, contribution, scale = queue[0] - - if index < expected: - heapq.heappop(queue) - continue + with self.lock: + accumulators = self._segment_accumulators.setdefault(name, {}) + record = accumulators.get(device) + + index_value = int(batch_index) if batch_index is not None else 0 + + if record is None: + record = { + "total": contribution, + "scale_product": scale_value, + "start_index": index_value, + "end_index": index_value, + "count": 1, + } + accumulators[device] = record + return + + total = record["total"] + if total.device != contribution.device: + total = total.to(device=contribution.device) + + total.mul_(scale_value) + total.add_(contribution) + + record["total"] = total + record["scale_product"] *= scale_value + record["count"] += 1 + + if batch_index is not None: + batch_value = int(batch_index) + if record["start_index"] is None or batch_value < record["start_index"]: + record["start_index"] = batch_value + if record["end_index"] is None or batch_value > record["end_index"]: + record["end_index"] = batch_value + else: + if record.get("start_index") is None: + record["start_index"] = record["count"] - 1 + record["end_index"] = record["count"] - 1 - if index != expected: - break + del contribution - heapq.heappop(queue) + def _finalize_eigen_scaling_matrix(self, name: str) -> torch.Tensor: + with self.lock: + segments = self._segment_accumulators.pop(name, {}) + target_device = self._module_target_devices.pop(name, None) - current = self.eigen_scaling_diag_matrix.get(name) + if not segments: + raise RuntimeError( + f"EoRA statistics for module '{name}' were not collected before processing." + ) - if isinstance(current, torch.Tensor): - if current.device != torch.device("cpu"): - current = current.to(device=torch.device("cpu"), dtype=torch.float32) + ordered_segments = sorted( + segments.values(), + key=lambda record: record.get("start_index", 0), + ) - current.mul_(scale) - current.add_(contribution) - self.eigen_scaling_diag_matrix[name] = current - del contribution - else: - self.eigen_scaling_diag_matrix[name] = contribution + if target_device is None: + first_total = ordered_segments[0]["total"] + target_device = torch.device(first_total.device) - expected = index + 1 + segment_pairs = [] + for record in ordered_segments: + total = record["total"] + if total.device != target_device: + total = total.to(device=target_device, dtype=torch.float32) + segment_pairs.append((total, float(record["scale_product"]))) - self._next_batch_index[name] = expected + return merge_eora_segments(segment_pairs) def process(self, module: NamedModule): assert isinstance(module.adapter_cfg, Lora) @@ -171,16 +203,7 @@ def process(self, module: NamedModule): start = time.time() - with self.lock: - self._flush_eora_pending_locked(module.name) - eigen_scaling_diag_matrix = self.eigen_scaling_diag_matrix.pop(module.name) - self._pending_contributions.pop(module.name, None) - self._next_batch_index.pop(module.name, None) - - if not isinstance(eigen_scaling_diag_matrix, torch.Tensor): - raise RuntimeError( - f"EoRA statistics for module '{module.name}' were not collected before processing." - ) + eigen_scaling_diag_matrix = self._finalize_eigen_scaling_matrix(module.name) tp_info = module.state.get("tp_pad_info") pad_cols = 0 @@ -309,9 +332,8 @@ def finalize(self, model: BaseQModel, **kwargs): if self.stream: torch_sync() - del self.eigen_scaling_diag_matrix - del self._pending_contributions - del self._next_batch_index + del self._segment_accumulators + del self._module_target_devices # hack: store loras into model until `save()` is called model.lora_results = self.results() diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 5da3ed302..4ff916554 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -11,7 +11,7 @@ import sys import threading import time -from typing import Dict, List, Optional, Tuple +from typing import Dict, Optional, Tuple import numpy as np import torch diff --git a/tests/test_eora_merge.py b/tests/test_eora_merge.py new file mode 100644 index 000000000..adc59860f --- /dev/null +++ b/tests/test_eora_merge.py @@ -0,0 +1,69 @@ +import math + +import pytest +import torch + +from gptqmodel.eora.eora import eora_process_input, merge_eora_segments + + +def _segment_reduce(contributions, scales): + total = None + scale_product = 1.0 + for contribution, scale in zip(contributions, scales): + if total is None: + total = contribution.clone() + scale_product = float(scale) + else: + total.mul_(float(scale)) + total.add_(contribution) + scale_product *= float(scale) + return total, scale_product + + +@pytest.mark.parametrize("segments", [1, 2, 4]) +def test_eora_merge_matches_sequential(segments): + torch.manual_seed(0) + + sample_size = 96 + cols = 8 + base_contributions = [ + torch.randn(cols, cols, dtype=torch.float32) + for _ in range(sample_size) + ] + def scale_fn(batch): + return sample_size / (sample_size + batch) + scales = [scale_fn(1) for _ in range(sample_size)] + + sequential = torch.zeros((cols, cols), dtype=torch.float32) + for contribution, scale in zip(base_contributions, scales): + sequential.mul_(float(scale)) + sequential.add_(contribution) + + segment_length = math.ceil(sample_size / segments) + segment_pairs = [] + for start in range(0, sample_size, segment_length): + end = min(start + segment_length, sample_size) + contributions = base_contributions[start:end] + seg_scales = scales[start:end] + segment_pairs.append(_segment_reduce(contributions, seg_scales)) + + merged = merge_eora_segments(segment_pairs) + torch.testing.assert_close(merged, sequential, atol=5e-6, rtol=5e-6) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_eora_process_input_preserves_device(): + device = torch.device("cuda", 0) + sample = torch.randn(1, 4, 6, device=device, dtype=torch.float16) + + batch, contribution, scale = eora_process_input( + input=(sample,), + name="test", + sample_size=32, + device=device, + ) + + assert batch == 1 + assert contribution.device == device + assert contribution.dtype == torch.float32 + assert isinstance(scale, float) diff --git a/tests/test_gc.py b/tests/test_gc.py index 3a8bca08d..20bcbe16a 100644 --- a/tests/test_gc.py +++ b/tests/test_gc.py @@ -5,8 +5,10 @@ from queue import Queue import pytest + from gptqmodel.utils.safe import gc as safe_gc + torch = pytest.importorskip("torch", reason="requires PyTorch") @@ -27,7 +29,7 @@ def _worker(barrier: threading.Barrier, safe: bool, errors: Queue) -> None: safe_gc.collect() else: gc.collect() - except Exception as exc: # pragma: no cover - stress test safeguard + except Exception: # pragma: no cover - stress test safeguard # Preserve the traceback so the failing test shows context from worker threads. errors.put(traceback.format_exc()) diff --git a/tests/test_gptq_queue.py b/tests/test_gptq_queue.py index 0a8faf77e..3bb9f2f6a 100644 --- a/tests/test_gptq_queue.py +++ b/tests/test_gptq_queue.py @@ -1,7 +1,7 @@ import copy -import torch import pytest +import torch from gptqmodel.quantization.config import QuantizeConfig from gptqmodel.quantization.gptq import GPTQ diff --git a/tests/test_hessian_chunk.py b/tests/test_hessian_chunk.py index f416fd866..5149f0845 100644 --- a/tests/test_hessian_chunk.py +++ b/tests/test_hessian_chunk.py @@ -79,7 +79,7 @@ def test_hessian_chunk_consistency_matches_full_precision(): assert full_device == chunked_device assert full_xtx is not None and chunked_xtx is not None - assert torch.allclose(full_xtx, chunked_xtx, atol=5e-6, rtol=5e-6) + assert torch.allclose(full_xtx, chunked_xtx, atol=3e-6, rtol=3e-6) def test_hessian_chunk_invocations_and_workspace_shape(): diff --git a/tests/test_parameter_count.py b/tests/test_parameter_count.py index 677d18281..5f4cdbe35 100644 --- a/tests/test_parameter_count.py +++ b/tests/test_parameter_count.py @@ -3,14 +3,13 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import os import os.path import tempfile import torch.cuda from models.model_test import ModelTest from safetensors.torch import load_file -from safetensors import safe_open -import os from gptqmodel import GPTQModel, QuantizeConfig from gptqmodel.utils.tensor import tensor_parameters diff --git a/tests/test_quant_and_eora.py b/tests/test_quant_and_eora.py index ab9c01d61..92055bc89 100644 --- a/tests/test_quant_and_eora.py +++ b/tests/test_quant_and_eora.py @@ -97,7 +97,7 @@ def test_quant_and_eora(self, quant_method: METHOD, format: FORMAT): torch_empty_cache() # BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, - for backend in [ BACKEND.TORCH ]: # BACKEND.IPEX, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN + for backend in [ BACKEND.MARLIN ]: # BACKEND.IPEX, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN base_bench = self.bench(path=tmpdir, backend=backend, adapter=None) # inference using qweights only eora_bench = self.bench(path=tmpdir, backend=backend, adapter=eora) # inference using eora (lora)