diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index fa0a00bca..a5d03f125 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -669,12 +669,14 @@ def loop(self, fail_safe: bool = False, **kwargs): is_lm_head_module = layer_index >= layer_count if is_lm_head_module: - quant_modules_pb.title("Quantizing lm_head").draw() + layer_title = "Quantizing lm_head" module = get_module(self.gptq_model.model, key=self.gptq_model.lm_head) else: - quant_modules_pb.title(f"Quantizing layer {layer_index} of {layer_count - 1}").draw() + layer_title = f"Quantizing layer {layer_index} of {layer_count - 1}" module = layers[layer_index] + quant_modules_pb.title(layer_title).subtitle("").draw() + self.gptq_model.wait_for_turtle_reload() if module.__class__.__name__.lower() == "MllamaCrossAttentionDecoderLayer".lower(): @@ -778,11 +780,11 @@ def loop(self, fail_safe: bool = False, **kwargs): need_outputs = not processor.fwd_after_process reuse_kv = bool(getattr(module, "reuse_kv", False)) forward_msg = ( - "ModuleLooper: forward start " - f"(processor=`{processor.name()}`, layer=`{layer_descriptor}`, " - f"subset={index + 1}/{subset_total}, batches={batch_count})" + "Forward start " + f"(layer=`{layer_descriptor}`, subset={index + 1}/{subset_total}, " + f"batches={batch_count})" ) - log.info(forward_msg) + quant_modules_pb.title(forward_msg).draw() # Drain any background work so the forward spike does not race pooled tasks. DEVICE_THREAD_POOL.wait() forward_outputs = self._run_forward_batches( @@ -807,6 +809,8 @@ def loop(self, fail_safe: bool = False, **kwargs): fwd_time = time.time() - fwd_start processor.set_fwd_time(fwd_time) + quant_modules_pb.title(layer_title).subtitle("").draw() + for h in handle: h.remove() @@ -881,11 +885,10 @@ def _process_on_worker(proc: LoopProcessor, nm: NamedModule): replay_batch_count = 0 replay_batch_count = replay_batch_count or 0 replay_msg = ( - "ModuleLooper: forward replay " - f"(processor=`{processor.name()}`, layer=`{layer_descriptor}`, " - f"batches={replay_batch_count})" + "Forward replay " + f"(layer=`{layer_descriptor}`, batches={replay_batch_count})" ) - log.info(replay_msg) + quant_modules_pb.title(replay_msg).draw() # Forward replay shares the same VRAM spike; block until the pool drains first. DEVICE_THREAD_POOL.wait() layer_outputs = self._run_forward_batches( @@ -917,6 +920,8 @@ def _process_on_worker(proc: LoopProcessor, nm: NamedModule): processor.receive_layer_inputs(layer_outputs) layer_inputs = processor.inputs_cache.layer_inputs + quant_modules_pb.title(layer_title).subtitle("").draw() + if p_index == len(self.processors) - 1: torch_sync() diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 8e9747af4..8a9cea4b5 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -66,7 +66,6 @@ if TYPE_CHECKING: - from ..utils.threadx import DeviceThreadPool try: from datasets import Dataset as HFDatasetType from datasets import IterableDataset as HFIterableDatasetType diff --git a/gptqmodel/utils/looper_helpers.py b/gptqmodel/utils/looper_helpers.py index b2b1486ef..ab0035ca9 100644 --- a/gptqmodel/utils/looper_helpers.py +++ b/gptqmodel/utils/looper_helpers.py @@ -18,8 +18,9 @@ from ..utils.device import get_device from ..utils.logger import setup_logger from ..utils.model import move_to, nested_move_to -from ..utils.torch import ALL_DEVICES, CPU from ..utils.safe import ThreadSafe +from ..utils.torch import ALL_DEVICES, CPU + _THREAD_SAFE_PARALLEL = ThreadSafe(torch_parallel) diff --git a/tests/test_torch_replicate.py b/tests/test_torch_replicate.py index 034d98dba..bfef18498 100644 --- a/tests/test_torch_replicate.py +++ b/tests/test_torch_replicate.py @@ -135,7 +135,7 @@ def test_replicate_from_cpu_to_multiple_gpu(): devices = [torch.device(f"cuda:{idx}") for idx in range(2)] module = _random_linear() - input_tensor = torch.randn(2, module.in_features) + torch.randn(2, module.in_features) torch_replicate(module, devices)