Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,12 +669,14 @@ def loop(self, fail_safe: bool = False, **kwargs):
is_lm_head_module = layer_index >= layer_count

if is_lm_head_module:
quant_modules_pb.title("Quantizing lm_head").draw()
layer_title = "Quantizing lm_head"
module = get_module(self.gptq_model.model, key=self.gptq_model.lm_head)
else:
quant_modules_pb.title(f"Quantizing layer {layer_index} of {layer_count - 1}").draw()
layer_title = f"Quantizing layer {layer_index} of {layer_count - 1}"
module = layers[layer_index]

quant_modules_pb.title(layer_title).subtitle("").draw()

self.gptq_model.wait_for_turtle_reload()

if module.__class__.__name__.lower() == "MllamaCrossAttentionDecoderLayer".lower():
Expand Down Expand Up @@ -778,11 +780,11 @@ def loop(self, fail_safe: bool = False, **kwargs):
need_outputs = not processor.fwd_after_process
reuse_kv = bool(getattr(module, "reuse_kv", False))
forward_msg = (
"ModuleLooper: forward start "
f"(processor=`{processor.name()}`, layer=`{layer_descriptor}`, "
f"subset={index + 1}/{subset_total}, batches={batch_count})"
"Forward start "
f"(layer=`{layer_descriptor}`, subset={index + 1}/{subset_total}, "
f"batches={batch_count})"
)
log.info(forward_msg)
quant_modules_pb.title(forward_msg).draw()
# Drain any background work so the forward spike does not race pooled tasks.
DEVICE_THREAD_POOL.wait()
forward_outputs = self._run_forward_batches(
Expand All @@ -807,6 +809,8 @@ def loop(self, fail_safe: bool = False, **kwargs):
fwd_time = time.time() - fwd_start
processor.set_fwd_time(fwd_time)

quant_modules_pb.title(layer_title).subtitle("").draw()

for h in handle:
h.remove()

Expand Down Expand Up @@ -881,11 +885,10 @@ def _process_on_worker(proc: LoopProcessor, nm: NamedModule):
replay_batch_count = 0
replay_batch_count = replay_batch_count or 0
replay_msg = (
"ModuleLooper: forward replay "
f"(processor=`{processor.name()}`, layer=`{layer_descriptor}`, "
f"batches={replay_batch_count})"
"Forward replay "
f"(layer=`{layer_descriptor}`, batches={replay_batch_count})"
)
log.info(replay_msg)
quant_modules_pb.title(replay_msg).draw()
# Forward replay shares the same VRAM spike; block until the pool drains first.
DEVICE_THREAD_POOL.wait()
layer_outputs = self._run_forward_batches(
Expand Down Expand Up @@ -917,6 +920,8 @@ def _process_on_worker(proc: LoopProcessor, nm: NamedModule):
processor.receive_layer_inputs(layer_outputs)
layer_inputs = processor.inputs_cache.layer_inputs

quant_modules_pb.title(layer_title).subtitle("").draw()

if p_index == len(self.processors) - 1:
torch_sync()

Expand Down
1 change: 0 additions & 1 deletion gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@


if TYPE_CHECKING:
from ..utils.threadx import DeviceThreadPool
try:
from datasets import Dataset as HFDatasetType
from datasets import IterableDataset as HFIterableDatasetType
Expand Down
3 changes: 2 additions & 1 deletion gptqmodel/utils/looper_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
from ..utils.device import get_device
from ..utils.logger import setup_logger
from ..utils.model import move_to, nested_move_to
from ..utils.torch import ALL_DEVICES, CPU
from ..utils.safe import ThreadSafe
from ..utils.torch import ALL_DEVICES, CPU


_THREAD_SAFE_PARALLEL = ThreadSafe(torch_parallel)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_torch_replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_replicate_from_cpu_to_multiple_gpu():

devices = [torch.device(f"cuda:{idx}") for idx in range(2)]
module = _random_linear()
input_tensor = torch.randn(2, module.in_features)
torch.randn(2, module.in_features)

torch_replicate(module, devices)

Expand Down