Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions gptqmodel/eora/eora.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# EoRA Official Repo: https://github.com/NVlabs/EoRA
# This file has been modified by ModelCloud.AI team and qubitium@modelcloud.ai for adoption into GPT-QModel

from typing import Tuple
from typing import Sequence, Tuple

import torch
from torch import Tensor
Expand All @@ -24,9 +24,8 @@ def eora_process_input(
) -> Tuple[int, torch.Tensor, float]:
"""Prepare the per-batch covariance contribution required for EoRA.

The contribution is computed on the target device for throughput, then
transferred to CPU so it can be safely aggregated across worker threads and
devices when the interpreter runs without the GIL.
The contribution remains on the originating device so multi-GPU execution
can accumulate locally before a single merge step.
"""

inp = input[0].to(device=device, dtype=torch.float32)
Expand All @@ -37,7 +36,7 @@ def eora_process_input(
adds = torch.matmul(inp.transpose(1, 2), inp)
adds_sum = torch.sum(adds, dim=0).detach()

contribution = adds_sum.to(device=torch.device("cpu"), dtype=torch.float32)
contribution = adds_sum.to(dtype=torch.float32)
contribution /= float(sample_size)

# Adding batch to denominator is only for mathematical stability
Expand All @@ -47,6 +46,30 @@ def eora_process_input(

return batch, contribution, scale


def merge_eora_segments(segments: Sequence[Tuple[torch.Tensor, float]]) -> torch.Tensor:
"""Combine pre-aggregated EoRA segments using their scale products.

Each segment entry is a tuple ``(total, scale_product)`` where ``total`` is
the sequential accumulation result for that segment starting from zero, and
``scale_product`` is the product of per-batch scale factors encountered in
the same segment. The function mutates the first segment tensor in place
and returns it as the merged result.
"""
if not segments:
raise ValueError("EoRA merge requires at least one segment.")

result: torch.Tensor | None = None
for total, scale_product in segments:
if result is None:
result = total
else:
result.mul_(float(scale_product))
result.add_(total)

assert result is not None
return result

def eora_compute_lora(
w_wq_delta: Tensor, # need the w (original weight) and wq (quantized qweight) delta in float32
name: str,
Expand Down
142 changes: 82 additions & 60 deletions gptqmodel/looper/eora_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,22 @@
# Contact: qubitium@modelcloud.ai, x.com/qubitium

import copy
import heapq
import time
from typing import Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
from torch.nn import Module

from ..adapter.adapter import Lora
from ..eora.eora import eora_compute_lora, eora_process_input
from ..eora.eora import eora_compute_lora, eora_process_input, merge_eora_segments
from ..looper.loop_processor import DTYPE_SIZE_COLUMN, MODULE_FEATURE_COLUMN, LoopProcessor
from ..looper.named_module import NamedModule
from ..models import BaseQModel
from ..models.writer import (PROCESS_LOG_FWD_TIME, PROCESS_LOG_LAYER, PROCESS_LOG_MODULE,
PROCESS_LOG_NAME, PROCESS_LOG_TIME, PROCESS_USED_MEMORY)
from ..quantization.config import QuantizeConfig
from ..utils.logger import setup_logger
from ..utils.device import get_device
from ..utils.model import move_to
from ..utils.torch import CPU, DEVICE_0, DEVICE_1, torch_streamCtx, torch_sync
from ..utils.torch import HAS_CUDA, tf32_disable_guard, torch_streamCtx, torch_sync
Expand All @@ -38,10 +38,10 @@ def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration, prepare_dataset
prepare_dataset_func=prepare_dataset_func, batch_size=batch_size,
require_fwd=require_fwd)

# dict: key is module name, value is the accumulated eigen_scaling_diag_matrix
self.eigen_scaling_diag_matrix: Dict[str, torch.Tensor] = {}
self._pending_contributions: Dict[str, List[Tuple[int, int, torch.Tensor, float]]] = {}
self._next_batch_index: Dict[str, int] = {}
# Track per-module segment accumulators keyed by device so we can merge
# contributions without repeatedly moving data through the CPU.
self._segment_accumulators: Dict[str, Dict[torch.device, Dict[str, Any]]] = {}
self._module_target_devices: Dict[str, torch.device] = {}

# Increase the dynamo cache size limit, default of 8 is too low
if torch._dynamo.config.cache_size_limit < 64:
Expand Down Expand Up @@ -78,9 +78,12 @@ def preprocess(self, module: NamedModule, **kwargs):
# hack store property inside module
module.adapter_cfg = adapter_cfg

self.eigen_scaling_diag_matrix[module.name] = None
self._pending_contributions[module.name] = []
self._next_batch_index[module.name] = 0
target_device = get_device(module.module)
if target_device.type == "meta":
target_device = torch.device("cpu")

self._module_target_devices[module.name] = torch.device(target_device)
self._segment_accumulators[module.name] = {}

return

Expand All @@ -99,7 +102,7 @@ def tmp(module, input: Tuple[torch.Tensor, ...], output: torch.Tensor):
device=module.weight.data.device,
)

self._stage_eora_contribution(
self._accumulate_eora_contribution(
name=name,
batch_index=batch_index,
batch=batch,
Expand All @@ -108,7 +111,7 @@ def tmp(module, input: Tuple[torch.Tensor, ...], output: torch.Tensor):
)
return tmp

def _stage_eora_contribution(
def _accumulate_eora_contribution(
self,
*,
name: str,
Expand All @@ -120,49 +123,78 @@ def _stage_eora_contribution(
if batch <= 0:
return

with self.lock:
queue = self._pending_contributions.setdefault(name, [])
self._next_batch_index.setdefault(name, 0)

pending_index = batch_index if batch_index is not None else self._next_batch_index[name]
heapq.heappush(queue, (pending_index, batch, contribution, scale))
self._flush_eora_pending_locked(name)

def _flush_eora_pending_locked(self, name: str) -> None:
queue = self._pending_contributions.get(name)
if not queue:
return
contribution = contribution.detach()
device = torch.device(contribution.device)
scale_value = float(scale)

expected = self._next_batch_index.get(name, 0)

while queue:
index, _batch, contribution, scale = queue[0]

if index < expected:
heapq.heappop(queue)
continue
with self.lock:
accumulators = self._segment_accumulators.setdefault(name, {})
record = accumulators.get(device)

index_value = int(batch_index) if batch_index is not None else 0

if record is None:
record = {
"total": contribution,
"scale_product": scale_value,
"start_index": index_value,
"end_index": index_value,
"count": 1,
}
accumulators[device] = record
return

total = record["total"]
if total.device != contribution.device:
total = total.to(device=contribution.device)

total.mul_(scale_value)
total.add_(contribution)

record["total"] = total
record["scale_product"] *= scale_value
record["count"] += 1

if batch_index is not None:
batch_value = int(batch_index)
if record["start_index"] is None or batch_value < record["start_index"]:
record["start_index"] = batch_value
if record["end_index"] is None or batch_value > record["end_index"]:
record["end_index"] = batch_value
else:
if record.get("start_index") is None:
record["start_index"] = record["count"] - 1
record["end_index"] = record["count"] - 1

if index != expected:
break
del contribution

heapq.heappop(queue)
def _finalize_eigen_scaling_matrix(self, name: str) -> torch.Tensor:
with self.lock:
segments = self._segment_accumulators.pop(name, {})
target_device = self._module_target_devices.pop(name, None)

current = self.eigen_scaling_diag_matrix.get(name)
if not segments:
raise RuntimeError(
f"EoRA statistics for module '{name}' were not collected before processing."
)

if isinstance(current, torch.Tensor):
if current.device != torch.device("cpu"):
current = current.to(device=torch.device("cpu"), dtype=torch.float32)
ordered_segments = sorted(
segments.values(),
key=lambda record: record.get("start_index", 0),
)

current.mul_(scale)
current.add_(contribution)
self.eigen_scaling_diag_matrix[name] = current
del contribution
else:
self.eigen_scaling_diag_matrix[name] = contribution
if target_device is None:
first_total = ordered_segments[0]["total"]
target_device = torch.device(first_total.device)

expected = index + 1
segment_pairs = []
for record in ordered_segments:
total = record["total"]
if total.device != target_device:
total = total.to(device=target_device, dtype=torch.float32)
segment_pairs.append((total, float(record["scale_product"])))

self._next_batch_index[name] = expected
return merge_eora_segments(segment_pairs)

def process(self, module: NamedModule):
assert isinstance(module.adapter_cfg, Lora)
Expand All @@ -171,16 +203,7 @@ def process(self, module: NamedModule):

start = time.time()

with self.lock:
self._flush_eora_pending_locked(module.name)
eigen_scaling_diag_matrix = self.eigen_scaling_diag_matrix.pop(module.name)
self._pending_contributions.pop(module.name, None)
self._next_batch_index.pop(module.name, None)

if not isinstance(eigen_scaling_diag_matrix, torch.Tensor):
raise RuntimeError(
f"EoRA statistics for module '{module.name}' were not collected before processing."
)
eigen_scaling_diag_matrix = self._finalize_eigen_scaling_matrix(module.name)

tp_info = module.state.get("tp_pad_info")
pad_cols = 0
Expand Down Expand Up @@ -309,9 +332,8 @@ def finalize(self, model: BaseQModel, **kwargs):
if self.stream:
torch_sync()

del self.eigen_scaling_diag_matrix
del self._pending_contributions
del self._next_batch_index
del self._segment_accumulators
del self._module_target_devices

# hack: store loras into model until `save()` is called
model.lora_results = self.results()
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import sys
import threading
import time
from typing import Dict, List, Optional, Tuple
from typing import Dict, Optional, Tuple

import numpy as np
import torch
Expand Down
69 changes: 69 additions & 0 deletions tests/test_eora_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import math

import pytest
import torch

from gptqmodel.eora.eora import eora_process_input, merge_eora_segments


def _segment_reduce(contributions, scales):
total = None
scale_product = 1.0
for contribution, scale in zip(contributions, scales):
if total is None:
total = contribution.clone()
scale_product = float(scale)
else:
total.mul_(float(scale))
total.add_(contribution)
scale_product *= float(scale)
return total, scale_product


@pytest.mark.parametrize("segments", [1, 2, 4])
def test_eora_merge_matches_sequential(segments):
torch.manual_seed(0)

sample_size = 96
cols = 8
base_contributions = [
torch.randn(cols, cols, dtype=torch.float32)
for _ in range(sample_size)
]
def scale_fn(batch):
return sample_size / (sample_size + batch)
scales = [scale_fn(1) for _ in range(sample_size)]

sequential = torch.zeros((cols, cols), dtype=torch.float32)
for contribution, scale in zip(base_contributions, scales):
sequential.mul_(float(scale))
sequential.add_(contribution)

segment_length = math.ceil(sample_size / segments)
segment_pairs = []
for start in range(0, sample_size, segment_length):
end = min(start + segment_length, sample_size)
contributions = base_contributions[start:end]
seg_scales = scales[start:end]
segment_pairs.append(_segment_reduce(contributions, seg_scales))

merged = merge_eora_segments(segment_pairs)
torch.testing.assert_close(merged, sequential, atol=5e-6, rtol=5e-6)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
def test_eora_process_input_preserves_device():
device = torch.device("cuda", 0)
sample = torch.randn(1, 4, 6, device=device, dtype=torch.float16)

batch, contribution, scale = eora_process_input(
input=(sample,),
name="test",
sample_size=32,
device=device,
)

assert batch == 1
assert contribution.device == device
assert contribution.dtype == torch.float32
assert isinstance(scale, float)
4 changes: 3 additions & 1 deletion tests/test_gc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from queue import Queue

import pytest

from gptqmodel.utils.safe import gc as safe_gc


torch = pytest.importorskip("torch", reason="requires PyTorch")


Expand All @@ -27,7 +29,7 @@ def _worker(barrier: threading.Barrier, safe: bool, errors: Queue) -> None:
safe_gc.collect()
else:
gc.collect()
except Exception as exc: # pragma: no cover - stress test safeguard
except Exception: # pragma: no cover - stress test safeguard
# Preserve the traceback so the failing test shows context from worker threads.
errors.put(traceback.format_exc())

Expand Down
2 changes: 1 addition & 1 deletion tests/test_gptq_queue.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy

import torch
import pytest
import torch

from gptqmodel.quantization.config import QuantizeConfig
from gptqmodel.quantization.gptq import GPTQ
Expand Down
2 changes: 1 addition & 1 deletion tests/test_hessian_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_hessian_chunk_consistency_matches_full_precision():

assert full_device == chunked_device
assert full_xtx is not None and chunked_xtx is not None
assert torch.allclose(full_xtx, chunked_xtx, atol=5e-6, rtol=5e-6)
assert torch.allclose(full_xtx, chunked_xtx, atol=3e-6, rtol=3e-6)


def test_hessian_chunk_invocations_and_workspace_shape():
Expand Down
Loading