Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,8 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) ->
num_devices = len(devices)

for index, device in enumerate(devices):
# Split the outstanding batches across devices so that each accelerator
# receives a contiguous slice.
remaining_batches = max(total_batches - segment_start, 0)
remaining_devices = max(num_devices - index, 1)
segment_length = remaining_batches // remaining_devices
Expand All @@ -841,6 +843,7 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) ->
max_segment_length = len(indices)

for position in range(max_segment_length):
# Submit one batch per device
futures = []
for device in devices:
segment_indices = device_segments.get(device, [])
Expand Down Expand Up @@ -874,6 +877,7 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) ->
)

for fut in futures:
# Preserve the original batch order
batch_idx, module_output, kv_next = fut.result()
if need_outputs and module_output is not None:
results[batch_idx] = module_output
Expand Down Expand Up @@ -903,6 +907,8 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) ->

ordered_outputs: List[List[torch.Tensor]] = []
for idx in range(total_batches):
# Rebuild the ordered list of batch outputs expected by the next
# stage.
module_output = results.get(idx)
if module_output is None:
raise RuntimeError("Forward batch returned no output; data-parallel execution produced empty result.")
Expand Down Expand Up @@ -1105,6 +1111,7 @@ def _loop_impl(self, fail_safe: bool = False, **kwargs):

try:
for index, reverse_p in enumerate(reversed_processors, start=1):
# Finalize processors in reverse order
self._check_loop_stop()
if isinstance(reverse_p, GPTQProcessor):
pass
Expand Down
14 changes: 14 additions & 0 deletions gptqmodel/looper/stage_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def run_layer_stage(
"""Execute the main per-layer quantization loop."""
log = logger or setup_logger()
for layer_index in pb:
# Iterate over every transformer layer (plus lm_head when enabled) as
# progress-bar controlled units of work.
if looper._check_loop_stop():
break
is_lm_head_module = layer_index >= layer_count
Expand Down Expand Up @@ -78,6 +80,8 @@ def run_layer_stage(
full = find_modules(module, name=looper.gptq_model.lm_head if is_lm_head_module else "")

for p_index, processor in enumerate(looper.processors):
# Each processor contributes a quantization phase; walk them in
# order so their caches and side effects line up with the pipeline.
processor.log_call_count = 0 # reset
processor.collect_memory_info(layer_index)

Expand All @@ -101,6 +105,8 @@ def run_layer_stage(
previous_subset_processed: Optional[Dict[str, NamedModule]] = None

for index, names in enumerate(modules):
# Process the layer in smaller subsets so attention groups or
# MoE experts can be quantized independently within a layer.
if isinstance(processor, AWQProcessor):
log.info(
"StageLayer[awq]: layer=%s subset=%s/%s size=%s names=%s",
Expand Down Expand Up @@ -273,6 +279,8 @@ def run_layer_stage(
looper.gptq_model.post_quantize(module)

for finalized in processed_subset.values():
# Reset finalized modules to CPU to guarantee deterministic
# ownership before the next processor touches the layer.
if isinstance(finalized, NamedModule):
setattr(finalized, "target_device", CPU)
inner_module = getattr(finalized, "module", None)
Expand All @@ -299,6 +307,8 @@ def run_layer_stage(
finalize_tasks = []

for reverse_p in reversed(looper.processors):
# Collect finalize tasks in reverse to mirror the processor
# execution order and honor downstream dependencies.
for module in processed_subset.values():
actual_module = module.module if isinstance(module, NamedModule) else module

Expand Down Expand Up @@ -383,6 +393,8 @@ def _finalize_on_worker(process, module, idx, total, module_label, layer_idx):
# ).draw()

for index, (process, module, module_label, target_dev, layer_idx) in enumerate(finalize_tasks, start=1):
# Schedule finalize work on the device thread pool so CPU
# bound tasks do not stall the main orchestration loop.
future = DEVICE_THREAD_POOL.submit(
target_dev,
_finalize_on_worker,
Expand Down Expand Up @@ -440,6 +452,8 @@ def _drain_finalize_futures(
completed_local = 0
try:
for future in as_completed(futures):
# Drain futures as they complete to surface errors
# quickly and keep the progress bar in sync.
try:
result = future.result()
except BaseException as exc:
Expand Down
16 changes: 16 additions & 0 deletions gptqmodel/looper/stage_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def run_subset_stage(
combined_names.append(candidate)

for sub_name in combined_names:
# Group every expert (including ones outside the current subset) so
# load balancing decisions can span the full MoE family.
group_key = looper._extract_moe_group_key(sub_name)
if group_key is None:
continue
Expand Down Expand Up @@ -208,6 +210,9 @@ def run_subset_stage(

subset_size = len(subset)
for idx, (name, m) in enumerate(subset.items()):
# Register the forward hook that captures activations for quantization.
# The final module optionally flips a flag so processors can trigger
# once-per-subset logic after the forward pass.
is_last = (idx == subset_size - 1)
hook_source = getattr(m, "full_name", None)
if hook_source is None:
Expand Down Expand Up @@ -320,16 +325,21 @@ def run_subset_stage(
pb.title(layer_title).subtitle("").draw()

for h in handle:
# Detach temporary hooks to avoid leaking state into future passes.
h.remove()

for name in subset:
# Reset inline hook attributes on NamedModule wrappers so future passes
# do not reuse state from this subset run.
if hasattr(subset[name], 'forward_hook'):
subset[name].forward_hook = None
subset[name].forward_hook_last = False

moe_skip_modules = []
if isinstance(processor, GPTQProcessor):
for name in subset:
# Skip MoE experts that never fired; they likely lacked calibration
# traffic and would produce invalid statistics.
if processor.tasks[name].fwd_counter == 0:
logger.error(f"`{name}` was not invoked, if it is a MoE module, it may lack sufficient calibration data routed to it.")
moe_skip_modules.append(name)
Expand All @@ -343,6 +353,8 @@ def run_subset_stage(

quant_target_devices: Dict[str, torch.device] = {}
for name, named_module in subset.items():
# Ensure each module has a matching processor task before sending it to
# the worker pool; otherwise freeze it on the current device.
task_map = getattr(processor, "tasks", None)
has_task = bool(task_map and task_map.get(name) is not None)

Expand Down Expand Up @@ -424,6 +436,8 @@ def _process_on_worker(
return nm.name, nm

for name, named_module in subset.items():
# Launch processing for every module in the subset; tasks may run in
# parallel as allowed by the device thread pool.
tgt_dev = quant_target_devices.get(name, cur_layer_device)
futures.append(
DEVICE_THREAD_POOL.submit(
Expand All @@ -440,6 +454,8 @@ def _process_on_worker(
)

for fut in futures:
# Collect results in submission order so the final subset map preserves
# deterministic iteration for downstream consumers.
name, named_module = fut.result()
processed_subset[name] = named_module
torch_sync()
Expand Down
6 changes: 5 additions & 1 deletion gptqmodel/models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,11 @@ def strip_attention_impl_fields(target: Any) -> Dict[str, Any]:
for attr in ("attn_implementation", "_attn_implementation"):
if hasattr(target, attr):
removed[attr] = getattr(target, attr)
delattr(target, attr)
# Avoid AttributeError: property '_attn_implementation' of 'Qwen2Config' object has no deleter
try:
delattr(target, attr)
except Exception:
pass
return removed

generation_config = getattr(self.model, "generation_config", None)
Expand Down
24 changes: 8 additions & 16 deletions gptqmodel/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ..quantization import QuantizeConfig
from ..utils.device import get_device
from ..utils.logger import setup_logger
from ..utils.torch import torch_sync
from .gar import compose_final_perm, compute_global_perm, compute_local_perms, invert_perm
from .quantizer import HF_OPTIMUM, Quantizer

Expand Down Expand Up @@ -349,7 +350,9 @@ def _compute_hessian_xtx(self, matrix: torch.Tensor) -> torch.Tensor:

if chunk_size is None:
mat32 = matrix.to(dtype=torch.float32)
return torch.matmul(mat32.T, mat32)
xtx = torch.matmul(mat32.T, mat32)
torch_sync(device=xtx.device)
return xtx

xtx_accum = torch.zeros((self.columns, self.columns), dtype=torch.float32, device=matrix.device)

Expand All @@ -360,6 +363,7 @@ def _compute_hessian_xtx(self, matrix: torch.Tensor) -> torch.Tensor:
materialized32 = materialized
xtx_accum.add_(torch.matmul(materialized32.T, materialized32))

torch_sync(device=xtx_accum.device)
return xtx_accum

def process_batch(self, inp: torch.Tensor) -> Tuple[int, Optional[torch.Tensor], torch.device]:
Expand Down Expand Up @@ -460,7 +464,7 @@ def _select_hessian_target_device(self, requested: Optional[torch.device]) -> to

return torch.device("cpu")

def _materialize_global_hessian(self, target_device: Optional[torch.device] = None) -> None:
def materialize_global_hessian(self, target_device: Optional[torch.device] = None) -> None:
device = self._select_hessian_target_device(target_device)

with self.lock:
Expand Down Expand Up @@ -500,19 +504,7 @@ def _materialize_global_hessian(self, target_device: Optional[torch.device] = No

for partial_device, partial in self._device_hessian_partials.items():
if partial.device != result_accum.device or partial.dtype != torch.float32:
# TODO FIXME multi-3090 using P2P is revaling an issue where result_accum and/or partial is not ready for consolidation on the main thread
# when parials are calculated on the individual
try:
tmp = partial.to(device=result_accum.device, dtype=torch.float32)
result_accum.add_(tmp)
del tmp
except:
log.warn(f"Quantization: Module `{self.name}` -> Retry 1/2 partial.to in 0.5s")
time.sleep(0.25)
tmp = partial.to(device=result_accum.device, dtype=torch.float32)
result_accum.add_(tmp)
del tmp

result_accum.add_(partial.to(device=result_accum.device, dtype=torch.float32))
else:
result_accum.add_(partial)

Expand All @@ -527,7 +519,7 @@ def _materialize_global_hessian(self, target_device: Optional[torch.device] = No
del result_accum

def finalize_hessian(self, target_device: Optional[torch.device] = None) -> torch.Tensor:
self._materialize_global_hessian(target_device=target_device)
self.materialize_global_hessian(target_device=target_device)
if self.H is None:
self.H = torch.zeros((self.columns, self.columns), dtype=torch.float32, device=self._select_hessian_target_device(target_device))
return self.H
Expand Down