Skip to content

Conversation

@Qubitium
Copy link
Collaborator

@Qubitium Qubitium commented Sep 29, 2025

@avtc Thanks for your patience in trying differen strategies. This PR by itself will not fix your oom issue out of the gate but sets the foundation for data-parallelism and also give us free metrics (without overhead) of when/count of how much gpu work we have done based on tasked submitted/completed on each individual gpus.

Right now toruch.cuda.empty_cache() will be called for every 14 bg gpu tasks that have been submittd to the gpu work queue.

The limit is current set at the module_lopper __init__ page.

# Create a single pool for the entire looper lifecycle.
        # Eagerly discovers devices and pins worker threads per device.
        # Tune worker counts here if desired (example policy shown).
        self.pool = DeviceThreadPool(
            inference_mode=True,
            workers={
                "cuda:per": 1,
                "xpu:per": 1,
                "mps": 1,
                "cpu": 1,
            },
            empty_cache_every_n=14,  # disable auto GC during quant loops; enable if you want
        )

This will give us the flexibilty to implement multiple strategies based on subset, layer, tc. So it can mimic the old way, wait for all bg to complete per layer, and it can do fine-grained control like, do this (cleanup task) for every N submodules we process. Flexibilty is good here since one strategy will not fit all. For normal llama like models, there is about 7 modules, quantized modules, per layer. For MoE, this can vary wildly.

Signed-off-by: Qubitium <Qubitium@modelcloud.ai>
Signed-off-by: Qubitium <Qubitium@modelcloud.ai>
Signed-off-by: Qubitium <Qubitium@modelcloud.ai>
Signed-off-by: Qubitium <Qubitium@modelcloud.ai>
Signed-off-by: Qubitium <Qubitium@modelcloud.ai>
Signed-off-by: Qubitium <Qubitium@modelcloud.ai>
@avtc
Copy link
Contributor

avtc commented Sep 29, 2025

@Qubitium i run GLM-4.5-Air to 4bit, mock_quantization, samples: 1 on this branch.
It looks like after each layer VRAM is reclaimed (according to reported VRAM usage in per module log).
It works longer than ASYNC_WORKER.join() but faster than SERIAL_BG_QUEUE.join(), i have not played with settings, just run as is.

And after few layers got an error:

ERROR GC pass encountered an error: AcceleratorError('CUDA error: unspecified launch failure\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n')
INFO  GC completed in 0.000s (pass #213).                                                                                  
DEBUG GC trigger received; acquiring global exclusive lock…                                                                
ERROR GC pass encountered an error: AcceleratorError('CUDA error: unspecified launch failure\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n')
INFO  GC completed in 0.000s (pass #214).                                                                                  
DEBUG GC trigger received; acquiring global exclusive lock…                                                                
ERROR GC pass encountered an error: AcceleratorError('CUDA error: unspecified launch failure\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n')
INFO  GC completed in 0.000s (pass #215).                                                                                  
DEBUG GC trigger received; acquiring global exclusive lock…                                                                
Quantizing mlp.experts.65.gate_proj in layer   [14 of 45] ██████████----------------------| 0:23:51 / 1:13:08 [15/46] 32.6%Traceback (most recent call last):
Quantizing mlp.experts.98.gate_proj in layer     [14 of 45] ██████████----------------------| 0:23:51 / 1:13:08 [15/46] 32.6ERROR GC pass encountered an error: AcceleratorError('CUDA error: unspecified launch failure\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n')
INFO  GC completed in 0.000s (pass #216).                                                                                  
DEBUG GC trigger received; acquiring global exclusive lock…                                                                
ERROR GC pass encountered an error: AcceleratorError('CUDA error: unspecified launch failure\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n')
INFO  GC completed in 0.000s (pass #217).                                                                                  
DEBUG GC trigger received; acquiring global exclusive lock…                                                                
ERROR GC pass encountered an error: AcceleratorError('CUDA error: unspecified launch failure\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n')
INFO  GC completed in 0.001s (pass #218).                                                                                  
Quantizing mlp.experts.107.up_proj in layer    [14 of 45] ██████████----------------------| 0:23:51 / 1:13:08 [15/46] 32.6%  File "/home/ubuntu/Documents/Quantize/quantize-glm4.5-air-gptqmodel-clean.py", line 58, in <module>
    model.quantize(
    ~~~~~~~~~~~~~~^
        calibration_dataset,
        ^^^^^^^^^^^^^^^^^^^^
        batch_size=BATCH_SIZE,
        ^^^^^^^^^^^^^^^^^^^^^^
        #auto_gc=False,
        ^^^^^^^^^^^^^^^
        )
        ^
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/models/base.py", line 693, in quantize
    return module_looper.loop(
           ~~~~~~~~~~~~~~~~~~^
        backend=backend,
        ^^^^^^^^^^^^^^^^
        fail_safe=self.quantize_config.fail_safe,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/looper/module_looper.py", line 563, in loop
    name, m = fut.result()
              ~~~~~~~~~~^^
  File "/home/ubuntu/.pyenv/versions/3.13.7t/lib/python3.13t/concurrent/futures/_base.py", line 456, in result
    return self.__get_result()
           ~~~~~~~~~~~~~~~~~^^
  File "/home/ubuntu/.pyenv/versions/3.13.7t/lib/python3.13t/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/utils/threadx.py", line 262, in _run
    result = fn(*args, **kwargs)
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/looper/module_looper.py", line 553, in _process_on_worker
    proc.process(module=nm)
    ~~~~~~~~~~~~^^^^^^^^^^^
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/looper/gptq_processor.py", line 129, in process
    wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, nsamples = g.quantize()
                                                                                 ~~~~~~~~~~^^
Quantizing mlp.experts.25.up_proj in layer     [14 of 45] ██████████----------------------| 0:23:51 / 1:13:08 [15/46] 32.6%  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/quantization/gptq.py", line 305, in quantize
    W = self._clone_module(device=self.module.target_device)
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/quantization/gptq.py", line 128, in _clone_module
    clone = self.module.weight.data.to(copy=copy, device=device)
DEBUG GC trigger received; acquiring global exclusive lock…                                                                
torch.AcceleratorError: CUDA error: unspecified launch failure
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

ERROR GC pass encountered an error: AcceleratorError('CUDA error: unspecified launch failure\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n')
INFO  GC completed in 0.000s (pass #219).                                                                                  
DEBUG GC trigger received; acquiring global exclusive lock…                                                                
ERROR GC pass encountered an error: AcceleratorError('CUDA error: unspecified launch failure\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n')
INFO  GC completed in 0.000s (pass #220).                                                                                  
DEBUG GC trigger received; acquiring global exclusive lock…                                                                
ERROR GC pass encountered an error: AcceleratorError('CUDA error: unspecified launch failure\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n')
INFO  GC completed in 0.000s (pass #221).                                                                                  
Quantizing mlp.experts.120.up_proj in layer    [14 of 45] ██████████----------------------| 0:23:51 / 1:13:08 [15/46] 32.6%(gptqmodelt)

I have also tried to run this branch and with ASYNC_WORKER.join() in the layer end, on GLM-4.5 (which have more experts and bigger size), with normal amount of samples 1080, but got CUDA OOM after first layer with experts (after layer 3, at beginning of layer 4). And now I am interested in distributing weights over several GPUs to make it work.
update: adjusted calibration dataset, and proceeding now with GLM-4.5

Signed-off-by: Qubitium <Qubitium@modelcloud.ai>
Signed-off-by: Qubitium <Qubitium@modelcloud.ai>
Signed-off-by: Qubitium <Qubitium@modelcloud.ai>
Signed-off-by: Qubitium <Qubitium@modelcloud.ai>
Signed-off-by: Qubitium <Qubitium@modelcloud.ai>
@Qubitium
Copy link
Collaborator Author

Qubitium commented Sep 29, 2025

@avtc The latest commits just unlocked another 20%+ speed improvement in MoE layer quantization. I am getting carried away so not concentrating on memory usage right at this moment.

Signed-off-by: Qubitium <Qubitium@modelcloud.ai>
Signed-off-by: Qubitium <Qubitium@modelcloud.ai>
Signed-off-by: Qubitium <Qubitium@modelcloud.ai>
Signed-off-by: Qubitium <Qubitium@modelcloud.ai>
@Qubitium Qubitium marked this pull request as ready for review September 29, 2025 12:18
@Qubitium Qubitium merged commit 805484a into main Sep 29, 2025
5 checks passed
@Qubitium Qubitium deleted the threadx branch September 29, 2025 12:22
@avtc
Copy link
Contributor

avtc commented Sep 29, 2025

the error log from latest main branch:

Quantizing mlp.experts.112.down_proj in layer  [1 of 45] █----------------------------------| 0:00:26 / 0:09:58 [2/46] 4.3%Traceback (most recent call last):
  File "/home/ubuntu/Documents/Quantize/quantize-glm4.5-air-gptqmodel-clean.py", line 58, in <module>
    model.quantize(
    ~~~~~~~~~~~~~~^
        calibration_dataset,
        ^^^^^^^^^^^^^^^^^^^^
        batch_size=BATCH_SIZE,
        ^^^^^^^^^^^^^^^^^^^^^^
        #auto_gc=False,
        ^^^^^^^^^^^^^^^
        )
        ^
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/models/base.py", line 705, in quantize
    return module_looper.loop(
           ~~~~~~~~~~~~~~~~~~^
        backend=backend,
        ^^^^^^^^^^^^^^^^
        fail_safe=self.quantize_config.fail_safe,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/looper/module_looper.py", line 665, in loop
    fut.result()
    ~~~~~~~~~~^^
  File "/home/ubuntu/.pyenv/versions/3.13.7t/lib/python3.13t/concurrent/futures/_base.py", line 449, in result
    return self.__get_result()
           ~~~~~~~~~~~~~~~~~^^
  File "/home/ubuntu/.pyenv/versions/3.13.7t/lib/python3.13t/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/utils/threadx.py", line 313, in _run
    result = fn(*args, **kwargs)
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/looper/module_looper.py", line 648, in finalize_module
    offload_to_disk(
    ~~~~~~~~~~~~~~~^
        model=self.gptq_model.model,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        module=self.gptq_model.model.get_submodule(module.full_name),
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        disk_path=self.gptq_model.quantize_config.offload_to_disk_path,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/utils/offload.py", line 81, in offload_to_disk
    _offload_disk(module=module, name=full_name, disk_path=disk_path)
    ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/utils/offload.py", line 115, in _offload_disk
    _ = disk_offload(
        module,
    ...<3 lines>...
        execution_device=m_device,
    )
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/accelerate/big_modeling.py", line 297, in disk_offload
    attach_align_device_hook(
    ~~~~~~~~~~~~~~~~~~~~~~~~^
        model,
        ^^^^^^
    ...<4 lines>...
        preload_module_classes=preload_module_classes,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/accelerate/hooks.py", line 521, in attach_align_device_hook
    add_hook_to_module(module, hook, append=True)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/accelerate/hooks.py", line 166, in add_hook_to_module
    module = hook.init_hook(module)
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/accelerate/hooks.py", line 111, in init_hook
    module = hook.init_hook(module)
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/accelerate/hooks.py", line 313, in init_hook
    set_module_tensor_to_device(module, name, "meta")
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/accelerate/utils/modeling.py", line 408, in set_module_tensor_to_device
    clear_device_cache()
    ~~~~~~~~~~~~~~~~~~^^
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/accelerate/utils/memory.py", line 63, in clear_device_cache
    elif is_cuda_available():
         ~~~~~~~~~~~~~~~~~^^
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/accelerate/utils/imports.py", line 126, in is_cuda_available
    with patch_environment(PYTORCH_NVML_BASED_CUDA_CHECK="1"):
         ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/.pyenv/versions/3.13.7t/lib/python3.13t/contextlib.py", line 141, in __enter__
    return next(self.gen)
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/accelerate/utils/environment.py", line 347, in patch_environment
    existing_vars[key] = os.environ[key]
                         ~~~~~~~~~~^^^^^
  File "<frozen os>", line 717, in __getitem__
KeyError: 'PYTORCH_NVML_BASED_CUDA_CHECK'

will check if setting env var help

@avtc
Copy link
Contributor

avtc commented Sep 29, 2025

setting PYTORCH_NVML_BASED_CUDA_CHECK=1 helps.
q.to failed for me even with 4 gpus - applied the lock on q.to - helped, and final error:

Quantizing mlp.experts.124.down_proj in layer  [45 of 45] ███████████████████████████████| 0:30:44 / 0:30:44 [46/46] 100.0%Traceback (most recent call last):
  File "/home/ubuntu/Documents/Quantize/quantize-glm4.5-air-gptqmodel-clean.py", line 58, in <module>
    model.quantize(
    ~~~~~~~~~~~~~~^
        calibration_dataset,
        ^^^^^^^^^^^^^^^^^^^^
        batch_size=BATCH_SIZE,
        ^^^^^^^^^^^^^^^^^^^^^^
        #auto_gc=False,
        ^^^^^^^^^^^^^^^
        )
        ^
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/models/base.py", line 705, in quantize
    return module_looper.loop(
           ~~~~~~~~~~~~~~~~~~^
        backend=backend,
        ^^^^^^^^^^^^^^^^
        fail_safe=self.quantize_config.fail_safe,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/looper/module_looper.py", line 669, in loop
    self.pool.wait()  # same as wait('all')
    ~~~~~~~~~~~~~~^^
TypeError: DeviceThreadPool.wait() missing 1 required positional argument: 'scope'

@avtc
Copy link
Contributor

avtc commented Sep 29, 2025

diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py
index fee50bdb..cc16d0f9 100644
--- a/gptqmodel/looper/module_looper.py
+++ b/gptqmodel/looper/module_looper.py
@@ -666,7 +666,7 @@ class ModuleLooper():
 
         # LifeCycle: All sub-modules have finalized meaning quantization work is complete
         # Ensure ANY remaining tasks the looper submitted have drained
-        self.pool.wait()  # same as wait('all')
+        self.pool.wait('all')  # wait for all devices
 
         # paranoid safety check
         # torch_sync()

Works now, VRAM is reclaimed, estimated time looks promising, 30m for GLM-4.5-Air, 4bit, samples: 1, mock_quantization=True.

@Qubitium
Copy link
Collaborator Author

@avtc PYTORCH_NVML_BASED_CUDA_CHECK=1 is a new one. That looks to be an upstream bug. Looked they tried to set an env var and then check to see if it was set and the check-after-set failed hard.

@Qubitium
Copy link
Collaborator Author

Regarding the accelerate bug, it is a thread safety race condition. Lot of code in python and related libaries have never been tested in a true threading env.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants