diff --git a/README.md b/README.md index c3115b122..04a5d6fe1 100644 --- a/README.md +++ b/README.md @@ -131,15 +131,15 @@ GPT-QModel not only supports GPTQ but also QQQ, GPTQv2, Eora with more quantizat GPT-QModel is a modular design supporting multiple quantization methods and feature extensions. -| Quantization Feature | GPT-QModel | Transformers | vLLM | SGLang | Lora Training | -|----------------------|------------|---|---|---|---------------| -| GPTQ | ✅ | ✅ | ✅ | ✅ | ✅ | -| EoRA | ✅ | ✅ | ✅ | ✅ | x | -| AWQ | ✅ | ✅* | ✅* | ✅* | ✅* | -| GPTQ v2 | ✅ | ✅ | ✅ | ✅ | ✅ | -| QQQ | ✅ | x | x | x | x | -| Rotation | ✅ | x | x | x | x | -| Group Aware Activitation Reordering (GPTQ) | ✅ | ✅ | ✅ | ✅ | ✅ | +| Quantization Feature | GPT-QModel | Transformers | vLLM | SGLang | Lora Training | +|----------------------------|------------|---|---|---|---------------| +| GPTQ | ✅ | ✅ | ✅ | ✅ | ✅ | +| EoRA | ✅ | ✅ | ✅ | ✅ | x | +| Group Aware Act Reordering | ✅ | ✅ | ✅ | ✅ | ✅ | +| AWQ | ✅ | ✅* | ✅* | ✅* | ✅* | +| QQQ | ✅ | x | x | x | x | +| Rotation | ✅ | x | x | x | x | +| GPTQ v2* | ✅ | ✅ | ✅ | ✅ | ✅ | ## Multi-Modal diff --git a/gptqmodel/eora/eora.py b/gptqmodel/eora/eora.py index eeff9df13..9fd9f0bf8 100644 --- a/gptqmodel/eora/eora.py +++ b/gptqmodel/eora/eora.py @@ -5,6 +5,14 @@ # EoRA Official Repo: https://github.com/NVlabs/EoRA # This file has been modified by ModelCloud.AI team and qubitium@modelcloud.ai for adoption into GPT-QModel +# EoRA +# @article{liu2024eora, +# title={EoRA: Training-free Compensation for Compressed LLM with Eigenspace Low-Rank Approximation}, +# author={Liu, Shih-Yang and Yang, Huck and Wang, Chien-Yi and Fung, Nai Chit and Yin, Hongxu and Sakr, Charbel and Muralidharan, Saurav and Cheng, Kwang-Ting and Kautz, Jan and Wang, Yu-Chiang Frank and others}, +# journal={arXiv preprint arXiv:2410.21271}, +# year={2024} +# } + from typing import Sequence, Tuple import torch diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index f57c218a1..026010a1e 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -28,6 +28,10 @@ if 'CUDA_VISIBLE_DEVICES' in os.environ and 'ROCR_VISIBLE_DEVICES' in os.environ: del os.environ['ROCR_VISIBLE_DEVICES'] +if not os.environ.get("NCCL_SHM_DISABLE", None): + os.environ["NCCL_SHM_DISABLE"] = '1' + log.info("ENV: Auto setting NCCL_SHM_DISABLE=1 for multi-gpu memory safety.") + import sys # noqa: E402 diff --git a/gptqmodel/quantization/gar.py b/gptqmodel/quantization/gar.py index 0b54bbcb6..a09572c14 100644 --- a/gptqmodel/quantization/gar.py +++ b/gptqmodel/quantization/gar.py @@ -3,6 +3,15 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +# Based on Group Aware Reordering (GAR) +# @article{gar, +# title={Dual Precision Quantization for Efficient and Accurate Deep Neural Networks Inference, CVPRW 2025.}, +# author={T. Gafni, A. Karnieli, Y. Hanani}, +# journal={arXiv preprint arXiv:2505.14638}, +# year={2025} +# } +# https://openaccess.thecvf.com/content/CVPR2025W/eLVM/html/Gafni_Dual_Precision_Quantization_for_Efficient_and_Accurate_Deep_Neural_Networks_CVPRW_2025_paper.html + import torch from gptqmodel.utils import setup_logger @@ -136,82 +145,6 @@ def compose_final_perm(local_perms, global_perm, groupsize: int) -> torch.Tensor perm2d = (local + base)[global_perm.to(device=local.device, dtype=torch.long)] # (G,S) return perm2d.reshape(-1) # (G*S,) - -# original algo -def compute_local_perms_original(diag_H, groupsize): - """ - For each group, compute a permutation that orders the indices in descending order - based on the corresponding diagonal values of H. - - Args: - diag_H (Tensor): 1D tensor representing the diagonal of the Hessian. - groupsize (int): Number of columns/weights per group. - - Returns: - local_perms (list of Tensors): Each element is a permutation (indices) for that group. - """ - n = diag_H.numel() - num_groups = n // groupsize - local_perms = [] - for g in range(num_groups): - start = g * groupsize - end = start + groupsize - sub_diag = diag_H[start:end] - # Get local permutation: indices that would sort sub_diag in descending order. - local_perm = torch.argsort(sub_diag, descending=True) - local_perms.append(local_perm) - return local_perms - -# original algo -def compute_global_perm_original(diag_H, groupsize): - """ - Compute a permutation for the groups themselves. Here we choose the maximum diagonal value - within each group as the group metric and sort the groups in descending order. - - Args: - diag_H (Tensor): 1D tensor representing the diagonal of the Hessian. - groupsize (int): Number of columns/weights per group. - - Returns: - global_perm (Tensor): 1D tensor of length num_groups with the new order of groups. - """ - n = diag_H.numel() - num_groups = n // groupsize - group_metric = [] - for g in range(num_groups): - start = g * groupsize - end = start + groupsize - group_metric.append(diag_H[start:end].max().item()) - # Create a tensor on the same device as diag_H. - group_metric = torch.tensor(group_metric, device=diag_H.device) - global_perm = torch.argsort(group_metric, descending=True) - return global_perm - -def compose_final_perm_original(local_perms, global_perm, groupsize): - """ - Compose the final overall permutation from the local and global permutations. - - Args: - local_perms (list of Tensors): Local permutation for each group. - global_perm (Tensor): Global group permutation. - groupsize (int): Number of indices per group. - - Returns: - final_perm (Tensor): 1D tensor that maps original indices to new positions. - """ - num_groups = len(local_perms) - final_perm = [] - # Process groups in the order specified by global_perm. - for new_group in range(num_groups): - # Get the original group index. - orig_group = global_perm[new_group].item() - offset = orig_group * groupsize - local_perm = local_perms[orig_group] - # Adjust local indices to the full index space. - for idx in local_perm: - final_perm.append(idx.item() + offset) - return torch.tensor(final_perm, dtype=torch.long) - def invert_perm(perm): """ Compute the inverse of a permutation vector. diff --git a/gptqmodel/quantization/gar_ref.py b/gptqmodel/quantization/gar_ref.py new file mode 100644 index 000000000..864d9f9dd --- /dev/null +++ b/gptqmodel/quantization/gar_ref.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +# Based on Group Aware Reordering (GAR) +# @article{gar, +# title={Dual Precision Quantization for Efficient and Accurate Deep Neural Networks Inference, CVPRW 2025.}, +# author={T. Gafni, A. Karnieli, Y. Hanani}, +# journal={arXiv preprint arXiv:2505.14638}, +# year={2025} +# } +# https://openaccess.thecvf.com/content/CVPR2025W/eLVM/html/Gafni_Dual_Precision_Quantization_for_Efficient_and_Accurate_Deep_Neural_Networks_CVPRW_2025_paper.html + +import torch + + +def compute_local_perms_original(diag_H, groupsize): + """ + For each group, compute a permutation that orders the indices in descending order + based on the corresponding diagonal values of H. + + Args: + diag_H (Tensor): 1D tensor representing the diagonal of the Hessian. + groupsize (int): Number of columns/weights per group. + + Returns: + local_perms (list of Tensors): Each element is a permutation (indices) for that group. + """ + n = diag_H.numel() + num_groups = n // groupsize + local_perms = [] + for g in range(num_groups): + start = g * groupsize + end = start + groupsize + sub_diag = diag_H[start:end] + # Get local permutation: indices that would sort sub_diag in descending order. + local_perm = torch.argsort(sub_diag, descending=True) + local_perms.append(local_perm) + return local_perms + + +def compute_global_perm_original(diag_H, groupsize): + """ + Compute a permutation for the groups themselves. Here we choose the maximum diagonal value + within each group as the group metric and sort the groups in descending order. + + Args: + diag_H (Tensor): 1D tensor representing the diagonal of the Hessian. + groupsize (int): Number of columns/weights per group. + + Returns: + global_perm (Tensor): 1D tensor of length num_groups with the new order of groups. + """ + n = diag_H.numel() + num_groups = n // groupsize + group_metric = [] + for g in range(num_groups): + start = g * groupsize + end = start + groupsize + group_metric.append(diag_H[start:end].max().item()) + # Create a tensor on the same device as diag_H. + group_metric = torch.tensor(group_metric, device=diag_H.device) + global_perm = torch.argsort(group_metric, descending=True) + return global_perm + + +def compose_final_perm_original(local_perms, global_perm, groupsize): + """ + Compose the final overall permutation from the local and global permutations. + + Args: + local_perms (list of Tensors): Local permutation for each group. + global_perm (Tensor): Global group permutation. + groupsize (int): Number of indices per group. + + Returns: + final_perm (Tensor): 1D tensor that maps original indices to new positions. + """ + num_groups = len(local_perms) + final_perm = [] + # Process groups in the order specified by global_perm. + for new_group in range(num_groups): + # Get the original group index. + orig_group = global_perm[new_group].item() + offset = orig_group * groupsize + local_perm = local_perms[orig_group] + # Adjust local indices to the full index space. + for idx in local_perm: + final_perm.append(idx.item() + offset) + return torch.tensor(final_perm, dtype=torch.long) diff --git a/gptqmodel/utils/offload.py b/gptqmodel/utils/offload.py index 97568bf0f..8d0ee5b36 100644 --- a/gptqmodel/utils/offload.py +++ b/gptqmodel/utils/offload.py @@ -92,8 +92,7 @@ def _bundle_module_state_dict(module: nn.Module, offload_dir: str) -> dict: tensors: dict[str, torch.Tensor] = {} with torch.inference_mode(): - with _OFFLOAD_LOCK: - state_items = list(module.state_dict().items()) + state_items = list(module.state_dict().items()) for key, tensor in state_items: cpu_tensor = tensor.detach().to("cpu") @@ -132,8 +131,8 @@ def _bundle_module_state_dict(module: nn.Module, offload_dir: str) -> dict: def offload_to_disk(module: List[str] | nn.Module, model: nn.Module, disk_path: str = "."): - #with _OFFLOAD_LOCK: - _offload_to_disk_impl(module=module, model=model, disk_path=disk_path) + with _OFFLOAD_LOCK: + _offload_to_disk_impl(module=module, model=model, disk_path=disk_path) def _offload_to_disk_impl(module: List[str] | nn.Module, model: nn.Module, disk_path: str = "."): @@ -193,11 +192,8 @@ def _offload_disk(module: nn.Module, name: str, disk_path: str = "."): module_offload_dir = os.path.join(disk_path, name) total_bytes = 0 - try: - with _OFFLOAD_LOCK: - state_items = list(module.state_dict().values()) - except Exception: - state_items = [] + + state_items = list(module.state_dict().values()) for tensor in state_items: total_bytes += _tensor_nbytes(tensor) diff --git a/setup.py b/setup.py index 56132edb1..3be472315 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ import sys from pathlib import Path -from setuptools import find_packages, find_namespace_packages, setup +from setuptools import find_namespace_packages, find_packages, setup from setuptools.command.bdist_wheel import bdist_wheel as _bdist_wheel