Skip to content
Merged

cleanup #2066

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,15 @@ GPT-QModel not only supports GPTQ but also QQQ, GPTQv2, Eora with more quantizat

GPT-QModel is a modular design supporting multiple quantization methods and feature extensions.

| Quantization Feature | GPT-QModel | Transformers | vLLM | SGLang | Lora Training |
|----------------------|------------|---|---|---|---------------|
| GPTQ | ✅ | ✅ | ✅ | ✅ | ✅ |
| EoRA | ✅ | ✅ | ✅ | ✅ | x |
| AWQ | ✅ | ✅* | ✅* | ✅* | ✅* |
| GPTQ v2 | ✅ | ✅ | ✅ | ✅ | ✅ |
| QQQ | ✅ | x | x | x | x |
| Rotation | ✅ | x | x | x | x |
| Group Aware Activitation Reordering (GPTQ) | ✅ | ✅ | ✅ | ✅ | ✅ |
| Quantization Feature | GPT-QModel | Transformers | vLLM | SGLang | Lora Training |
|----------------------------|------------|---|---|---|---------------|
| GPTQ | ✅ | ✅ | ✅ | ✅ | ✅ |
| EoRA | ✅ | ✅ | ✅ | ✅ | x |
| Group Aware Act Reordering | ✅ | ✅ | ✅ | ✅ | ✅ |
| AWQ | ✅ | ✅* | ✅* | ✅* | ✅* |
| QQQ | ✅ | x | x | x | x |
| Rotation | ✅ | x | x | x | x |
| GPTQ v2* | ✅ | ✅ | ✅ | ✅ | ✅ |

## Multi-Modal

Expand Down
8 changes: 8 additions & 0 deletions gptqmodel/eora/eora.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
# EoRA Official Repo: https://github.com/NVlabs/EoRA
# This file has been modified by ModelCloud.AI team and qubitium@modelcloud.ai for adoption into GPT-QModel

# EoRA
# @article{liu2024eora,
# title={EoRA: Training-free Compensation for Compressed LLM with Eigenspace Low-Rank Approximation},
# author={Liu, Shih-Yang and Yang, Huck and Wang, Chien-Yi and Fung, Nai Chit and Yin, Hongxu and Sakr, Charbel and Muralidharan, Saurav and Cheng, Kwang-Ting and Kautz, Jan and Wang, Yu-Chiang Frank and others},
# journal={arXiv preprint arXiv:2410.21271},
# year={2024}
# }

from typing import Sequence, Tuple

import torch
Expand Down
4 changes: 4 additions & 0 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
if 'CUDA_VISIBLE_DEVICES' in os.environ and 'ROCR_VISIBLE_DEVICES' in os.environ:
del os.environ['ROCR_VISIBLE_DEVICES']

if not os.environ.get("NCCL_SHM_DISABLE", None):
os.environ["NCCL_SHM_DISABLE"] = '1'
log.info("ENV: Auto setting NCCL_SHM_DISABLE=1 for multi-gpu memory safety.")

import sys # noqa: E402


Expand Down
85 changes: 9 additions & 76 deletions gptqmodel/quantization/gar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

# Based on Group Aware Reordering (GAR)
# @article{gar,
# title={Dual Precision Quantization for Efficient and Accurate Deep Neural Networks Inference, CVPRW 2025.},
# author={T. Gafni, A. Karnieli, Y. Hanani},
# journal={arXiv preprint arXiv:2505.14638},
# year={2025}
# }
# https://openaccess.thecvf.com/content/CVPR2025W/eLVM/html/Gafni_Dual_Precision_Quantization_for_Efficient_and_Accurate_Deep_Neural_Networks_CVPRW_2025_paper.html

import torch

from gptqmodel.utils import setup_logger
Expand Down Expand Up @@ -136,82 +145,6 @@ def compose_final_perm(local_perms, global_perm, groupsize: int) -> torch.Tensor
perm2d = (local + base)[global_perm.to(device=local.device, dtype=torch.long)] # (G,S)
return perm2d.reshape(-1) # (G*S,)


# original algo
def compute_local_perms_original(diag_H, groupsize):
"""
For each group, compute a permutation that orders the indices in descending order
based on the corresponding diagonal values of H.

Args:
diag_H (Tensor): 1D tensor representing the diagonal of the Hessian.
groupsize (int): Number of columns/weights per group.

Returns:
local_perms (list of Tensors): Each element is a permutation (indices) for that group.
"""
n = diag_H.numel()
num_groups = n // groupsize
local_perms = []
for g in range(num_groups):
start = g * groupsize
end = start + groupsize
sub_diag = diag_H[start:end]
# Get local permutation: indices that would sort sub_diag in descending order.
local_perm = torch.argsort(sub_diag, descending=True)
local_perms.append(local_perm)
return local_perms

# original algo
def compute_global_perm_original(diag_H, groupsize):
"""
Compute a permutation for the groups themselves. Here we choose the maximum diagonal value
within each group as the group metric and sort the groups in descending order.

Args:
diag_H (Tensor): 1D tensor representing the diagonal of the Hessian.
groupsize (int): Number of columns/weights per group.

Returns:
global_perm (Tensor): 1D tensor of length num_groups with the new order of groups.
"""
n = diag_H.numel()
num_groups = n // groupsize
group_metric = []
for g in range(num_groups):
start = g * groupsize
end = start + groupsize
group_metric.append(diag_H[start:end].max().item())
# Create a tensor on the same device as diag_H.
group_metric = torch.tensor(group_metric, device=diag_H.device)
global_perm = torch.argsort(group_metric, descending=True)
return global_perm

def compose_final_perm_original(local_perms, global_perm, groupsize):
"""
Compose the final overall permutation from the local and global permutations.

Args:
local_perms (list of Tensors): Local permutation for each group.
global_perm (Tensor): Global group permutation.
groupsize (int): Number of indices per group.

Returns:
final_perm (Tensor): 1D tensor that maps original indices to new positions.
"""
num_groups = len(local_perms)
final_perm = []
# Process groups in the order specified by global_perm.
for new_group in range(num_groups):
# Get the original group index.
orig_group = global_perm[new_group].item()
offset = orig_group * groupsize
local_perm = local_perms[orig_group]
# Adjust local indices to the full index space.
for idx in local_perm:
final_perm.append(idx.item() + offset)
return torch.tensor(final_perm, dtype=torch.long)

def invert_perm(perm):
"""
Compute the inverse of a permutation vector.
Expand Down
91 changes: 91 additions & 0 deletions gptqmodel/quantization/gar_ref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

# Based on Group Aware Reordering (GAR)
# @article{gar,
# title={Dual Precision Quantization for Efficient and Accurate Deep Neural Networks Inference, CVPRW 2025.},
# author={T. Gafni, A. Karnieli, Y. Hanani},
# journal={arXiv preprint arXiv:2505.14638},
# year={2025}
# }
# https://openaccess.thecvf.com/content/CVPR2025W/eLVM/html/Gafni_Dual_Precision_Quantization_for_Efficient_and_Accurate_Deep_Neural_Networks_CVPRW_2025_paper.html

import torch


def compute_local_perms_original(diag_H, groupsize):
"""
For each group, compute a permutation that orders the indices in descending order
based on the corresponding diagonal values of H.

Args:
diag_H (Tensor): 1D tensor representing the diagonal of the Hessian.
groupsize (int): Number of columns/weights per group.

Returns:
local_perms (list of Tensors): Each element is a permutation (indices) for that group.
"""
n = diag_H.numel()
num_groups = n // groupsize
local_perms = []
for g in range(num_groups):
start = g * groupsize
end = start + groupsize
sub_diag = diag_H[start:end]
# Get local permutation: indices that would sort sub_diag in descending order.
local_perm = torch.argsort(sub_diag, descending=True)
local_perms.append(local_perm)
return local_perms


def compute_global_perm_original(diag_H, groupsize):
"""
Compute a permutation for the groups themselves. Here we choose the maximum diagonal value
within each group as the group metric and sort the groups in descending order.

Args:
diag_H (Tensor): 1D tensor representing the diagonal of the Hessian.
groupsize (int): Number of columns/weights per group.

Returns:
global_perm (Tensor): 1D tensor of length num_groups with the new order of groups.
"""
n = diag_H.numel()
num_groups = n // groupsize
group_metric = []
for g in range(num_groups):
start = g * groupsize
end = start + groupsize
group_metric.append(diag_H[start:end].max().item())
# Create a tensor on the same device as diag_H.
group_metric = torch.tensor(group_metric, device=diag_H.device)
global_perm = torch.argsort(group_metric, descending=True)
return global_perm


def compose_final_perm_original(local_perms, global_perm, groupsize):
"""
Compose the final overall permutation from the local and global permutations.

Args:
local_perms (list of Tensors): Local permutation for each group.
global_perm (Tensor): Global group permutation.
groupsize (int): Number of indices per group.

Returns:
final_perm (Tensor): 1D tensor that maps original indices to new positions.
"""
num_groups = len(local_perms)
final_perm = []
# Process groups in the order specified by global_perm.
for new_group in range(num_groups):
# Get the original group index.
orig_group = global_perm[new_group].item()
offset = orig_group * groupsize
local_perm = local_perms[orig_group]
# Adjust local indices to the full index space.
for idx in local_perm:
final_perm.append(idx.item() + offset)
return torch.tensor(final_perm, dtype=torch.long)
14 changes: 5 additions & 9 deletions gptqmodel/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def _bundle_module_state_dict(module: nn.Module, offload_dir: str) -> dict:
tensors: dict[str, torch.Tensor] = {}

with torch.inference_mode():
with _OFFLOAD_LOCK:
state_items = list(module.state_dict().items())
state_items = list(module.state_dict().items())

for key, tensor in state_items:
cpu_tensor = tensor.detach().to("cpu")
Expand Down Expand Up @@ -132,8 +131,8 @@ def _bundle_module_state_dict(module: nn.Module, offload_dir: str) -> dict:


def offload_to_disk(module: List[str] | nn.Module, model: nn.Module, disk_path: str = "."):
#with _OFFLOAD_LOCK:
_offload_to_disk_impl(module=module, model=model, disk_path=disk_path)
with _OFFLOAD_LOCK:
_offload_to_disk_impl(module=module, model=model, disk_path=disk_path)


def _offload_to_disk_impl(module: List[str] | nn.Module, model: nn.Module, disk_path: str = "."):
Expand Down Expand Up @@ -193,11 +192,8 @@ def _offload_disk(module: nn.Module, name: str, disk_path: str = "."):
module_offload_dir = os.path.join(disk_path, name)

total_bytes = 0
try:
with _OFFLOAD_LOCK:
state_items = list(module.state_dict().values())
except Exception:
state_items = []

state_items = list(module.state_dict().values())

for tensor in state_items:
total_bytes += _tensor_nbytes(tensor)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from pathlib import Path

from setuptools import find_packages, find_namespace_packages, setup
from setuptools import find_namespace_packages, find_packages, setup
from setuptools.command.bdist_wheel import bdist_wheel as _bdist_wheel


Expand Down