Skip to content

Conversation

@Qubitium
Copy link
Collaborator

@Qubitium Qubitium commented Sep 26, 2025

@avtc Reverting previous threading fix for multi-gpu as I believe the recent threaad codes I pushed on main has nullified this issue at the source. Please test. Also, massive cpu memory usage has been eliminated including elimination of packing as pack is now done inline to quant.

Second fix: #1923

Signed-off-by: Qubitium <Qubitium@modelcloud.ai>
@Qubitium Qubitium merged commit 04c00ed into main Sep 26, 2025
4 checks passed
@Qubitium Qubitium deleted the remove-prev-thead-fix branch September 26, 2025 01:22
@avtc
Copy link
Contributor

avtc commented Sep 26, 2025

@Qubitium the same error appear again. GLM-4.5-Air, mock_quantization: true, 3.13t, during layer number 1 (first layer with experts).

Traceback (most recent call last):
  File "/home/ubuntu/Documents/Quantize/quantize-glm4.5-air-gptqmodel.py", line 326, in <module>
    model.quantize(
    ~~~~~~~~~~~~~~^
        calibration_dataset,
        ^^^^^^^^^^^^^^^^^^^^
    ...<6 lines>...
        #fail_safe=True,
        ^^^^^^^^^^^^^^^^
        )
        ^
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/models/base.py", line 670, in quantize
    return module_looper.loop(
           ~~~~~~~~~~~~~~~~~~^
        backend=backend,
        ^^^^^^^^^^^^^^^^
        fail_safe=self.quantize_config.fail_safe,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/looper/module_looper.py", line 466, in loop
    name, m = future.result()
              ~~~~~~~~~~~~~^^
  File "/home/ubuntu/.pyenv/versions/3.13.7t/lib/python3.13t/concurrent/futures/_base.py", line 456, in result
    return self.__get_result()
           ~~~~~~~~~~~~~~~~~^^
  File "/home/ubuntu/.pyenv/versions/3.13.7t/lib/python3.13t/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
  File "/home/ubuntu/.pyenv/versions/3.13.7t/lib/python3.13t/concurrent/futures/thread.py", line 59, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/looper/module_looper.py", line 458, in process_module
    processor.process(module=m)
    ~~~~~~~~~~~~~~~~~^^^^^^^^^^
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/looper/gptq_processor.py", line 128, in process
    wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, nsamples = g.quantize()
                                                                                 ~~~~~~~~~~^^
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/quantization/gptq.py", line 560, in quantize
    Q = Q.to(device=self.module.weight.data.device, non_blocking=False)
torch.AcceleratorError: CUDA error: invalid argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

@Qubitium
Copy link
Collaborator Author

@avtc Ok. Still working to fix many thread (GIL=0) on main. Guess there are stilll bugs there. Can you give me a reproducible script? Thanks! No need for private calib data, etc. Just something simple I can run on my local multi-gpu setup.

@avtc
Copy link
Contributor

avtc commented Sep 26, 2025

@Qubitium
bash part:

source ~/venvs/gptqmodelt/bin/activate
# with 4 gpus failed to reproduce
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6"
export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True,max_split_size_mb:128"
export PYTHON_GIL=0
#export CUDA_LAUNCH_BLOCKING=1
python /home/ubuntu/Documents/Quantize/quantize-glm4.5-air-gptqmodel-clean.py

python part quantize-glm4.5-air-gptqmodel-clean.py:

import torch
import json
import os
import shutil
import glob
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from safetensors.torch import load_file, save_file
from datasets import load_dataset

MODEL_NAME = "zai-org/GLM-4.5-Air"
BITS = 4
GROUP_SIZE = 128
V2 = 0
DUMP = 0.02
BATCH_SIZE = 1
SAMPLES = 1
MOCK_QUANTIZATION = 1
OUTPUT_DIR = f"/home/ubuntu/models/gptqmodel/GLM-4.5-Air-gptqmodel-w{BITS}g{GROUP_SIZE}-v2-{V2}-dump{DUMP}-bs{BATCH_SIZE}-s{SAMPLES}-b7699aa1"

model_path = MODEL_NAME

from gptqmodel import GPTQModel, QuantizeConfig

calibration_dataset = load_dataset(
    "allenai/c4",
    data_files="en/c4-train.00001-of-01024.json.gz",
    split="train"
    ).select(range(SAMPLES))["text"]

dynamic = {
    r"-:model.embed_tokens.weight": {},
    r"-:.*shared_experts": {},
    r"-:.*shared_head": {},
    r"-:lm_head.weight": {},
    #r"-:.*self_attn..*_proj": {},
    }

quant_config = QuantizeConfig(
    bits=BITS,
    group_size=GROUP_SIZE,
    v2=V2,
    v2_memory_device="auto",
    sym=True,
    desc_act=False,
    dynamic=dynamic,
    damp_percent=DUMP,
    mock_quantization=MOCK_QUANTIZATION,
    )

print(f"Loading model from {model_path}...")

model = GPTQModel.load(model_path, quant_config)

model.quantize(
    calibration_dataset,
    batch_size=BATCH_SIZE,
    )
model.save(OUTPUT_DIR)

print("Finished! Quantized model saved to", OUTPUT_DIR)

@avtc
Copy link
Contributor

avtc commented Sep 26, 2025

There are new runtime errors in the output:

Traceback (most recent call last):
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/utils/threads.py", line 33, in _runner
    return fn()
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/looper/module_looper.py", line 550, in finalize_module
    reverse_p.submodule_finalize(module, self.gptq_model)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/looper/gptq_processor.py", line 236, in submodule_finalize
    pack_module(
    ~~~~~~~~~~~^
        name=module.full_name,
        ^^^^^^^^^^^^^^^^^^^^^^
    ...<6 lines>...
        lock=self.lock,
        ^^^^^^^^^^^^^^^
    )
    ^
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/utils/model.py", line 603, in pack_module
    module.pack(linear=layer, scales=q_scales, zeros=q_zeros, g_idx=q_g_idx)
    ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/nn_modules/qlinear/__init__.py", line 636, in pack
    f.result()
    ~~~~~~~~^^
  File "/home/ubuntu/.pyenv/versions/3.13.7t/lib/python3.13t/concurrent/futures/_base.py", line 456, in result
    return self.__get_result()
           ~~~~~~~~~~~~~~~~~^^
  File "/home/ubuntu/.pyenv/versions/3.13.7t/lib/python3.13t/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
  File "/home/ubuntu/.pyenv/versions/3.13.7t/lib/python3.13t/concurrent/futures/thread.py", line 59, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/nn_modules/qlinear/__init__.py", line 614, in _process_block
    _pack_rows_2_4_8(sub, qweight, dst_rows)
    ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/venvs/gptqmodelt/lib/python3.13t/site-packages/gptqmodel/nn_modules/qlinear/__init__.py", line 571, in _pack_rows_2_4_8
    dst[dst_rows_base:dst_rows_base + bits] = ((packed64 & MASK32).to(t.int32))
    ~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Inplace update to inference tensor outside InferenceMode is not allowed.You can make a clone to get a normal tensor before doing inplace update.See https://github.com/pytorch/rfcs/pull/17 for more details.

@Qubitium
Copy link
Collaborator Author

@avtc check #1930

@avtc
Copy link
Contributor

avtc commented Sep 26, 2025

@Qubitium
Just to note, the reproducible script is two comments above:
#1916 (comment)

With 4 GPUs the issue is not reproduced, but with 5 GPUs it reproduced.
As far as I remember adding a try-catch with retry works well (upd: not always). It throws on first layer with experts only.

@avtc
Copy link
Contributor

avtc commented Sep 26, 2025

@Qubitium
Got CUDA OOM on layer 17, how is it possible to use auto_gc now?

@Qubitium
Copy link
Collaborator Author

Qubitium commented Sep 26, 2025

@avtc Can you open a new issue so I can track this. auto_gc was removed by me cause I don't think it matters now. May not. How many gpu are you running for the glm4.5 air quant and show the the "Tokens/Padd Tokens' data outputed in the new main code so i can see how many tokens you are feeding to the model for quant.

btw, please check your cpu memory usage. It should be much lower than before.

@avtc
Copy link
Contributor

avtc commented Sep 26, 2025

The last main does not have fix for threads, so I can use only 4 x 3090. I have used single sample with 80 tokens to quant, to 4 bit, using a script above. I am not sure where to look for Tokens/Pad tokens, will check later.

@Qubitium
Copy link
Collaborator Author

Qubitium commented Sep 26, 2025

@avtc Check for this output:

INFO  Calibration: Sort in descending order by length                                                                                                                                                                                         
INFO  Calibration: Total padded tokens: 0                                                                                                                                                                                                     
INFO  Calibration: Total non-padded tokens: 345522                                                                                                                                                                                            
INFO  Calibration: Total tokens: 345522                                                                                                                                                                                                       
INFO  Calibration: Sort in descending order by length                                                                                                                                                                                         
INFO  Calibration: Total padded tokens: 0                                                                                                                                                                                                     
INFO  Calibration: Total non-padded tokens: 345522                                                                                                                                                                                            
INFO  Calibration: Total tokens: 345522         

@avtc
Copy link
Contributor

avtc commented Sep 26, 2025

@Qubitium

WARN  Calibration dataset size should be more than 256. Current: 1.                                                        
INFO  Calibration: Native order                                                                                            
INFO  Calibration: Total padded tokens: 0                                                                                  
INFO  Calibration: Total non-padded tokens: 80                                                                             
INFO  Calibration: Total tokens: 80                                                                                        
WARN  The average length of input_ids of calibration_dataset should be greater than 256: actual avg: 80.0

@avtc
Copy link
Contributor

avtc commented Sep 26, 2025

@Qubitium
I have tried with auto_gc=True (rolled back the changes), CUDA OOM happen on layer 22 with 4 x 3090 (was on layer 17 without auto_gc=True), and on layer 19 with 7 x 3090.
I see that with each layer the VRAM usage raised on the GPU0, maybe there is a VRAM memory leak somewhere.
Will add more details in separate issue on weekend.

The max RAM usage was around 46 GB until layer 19, great result! But need to check with more samples...

@avtc
Copy link
Contributor

avtc commented Sep 26, 2025

@Qubitium
With offload_to_disk=False, - the VRAM reclaimed after each layer, so the issue is related to offload.

@avtc
Copy link
Contributor

avtc commented Sep 26, 2025

@Qubitium
Looks like I have found a solution, please check:
avtc@2a1f705

The VRAM with this fix is reclaimed after each layer.

The offload turned off empty_cache, so the reclaim did not work with just auto_gc=True:

108 |     real_cache_flush = None
110 |         real_cache_flush = torch.cuda.empty_cache
111 |         torch.cuda.empty_cache = lambda: None

@Qubitium
Copy link
Collaborator Author

@avtc Yes! I have confirmed add the post-level torch.cuda_empty_cache() op will help lower the gpu by about 10% on llama 3.2 1B model so you should see much more saving for MoE.

But instead of this, which is very expensive (slow) for non-moe where each layer does not have so much modules, I will add new auto_gc_bytes features that will auto_gc based on how much estimated memory we have freed. This will make sure we only call and tune this call unless absolutely necesary. Should be done today

@avtc
Copy link
Contributor

avtc commented Sep 27, 2025

@Qubitium the main idea behind the fix was to wait for offload to finish after each layer (ASYNC_WORKER.join()) to be able to release layer VRAM, as right now ASYNC_WORKER.join() called after all layers.

For small models that fits into VRAM that was unnoticeable, but with large models this fix will help to keep in VRAM only a single layer and prevent VRAM OOM.

I have checked with auto_gc=False, and it still works, and the GPU0 VRAM in nvtop become free after reaching the max amount automatically. The max RAM usage was around 24 GB for my sample 1 test. The final save still takes more - 120GB.

With this fix I have noticed slowdown, so as a more improvement will be to monitor the VRAM usage and call ASYNC_WORKER.join() not so frequent.

@Qubitium
Copy link
Collaborator Author

Qubitium commented Sep 27, 2025

@avtc Check my latest PR merge and tagged you for the vram autogc update. It should be fixed but needs tuning. #1934

The max RAM usage was around 24 GB for my sample 1 test. The final save still takes more - 120GB.

Data looks good and expected. The offload is working but during model save, it is too stupid and not directly using the meta files directly and again copying to meta files back to cpu and then write them again to disk. This is on my todo list so that on model save, the cpu RAM usage is also 24GB or less.

@avtc
Copy link
Contributor

avtc commented Sep 27, 2025

@Qubitium on last main - the CUDA OOM happen on layer 6 (with 4 x 3090, 4bit, 1 sample).
I think to properly release VRAM need to add ASYNC_WORKER.join() somewhere (after the layer done, or triggered by monitoring VRAM on GPU0).

@avtc
Copy link
Contributor

avtc commented Sep 28, 2025

@Qubitium I have moved SERIAL_BG_QUEUE.join() to execute after each layer - and it works now without CUDA OOM, but it works much longer than ASYNC_WORKER.join() (executed after each layer)

avtc added a commit to avtc/GPTQModel that referenced this pull request Sep 28, 2025
…1916)"

This reverts commit 04c00ed.

# Conflicts:
#	gptqmodel/quantization/gptq.py
@Qubitium
Copy link
Collaborator Author

@avtc You are right about the placement improving the situation for GLM 4.5 Air MoE. I am just trying to thinkn of a way that fixes for all situations. There are models with 4x more MoE experts that will surely OOM well before the entire layer's submodule_finalize has been completed. There is a way to do this in a more generic way without sacraficing speed. I want to have the speed and vram saving at the same time. The python/torch hooks also has speed penality so it's not looking to great right now. Testing other methods today.

@Qubitium
Copy link
Collaborator Author

Qubitium commented Sep 29, 2025

So the symptom are co-related.

  1. Accelerate's disk_offload api actually wrongly called torch.cuda.empty_cache() which appears to to fix this memory usage before my fix but is just bad code calling at the wrong place but at the right time (but too slow and can cause cuda assert bugs since we are now doing packing and offload in separate threads). I have already submitted fix to Accelerate's upstream.

  2. The location of SERIAL_WORKER.join(). Placing it after each layer does help with vram since it forces the main thread to wait before starting forwarding() which is highly vram intensive. But, there is a huge cost to this since the works dumped to the SERAIL_WORKER queue should not dependent on forwarding, should already be on cpu before task submission, and are separate modules that does not participate in forward. Also doesn't fix 512+ massive shared experts/horizontal MoE vram issue that many models are using.

  3. Need to cleanup at which lifecycle the module is on which device, a little mixed right now and put them on proper threads. I think the vram issue is co-related issue with some vram tensors still bounded to the modules when submoduel_finalize is called, which should not happen ideally and signals another bug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants