Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions gptqmodel/looper/gptq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ def is_skipped(self, module: NamedModule) -> bool:
def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]:
def tmp(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
g = self.tasks[name] # noqa: F821
batch_idx = self.current_batch_index()
with tf32_disable_guard():
g.add_batch(inp[0].data, out.data) # noqa: F821
g.add_batch(inp[0].data, out.data, batch_index=batch_idx) # noqa: F821
del inp, out
return tmp

Expand Down Expand Up @@ -311,7 +312,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs):
logger=log,
module_name=module_label,
):
pack_module(
packer_label = pack_module(
name=module.full_name,
qModules=qModules,
q_scales=q_scales,
Expand All @@ -326,7 +327,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs):
timer.record(
"submodule_finalize_pack",
time.perf_counter() - pack_start,
source=module_label,
source=f"{module_label} [{packer_label or 'module.pack_original'}]",
)

# TODO: store module quant results in module, not global processor result
Expand Down
17 changes: 16 additions & 1 deletion gptqmodel/looper/loop_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,25 @@ def __init__(
log.warn(f"The average length of input_ids of calibration_dataset should be greater than "
f"{min_calibration_dataset_input_ids_avg_length}: actual avg: {avg}.")

self.num_batches = len(calibration)
self.num_batches = len(calibration)

self.calibration_dataset = calibration

# Track the current calibration batch index on a per-thread basis so
# processors can retrieve deterministic ordering information (e.g.
# GPTQ's Hessian updates) even when forwards run on multiple threads.
self._batch_tls = threading.local()

def _set_current_batch_index(self, batch_index: Optional[int]) -> None:
if batch_index is None:
if hasattr(self._batch_tls, "index"):
delattr(self._batch_tls, "index")
else:
self._batch_tls.index = int(batch_index)

def current_batch_index(self) -> Optional[int]:
return getattr(self._batch_tls, "index", None)

def _async_log_writer(self, stat):
with open(self.log_tmp_log_file_name, 'a') as f:
json.dump(stat, f, indent=4)
Expand Down
241 changes: 184 additions & 57 deletions gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class FinalizeProgressInfo(NamedTuple):
layer_idx: Optional[int]


class StopMainLoop(Exception):
"""Signal that the module loop should abort immediately."""


class ModuleLooper():
"""Drive the per-layer quantisation workflow over one or more devices.

Expand All @@ -77,6 +81,10 @@ def __init__(self, model: BaseQModel, processors: List[LoopProcessor]):
self.gptq_model = model
self.support_batch_quantize = model.support_batch_quantize
self.lock = threading.Lock()
self._layer_callback = getattr(model, "layer_callback", None)
self._loop_stop_event = threading.Event()
self._loop_stop_exc: Optional[BaseException] = None
self._loop_stop_waited = False

disk_speed = estimate_disk_io_speed()
disk_speed_mb = disk_speed / (1024 * 1024)
Expand Down Expand Up @@ -105,6 +113,81 @@ def __init__(self, model: BaseQModel, processors: List[LoopProcessor]):
for processor in self.processors:
self._processor_mask_tls(processor)

def register_layer_callback(self, callback) -> None:
"""Register or replace the layer-complete callback target."""
self._layer_callback = callback

def _resolve_layer_callback(self):
for candidate in (
getattr(self, "_layer_callback", None),
getattr(self, "layer_callback", None),
getattr(self.gptq_model, "layer_callback", None),
getattr(self.gptq_model, "callbackup", None),
getattr(self.gptq_model, "callback", None),
):
if candidate is not None:
return candidate
return None

def callbackup(self, layer_idx: int, submodule_finalized: bool):
callback = self._resolve_layer_callback()
if callback is None:
return None

handler = getattr(callback, "layer_complete", None)
if handler is None and callable(callback):
handler = callback
if handler is None:
return None

try:
result = handler(layer_idx=layer_idx, submodule_finalized=submodule_finalized)
except StopMainLoop:
raise
if result is StopMainLoop:
raise StopMainLoop(f"Layer callback requested stop at layer {layer_idx}")
if isinstance(result, StopMainLoop):
raise result
return result

def _request_loop_stop(self, exc: Optional[BaseException]) -> None:
with self.lock:
if self._loop_stop_exc is None and exc is not None:
self._loop_stop_exc = exc
self._loop_stop_event.set()

def _check_loop_stop(self) -> bool:
if not self._loop_stop_event.is_set():
return False
if not self._loop_stop_waited:
DEVICE_THREAD_POOL.wait()
self._loop_stop_waited = True
if self._loop_stop_exc is not None:
raise self._loop_stop_exc
return True

def _emit_layer_complete(
self,
layer_idx: int,
submodule_finalized: bool,
*,
raise_in_place: bool,
) -> None:
try:
self.callbackup(layer_idx=layer_idx, submodule_finalized=submodule_finalized)
except StopMainLoop:
self._request_loop_stop(None)
return
except BaseException as exc:
if raise_in_place:
raise
log.exception(
"Layer completion callback raised an exception (layer=%s, submodule_finalized=%s)",
layer_idx,
submodule_finalized,
)
self._request_loop_stop(exc)

# Processors capture activations through hooks that need thread-local state
# so masks survive the roundtrip to worker threads.
def _processor_mask_tls(self, processor: LoopProcessor) -> threading.local:
Expand Down Expand Up @@ -400,72 +483,76 @@ def _run_forward_batches_single(
stage_label = progress_stage or "Forward"

for batch_idx in range(total_batches):
layer_input = [move_to(inp, device=cur_layer_device, stream=False) for inp in layer_inputs[batch_idx]]
processor._set_current_batch_index(batch_idx)
try:
layer_input = [move_to(inp, device=cur_layer_device, stream=False) for inp in layer_inputs[batch_idx]]

raw_mask = attention_masks[batch_idx]
attn_tensor = raw_mask if raw_mask is None else move_to(raw_mask, device=cur_layer_device, stream=False)
raw_mask = attention_masks[batch_idx]
attn_tensor = raw_mask if raw_mask is None else move_to(raw_mask, device=cur_layer_device, stream=False)

keep_mask = None
if attn_tensor is not None:
seq_len = layer_input[0].shape[1] if (len(layer_input) > 0 and layer_input[0].dim() >= 2) else None
keep_mask = normalize_seq_mask(attn_tensor, seq_len=seq_len)
self._set_processor_mask(processor, keep_mask)
keep_mask = None
if attn_tensor is not None:
seq_len = layer_input[0].shape[1] if (len(layer_input) > 0 and layer_input[0].dim() >= 2) else None
keep_mask = normalize_seq_mask(attn_tensor, seq_len=seq_len)
self._set_processor_mask(processor, keep_mask)

additional_inputs: Dict[str, torch.Tensor] = {}
if self.support_batch_quantize and attn_tensor is not None:
additional_inputs["attention_mask"] = attn_tensor
additional_inputs: Dict[str, torch.Tensor] = {}
if self.support_batch_quantize and attn_tensor is not None:
additional_inputs["attention_mask"] = attn_tensor

if position_ids:
pos = position_ids[batch_idx]
if pos is not None:
additional_inputs["position_ids"] = move_to(pos, device=cur_layer_device, stream=False)
if position_ids:
pos = position_ids[batch_idx]
if pos is not None:
additional_inputs["position_ids"] = move_to(pos, device=cur_layer_device, stream=False)

for key, value in layer_input_kwargs[batch_idx].items():
additional_inputs[key] = nested_move_to(value, device=cur_layer_device, stream=False)
for key, value in layer_input_kwargs[batch_idx].items():
additional_inputs[key] = nested_move_to(value, device=cur_layer_device, stream=False)

if reuse_kv and prev_kv is not None:
additional_inputs["kv_last_layer"] = nested_move_to(prev_kv, device=cur_layer_device, stream=False)
if reuse_kv and prev_kv is not None:
additional_inputs["kv_last_layer"] = nested_move_to(prev_kv, device=cur_layer_device, stream=False)

rehome_module_to_device(module, cur_layer_device, move_parameters=True, move_buffers=True)
rehome_module_to_device(module, cur_layer_device, move_parameters=True, move_buffers=True)

module_output = None
try:
if is_lm_head_module:
module_output = module(*layer_input)
else:
module_output = module(*layer_input, **additional_inputs)
except StopForward:
module_output = None
try:
if is_lm_head_module:
module_output = module(*layer_input)
else:
module_output = module(*layer_input, **additional_inputs)
except StopForward:
module_output = None
finally:
self._set_processor_mask(processor, None)

if (
reuse_kv
and module_output is not None
and isinstance(module_output, tuple)
and len(module_output) > 0
and shared_kv_cache_dict.get(layer_index) is None
):
shared_kv_cache_dict[layer_index] = module_output[-1]

if need_outputs and module_output is not None:
primary = module_output[0] if isinstance(module_output, tuple) else module_output
primary = move_to(primary, device=cur_layer_device, stream=False)
outputs.append([primary])

rows_for_batch = batch_row_counts[batch_idx] if batch_idx < len(batch_row_counts) else 0
if rows_for_batch <= 0:
rows_for_batch = self._batch_row_count(layer_inputs[batch_idx]) if layer_inputs and batch_idx < len(layer_inputs) else 1
rows_for_batch = max(rows_for_batch, 1)

processed_rows = min(processed_rows + rows_for_batch, total_rows)
if progress_pb is not None:
if progress_title:
progress_pb.title(progress_title)
progress_pb.current_iter_step = processed_rows
progress_pb.subtitle(
f"{stage_label} rows {processed_rows}/{total_rows}"
).draw()
finally:
self._set_processor_mask(processor, None)

if (
reuse_kv
and module_output is not None
and isinstance(module_output, tuple)
and len(module_output) > 0
and shared_kv_cache_dict.get(layer_index) is None
):
shared_kv_cache_dict[layer_index] = module_output[-1]

if need_outputs and module_output is not None:
primary = module_output[0] if isinstance(module_output, tuple) else module_output
primary = move_to(primary, device=cur_layer_device, stream=False)
outputs.append([primary])

rows_for_batch = batch_row_counts[batch_idx] if batch_idx < len(batch_row_counts) else 0
if rows_for_batch <= 0:
rows_for_batch = self._batch_row_count(layer_inputs[batch_idx]) if layer_inputs and batch_idx < len(layer_inputs) else 1
rows_for_batch = max(rows_for_batch, 1)

processed_rows = min(processed_rows + rows_for_batch, total_rows)
if progress_pb is not None:
if progress_title:
progress_pb.title(progress_title)
progress_pb.current_iter_step = processed_rows
progress_pb.subtitle(
f"{stage_label} rows {processed_rows}/{total_rows}"
).draw()
processor._set_current_batch_index(None)

return outputs

Expand Down Expand Up @@ -872,6 +959,8 @@ def loop(self, fail_safe: bool = False, **kwargs):
setattr(parent, module_path[-1], hooked_lm_head)

for layer_index in pb:
if self._check_loop_stop():
break
is_lm_head_module = layer_index >= layer_count

if is_lm_head_module:
Expand Down Expand Up @@ -928,6 +1017,17 @@ def loop(self, fail_safe: bool = False, **kwargs):
lock_ctx = DEVICE_THREAD_POOL.read_lock(cur_layer_device)
with ctx(lock_ctx, device_ctx(device_for_ctx)):
processor.layer_quantize(module, cur_layer_device, named_childs)
if p_index == len(self.processors) - 1:
self._emit_layer_complete(
layer_idx=layer_index,
submodule_finalized=False,
raise_in_place=True,
)
self._emit_layer_complete(
layer_idx=layer_index,
submodule_finalized=True,
raise_in_place=True,
)
continue

layer_inputs = processor.inputs_cache.layer_inputs
Expand Down Expand Up @@ -1339,6 +1439,12 @@ def _finalize_on_worker(process, module, idx, total, module_label, layer_idx):

finalize_futures_snapshot = list(finalize_futures)

self._emit_layer_complete(
layer_idx=layer_index,
submodule_finalized=False,
raise_in_place=True,
)

if finalize_futures_snapshot:
known_layers = sorted(
{
Expand Down Expand Up @@ -1371,11 +1477,17 @@ def _drain_finalize_futures(
futures,
finalize_pb_local,
finalize_count_local,
layer_idx_for_callback,
):
completed_local = 0
try:
for future in as_completed(futures):
result = future.result()
try:
result = future.result()
except BaseException as exc:
log.exception("Submodule finalize task raised an exception")
self._request_loop_stop(exc)
return

if isinstance(result, FinalizeProgressInfo):
module_label = result.module_label
Expand All @@ -1399,6 +1511,11 @@ def _drain_finalize_futures(
).subtitle(subtitle).draw()
finally:
finalize_pb_local.close()
self._emit_layer_complete(
layer_idx=layer_idx_for_callback,
submodule_finalized=True,
raise_in_place=False,
)

if finalize_futures_snapshot:
# Drain finalize futures asynchronously so the main loop can continue scheduling work.
Expand All @@ -1408,14 +1525,23 @@ def _drain_finalize_futures(
[future for future, *_ in finalize_futures_snapshot],
finalize_pb,
finalize_count,
layer_index,
),
name="SubmoduleFinalizeWatcher",
daemon=True,
).start()
else:
self._emit_layer_complete(
layer_idx=layer_index,
submodule_finalized=True,
raise_in_place=True,
)

# LifeCycle: All sub-modules have finalized meaning quantization work is complete
self._check_loop_stop()
# Ensure ANY remaining tasks the looper submitted have drained
DEVICE_THREAD_POOL.wait() # same as wait('all')
self._check_loop_stop()

# paranoid safety check
# torch_sync()
Expand All @@ -1437,6 +1563,7 @@ def _drain_finalize_futures(

try:
for index, reverse_p in enumerate(reversed_processors, start=1):
self._check_loop_stop()
if isinstance(reverse_p, GPTQProcessor):
pass
elif isinstance(reverse_p, EoraProcessor):
Expand Down
Loading