Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,6 @@ quant_config = QuantizeConfig(bits=4, group_size=128, act_group_aware=True)
### Experimental Features

* GPTQ v2: set `v2=True` in quantization config.
* Pass `buffered_fwd = True` to `quantize()` api to potentially speed up quantization if gpu has plenty of vram and can hold all fwd inputs in vram.


### Attribution of Quantization Methods:
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/looper/awq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def _sanitize_kwargs(self, inputs_kwargs, module):
sanitized_kwargs[k] = v
return sanitized_kwargs

def preprocess(self, module: NamedModule, buffered_fwd: bool):
def preprocess(self, module: NamedModule, fail_safe: bool):
# TODO Dynamic is not yet supported
pass

Expand Down
11 changes: 1 addition & 10 deletions gptqmodel/looper/gptq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def log_plotly(self):
def set_calibration_dataset(self, calibration_dataset):
raise NotImplementedError("GPTQProcessor's calibration_dataset cannot be modified")

def preprocess(self, module: NamedModule, buffered_fwd: bool, fail_safe: bool):
def preprocess(self, module: NamedModule, fail_safe: bool):
# entire module is skipped
if self.qcfg.dynamic_get(layer_name=module.full_name) == False:
return
Expand Down Expand Up @@ -87,15 +87,6 @@ def preprocess(self, module: NamedModule, buffered_fwd: bool, fail_safe: bool):
tmp = GPTQ(module=module, qcfg=qcfg_clone)
tmp.fail_safe = fail_safe

# models like DeepSeek v3/r1 has > 256 $ of sub-modules per layer
# use buffered mode go vram don't explode: gptq needs to store fwd inputs per each layer fwd
# all sub-modules within a single layer needs to store all the inputs.
# deepseek has massive # of sub-modules per layer, causing vram pressure
# buffered mode is slower due to gpu<->cpu movement
if buffered_fwd:
log.info.once(f"Quantize: Enabling fwd buffered mode for: `{module.name}`")
tmp.fwd_inputs_buffered = True

tmp.quantizer.configure(
perchannel=True,
)
Expand Down
12 changes: 6 additions & 6 deletions gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def store_input_hook(module, args, kwargs):
attention_masks=attention_masks)

@torch.inference_mode
def loop(self, calibration_enable_gpu_cache=True, buffered_fwd=False, fail_safe: bool = False, **kwargs):
def loop(self, calibration_enable_gpu_cache=True, fail_safe: bool = False, **kwargs):
if self.gptq_model.quantize_config.lm_head:
if self.gptq_model.model.config.tie_word_embeddings and hasattr(self.gptq_model.model.model, "_tied_weights_keys"):
tied_keys = self.gptq_model.model._tied_weights_keys
Expand Down Expand Up @@ -303,7 +303,7 @@ def loop(self, calibration_enable_gpu_cache=True, buffered_fwd=False, fail_safe:
if isinstance(processor, AWQProcessor):
named_childs = dict()
for index, names in enumerate(modules):
named_modules = self.crate_named_modules(buffered_fwd=buffered_fwd, full=full,
named_modules = self.crate_named_modules(full=full,
is_lm_head_module=is_lm_head_module,
layer_index=layer_index, layers_prefix=layers_prefix,
names=names,
Expand All @@ -326,7 +326,7 @@ def loop(self, calibration_enable_gpu_cache=True, buffered_fwd=False, fail_safe:
processed_subset = {}

for index, names in enumerate(modules):
subset = self.crate_named_modules(buffered_fwd=buffered_fwd, full=full, is_lm_head_module=is_lm_head_module,
subset = self.crate_named_modules(full=full, is_lm_head_module=is_lm_head_module,
layer_index=layer_index, layers_prefix=layers_prefix,
names=names,
processor=processor,
Expand Down Expand Up @@ -598,7 +598,7 @@ def finalize_module(module):

return total_log

def crate_named_modules(self, buffered_fwd, full, is_lm_head_module, layer_index, layers_prefix, names, processor, fail_safe) -> Dict[str, NamedModule]:
def crate_named_modules(self, full, is_lm_head_module, layer_index, layers_prefix, names, processor, fail_safe) -> Dict[str, NamedModule]:
is_awq_quant = isinstance(processor, AWQProcessor)
subset = {}
for n in names:
Expand Down Expand Up @@ -626,9 +626,9 @@ def crate_named_modules(self, buffered_fwd, full, is_lm_head_module, layer_index

if not is_awq_quant:
if isinstance(processor, GPTQProcessor):
processor.preprocess(subset[name], buffered_fwd=buffered_fwd, fail_safe=fail_safe)
processor.preprocess(subset[name], fail_safe=fail_safe)
else:
processor.preprocess(subset[name], buffered_fwd=buffered_fwd)
processor.preprocess(subset[name])
# some modules are skipped
if processor.is_skipped(subset[name]):
skipped_modules.append(name)
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/looper/native_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def log_plotly(self):
def set_calibration_dataset(self, calibration_dataset):
raise NotImplementedError("NativeProcessor's calibration_dataset cannot be modified")

def preprocess(self, module: NamedModule, buffered_fwd: bool):
def preprocess(self, module: NamedModule):
self.native_inp_caches[module.name] = []

def is_skipped(self, module: NamedModule) -> bool:
Expand Down
11 changes: 1 addition & 10 deletions gptqmodel/looper/qqq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def log_plotly(self):
def set_calibration_dataset(self, calibration_dataset):
raise NotImplementedError("QQQProcessor's calibration_dataset cannot be modified")

def preprocess(self, module: NamedModule, buffered_fwd: bool):
def preprocess(self, module: NamedModule):
# entire module is skipped
if self.qcfg.dynamic_get(layer_name=module.full_name) == False:
return
Expand All @@ -75,15 +75,6 @@ def preprocess(self, module: NamedModule, buffered_fwd: bool):

tmp = QQQ(module=module, qcfg=qcfg_clone)

# models like DeepSeek v3/r1 has > 256 $ of sub-modules per layer
# use buffered mode go vram don't explode: gptq needs to store fwd inputs per each layer fwd
# all sub-modules within a single layer needs to store all the inputs.
# deepseek has massive # of sub-modules per layer, causing vram pressure
# buffered mode is slower due to gpu<->cpu movement
if buffered_fwd:
log.info(f"Quantize: Enabling fwd buffered mode for: `{module.name}`")
tmp.fwd_inputs_buffered = True

tmp.quantizer.configure(
perchannel=True,
)
Expand Down
3 changes: 0 additions & 3 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,8 +613,6 @@ def generate(
calibration_enable_gpu_cache: Optional[bool] = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
logger_board: Optional[str] = None,
# Experimental: enables the buffering of fwd inputs to cpu, slower than non-buffered, may reduce vram usage
buffered_fwd: bool = False,
# pass-through vars for load()
trust_remote_code: bool = False,
dtype: Optional[Union[str, torch.dtype]] = None,
Expand Down Expand Up @@ -662,6 +660,5 @@ def generate(
calibration_enable_gpu_cache=calibration_enable_gpu_cache,
tokenizer=tokenizer,
logger_board=logger_board,
buffered_fwd=buffered_fwd,
)
return
8 changes: 1 addition & 7 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ..utils.model import MODALITY, find_modules, get_device, get_module_by_name_prefix, move_to
from ..utils.offload import offload_to_disk
from ..utils.structure import alias_from_turtle_for_submodule
from ..utils.torch import TORCH_HAS_COMPILE, torch_compile, torch_empty_cache
from ..utils.torch import TORCH_HAS_COMPILE, torch_compile
from ._const import (CALIBRATION_DATASET_CONCAT_CHAR, CPU, DEFAULT_MAX_SHARD_SIZE,
DEVICE, EXPERT_INDEX_PLACEHOLDER, META)
from .loader import ModelLoader
Expand Down Expand Up @@ -491,8 +491,6 @@ def quantize(
tokenizer: Optional[PreTrainedTokenizerBase] = None,
logger_board: Optional[str] = None,
backend: Optional[BACKEND] = BACKEND.AUTO,
# Experimental: enables the buffering of fwd inputs to cpu, slower than non-buffered, may reduce vram usage
buffered_fwd: bool = False,
# eora adapter generation needs config Lora(rank=1, path='lora.safetensors')
adapter: Adapter = None,
adapter_calibration_dataset: Union[List[Dict[str, Union[List[int], torch.LongTensor]]], List[str], List[int]] = None,
Expand Down Expand Up @@ -672,7 +670,6 @@ def quantize(

return module_looper.loop(
calibration_enable_gpu_cache=calibration_enable_gpu_cache,
buffered_fwd=buffered_fwd,
backend=backend,
fail_safe=self.quantize_config.fail_safe,
)
Expand All @@ -689,8 +686,6 @@ def _eora_generate(
calibration_enable_gpu_cache: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
logger_board: Optional[str] = None,
# Experimental: enables the buffering of fwd inputs to cpu, slower than non-buffered, may reduce vram usage
buffered_fwd: bool = False,
):
if self.quantized:
raise EnvironmentError("eora_generate() is called a model that is already quantized")
Expand Down Expand Up @@ -735,7 +730,6 @@ def _eora_generate(

module_looper.loop(
calibration_enable_gpu_cache=calibration_enable_gpu_cache,
buffered_fwd=buffered_fwd,
)

self.eora_save(save_dir=adapter.path, model_save_dir=self.model_local_path)
Expand Down
24 changes: 2 additions & 22 deletions gptqmodel/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..looper.named_module import NamedModule
from ..quantization import QuantizeConfig
from ..utils.logger import setup_logger
from ..utils.torch import HAS_CUDA, HAS_XPU, device_next, torch_sync
from ..utils.torch import HAS_CUDA, HAS_XPU, device_next
from .gar import compose_final_perm, compute_global_perm, compute_local_perms, invert_perm
from .quantizer import HF_OPTIMUM, Quantizer

Expand Down Expand Up @@ -87,10 +87,6 @@ def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None):

self.quantizer = self.create_quantizer(name=self.name)

# fwd input buffer
self.fwd_inputs_buffered = False
self.fwd_inputs_buffered_data = []

# fwd counter
self.fwd_counter = 0

Expand Down Expand Up @@ -142,13 +138,7 @@ def add_batch(self, inp: torch.Tensor, out: torch.Tensor):

# print(f"self.module.target_device = {self.module.target_device}")

if self.fwd_inputs_buffered:
# with torch_streamCtx(self.module.target_device_stream):
# self.fwd_inputs_buffered_data.append(inp.to(device=self.module.target_device, non_blocking=True))

self.fwd_inputs_buffered_data.append(inp.to(device=self.module.target_device, non_blocking=False))
else:
self.process_batch(inp)
self.process_batch(inp)

def process_batch(self, inp: torch.Tensor):
# print(f"inp = {inp}")
Expand Down Expand Up @@ -296,16 +286,6 @@ def quantize(
# Use simplified hessian inverse (identity matrix)
self.hessian_inverse = self._mock_hessian_inverse

# process buffered inputs
if len(self.fwd_inputs_buffered_data) > 0:
torch_sync(device=self.module.target_device)

for inp in self.fwd_inputs_buffered_data:
self.process_batch(inp)

# release buffer
del self.fwd_inputs_buffered_data

# if self.device.type not in ["mps", "cpu"]:
# self.module.weight.data = self.module.weight.data.cpu()

Expand Down
8 changes: 0 additions & 8 deletions gptqmodel/quantization/gptqv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,6 @@ def quantize(
if not TORCH_GTE_28:
self.hessian_inverse = torch_compile(self.hessian_inverse)

# process buffered inputs
for inp in self.fwd_inputs_buffered_data:
torch_sync(device=self.module.target_device)
self.process_batch(inp)

# release buffer
del self.fwd_inputs_buffered_data

# if self.device.type not in ["mps", "cpu"]:
# self.module.weight.data = self.module.weight.data.cpu()

Expand Down
3 changes: 1 addition & 2 deletions tests/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ class ModelTest(unittest.TestCase):
DATASET_SIZE = 256
DATASET_SORT = "asc"
DELETE_QUANTIZED_MODEL = True
BUFFERED_FWD = False

KERNEL_QUANT = {} # kernel sets
KERNEL_INFERENCE = {} # kernel sets
Expand Down Expand Up @@ -221,7 +220,7 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne
is_ovis_model = model.__class__.__name__ == "OvisGPTQ"
need_create_processor = is_image_to_text_model and not is_ovis_model
if not is_quantized:
model.quantize(calibration_dataset, calibration_sort=self.DATASET_SORT, backend=self.QUANT_BACKEND, batch_size=batch_size, buffered_fwd=self.BUFFERED_FWD)
model.quantize(calibration_dataset, calibration_sort=self.DATASET_SORT, backend=self.QUANT_BACKEND, batch_size=batch_size)

self.check_kernel(model, self.KERNEL_QUANT)

Expand Down
3 changes: 1 addition & 2 deletions tests/test_perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def calculate_native_ppl(self, format):
# (QUANT_METHOD.AUTO_ROUND, FORMAT.GPTQ, 4, 32, False),
]
)
def test_quantized_perplexity(self, method: METHOD, format: FORMAT, bits: int, group_size: int, buffered_fwd: bool = False):
def test_quantized_perplexity(self, method: METHOD, format: FORMAT, bits: int, group_size: int):
if method == METHOD.GPTQ:
quantize_config = QuantizeConfig(
bits=bits,
Expand Down Expand Up @@ -156,7 +156,6 @@ def test_quantized_perplexity(self, method: METHOD, format: FORMAT, bits: int, g
model.quantize(
dataset,
batch_size=128 if IS_ROCM else 256,
# buffered_fwd=buffered_fwd, TODO FIX ME
)
quant_time = time.time() - start

Expand Down