Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 52 additions & 18 deletions gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import time
from concurrent.futures import as_completed
from contextlib import nullcontext
from typing import Dict, List, Optional, TYPE_CHECKING
from typing import Dict, List, NamedTuple, Optional, TYPE_CHECKING

import torch

Expand Down Expand Up @@ -59,6 +59,12 @@
from logbar.progress import ProgressBar


class FinalizeProgressInfo(NamedTuple):
module_label: Optional[str]
process_name: str
layer_idx: Optional[int]


class ModuleLooper():
"""Drive the per-layer quantisation workflow over one or more devices.

Expand Down Expand Up @@ -488,6 +494,9 @@ def _run_forward_batches_parallel(
"""Fan batches across device clones and preserve result ordering."""
module_replicas = clone_module_for_devices(module, devices)

# Ensure any async replication/memcpy ops are complete before threads start fanning out.
torch_sync()

prev_kv = shared_kv_cache_dict.get(layer_index - 1) if reuse_kv else None

results: Dict[int, torch.Tensor | tuple | None] = {}
Expand Down Expand Up @@ -1308,6 +1317,9 @@ def _finalize_on_worker(process, module, idx, total, module_label, layer_idx):
source=resolved_label,
)

process_name = process.name() if process is not None else "<processor>"
return FinalizeProgressInfo(module_label, process_name, layer_idx)

# pb.subtitle(
# f"{process.name()}: layer:{layer_idx} Finalized {idx}/{total} {module_label}"
# ).draw()
Expand All @@ -1328,34 +1340,57 @@ def _finalize_on_worker(process, module, idx, total, module_label, layer_idx):
finalize_futures_snapshot = list(finalize_futures)

if finalize_futures_snapshot:
known_layers = sorted(
{
layer_idx
for _, _, _, _, layer_idx in finalize_futures_snapshot
if layer_idx is not None
}
)
includes_unknown = any(
layer_idx is None
for _, _, _, _, layer_idx in finalize_futures_snapshot
)

layer_heading = "Layer ?"
if known_layers:
sample_layers = ", ".join(str(idx) for idx in known_layers[:3])
if len(known_layers) > 3:
sample_layers += ", …"
suffix = ", ?" if includes_unknown else ""
prefix = "Layer" if len(known_layers) == 1 else "Layers"
layer_heading = f"{prefix} {sample_layers}{suffix}"
elif includes_unknown:
layer_heading = "Layer ?"

finalize_pb.title(
f"Submodule finalize 0/{finalize_count}"
f"{layer_heading} Submodule finalize 0/{finalize_count}"
).subtitle("Waiting for completions...").draw()

future_metadata = {
future: (module_label, process, layer_idx)
for future, _, module_label, process, layer_idx in finalize_futures_snapshot
}

def _drain_finalize_futures(
futures,
finalize_pb_local,
finalize_count_local,
future_metadata_local,
):
completed_local = 0
try:
for future in as_completed(futures):
module_label, process, layer_idx = future_metadata_local.get(
future, (None, None, None)
)

future.result()

layer_label = f"Layer {layer_idx}" if layer_idx is not None else "layer ?"
result = future.result()

if isinstance(result, FinalizeProgressInfo):
module_label = result.module_label
process_name = result.process_name
layer_idx = result.layer_idx
elif isinstance(result, tuple) and len(result) == 3:
module_label, process_name, layer_idx = result
else:
module_label = None
process_name = "<processor>"
layer_idx = None

layer_label = f"Layer {layer_idx}" if layer_idx is not None else "Layer ?"
display_module = module_label or "<unnamed>"
processor_name = process.name() if process is not None else "<processor>"
subtitle = f"{processor_name}: {display_module}"
subtitle = f"{process_name}: {display_module}"

completed_local += 1
finalize_pb_local.next()
Expand All @@ -1373,7 +1408,6 @@ def _drain_finalize_futures(
[future for future, *_ in finalize_futures_snapshot],
finalize_pb,
finalize_count,
future_metadata,
),
name="SubmoduleFinalizeWatcher",
daemon=True,
Expand Down
5 changes: 2 additions & 3 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from .base import BaseQModel, QuantizeConfig # noqa: E402
from .definitions.apertus import ApertusQModel # noqa: E402
from .definitions.baichuan import BaiChuanQModel # noqa: E402
from .definitions.bailing_moe import BailingMoeQModel # noqa: E402
from .definitions.bloom import BloomQModel # noqa: E402
from .definitions.chatglm import ChatGLMQModel # noqa: E402
from .definitions.codegen import CodeGenQModel # noqa: E402
Expand Down Expand Up @@ -85,6 +86,7 @@
from .definitions.internlm import InternLMQModel # noqa: E402
from .definitions.internlm2 import InternLM2QModel # noqa: E402
from .definitions.klear import KlearQModel # noqa: E402
from .definitions.lfm2_moe import LFM2MoeQModel # noqa: E402
from .definitions.llama import LlamaQModel # noqa: E402
from .definitions.llama4 import Llama4QModel # noqa: E402
from .definitions.llava_qwen2 import LlavaQwen2QModel # noqa: E402
Expand Down Expand Up @@ -118,9 +120,6 @@
from .definitions.starcoder2 import Starcoder2QModel # noqa: E402
from .definitions.telechat2 import TeleChat2QModel
from .definitions.xverse import XverseQModel # noqa: E402
from .definitions.bailing_moe import BailingMoeQModel # noqa: E402
from .definitions.lfm2_moe import LFM2MoeQModel # noqa: E402



# make quants and inference more determinisitc
Expand Down
4 changes: 2 additions & 2 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@
HFDataset = None
HFIterableDataset = None

from .. import DEVICE_THREAD_POOL
from ..adapter.adapter import Adapter
from ..nn_modules.qlinear import BaseQuantLinear
from ..nn_modules.qlinear.torch import TorchQuantLinear
from ..nn_modules.qlinear.lookahead import configure_default_lookahead
from ..nn_modules.qlinear.torch import TorchQuantLinear
from ..quantization import QuantizeConfig
from ..quantization.config import FORMAT, METHOD, QUANTIZE_BLACK_LIST, dynamic_get
from ..quantization.rotation.rotation import fuse_layer_norms, rotate_model
from .. import DEVICE_THREAD_POOL
from ..utils.backend import BACKEND
from ..utils.data import collate_data
from ..utils.device import get_device
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/models/definitions/lfm2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class LFM2MoeQModel(BaseQModel):
"operator_norm": ("operator_norm:!",),
"conv": ("in_proj", "out_proj"),
"self_attn": ("q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"),
"ffn_norm": ("ffn_norm:!",),
"ffn_norm": ("ffn_norm:!",),
"feed_forward": {
"gate": ("gate:!",),
"": ("w1:0", "w3:0", "w2:1"),
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import csv
import json
import os
import pcre as re
from os.path import isfile, join
from typing import Any, Dict, Optional, Union

import pcre as re
import torch
import transformers
from safetensors.torch import save_file
Expand Down
4 changes: 2 additions & 2 deletions gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import List, Optional, Tuple

import numpy as np
import threadpoolctl
import torch as t # conflict with torch.py
import torch.nn as nn
import transformers
Expand All @@ -20,6 +19,7 @@
from ...models._const import DEVICE, PLATFORM
from ...utils.backend import BACKEND
from ...utils.logger import setup_logger
from ...utils.safe import THREADPOOLCTL


log = setup_logger()
Expand Down Expand Up @@ -890,7 +890,7 @@ def _pack_rows_3(int32_blk_32xN: t.Tensor, dst: t.Tensor, dst_rows_base: int):
del weight, scales_dev, zeros_dev, scale_zeros_dev, qweight_dev, qzeros_dev

def pack_original(self, linear: nn.Module, scales: t.Tensor, zeros: t.Tensor, g_idx: t.Tensor=None):
with threadpoolctl.threadpool_limits(1):
with THREADPOOLCTL.threadpool_limits(1):
# TODO why did we need to clone? at packing, the original weight is no longer used by other processors?
# W = linear.weight.data.clone()
W = linear.weight.data
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/nn_modules/qlinear/awq_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ...adapter.adapter import Adapter, Lora
from ...models._const import DEVICE, PLATFORM
from ...nn_modules.qlinear import AWQuantLinear, PackableQuantLinear
from ...nn_modules.qlinear import AWQuantLinear
from ...quantization.awq.modules.linear.gemm import WQLinearMMFunction
from ...utils.backend import BACKEND
from ...utils.logger import setup_logger
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/nn_modules/qlinear/awq_gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ...adapter.adapter import Adapter, Lora
from ...models._const import DEVICE, PLATFORM
from ...nn_modules.qlinear import AWQuantLinear, PackableQuantLinear
from ...nn_modules.qlinear import AWQuantLinear
from ...quantization.awq.utils.module import try_import
from ...utils.backend import BACKEND
from ...utils.gemv import calculate_zeros_width
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/nn_modules/qlinear/awq_gemv_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ...adapter.adapter import Adapter, Lora
from ...models._const import DEVICE, PLATFORM
from ...nn_modules.qlinear import AWQuantLinear, PackableQuantLinear
from ...nn_modules.qlinear import AWQuantLinear
from ...quantization.awq.utils.module import try_import
from ...utils.backend import BACKEND
from ...utils.gemv import calculate_zeros_width
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/nn_modules/qlinear/pack_block_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch import Tensor
from torch.utils.cpp_extension import load


log = logging.getLogger(__name__)

_EXTENSION = None
Expand Down
3 changes: 2 additions & 1 deletion gptqmodel/nn_modules/qlinear/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ...utils.logger import setup_logger
from ...utils.torch import torch_compile


try:
from ..triton_utils.dequant import dequant as triton_dequant

Expand Down Expand Up @@ -259,7 +260,7 @@ def schedule(tile_idx: int, buffer_idx: int):
compute_stream.wait_stream(stream_dequant)
width = widths[buffer_idx]
start = tile_idx * tile
end = start + width
start + width

out_slice = out.narrow(1, start, width)
out_slice.zero_()
Expand Down
21 changes: 12 additions & 9 deletions gptqmodel/quantization/awq/modules/triton/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import triton
import triton.language as tl


AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]

def get_same_device_cm(t):
Expand Down Expand Up @@ -282,10 +283,11 @@ def awq_dequantize_triton(
Y = qweight.shape[0] # num rows
X = qweight.shape[1] # num cols

grid = lambda META: (
triton.cdiv(X, META["BLOCK_SIZE_X"]),
triton.cdiv(Y, META["BLOCK_SIZE_Y"]),
)
def grid(META):
return (
triton.cdiv(X, META["BLOCK_SIZE_X"]),
triton.cdiv(Y, META["BLOCK_SIZE_Y"]),
)
with get_same_device_cm(qweight):
awq_dequantize_kernel[grid](
qweight,
Expand Down Expand Up @@ -330,10 +332,11 @@ def awq_gemm_triton(
assert group_size <= K
assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K

grid = lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
split_k_iters,
)
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
split_k_iters,
)

result = torch.zeros((M, N), dtype=scales.dtype, device=input.device)

Expand All @@ -356,4 +359,4 @@ def awq_gemm_triton(
SPLIT_K=split_k_iters,
)

return result
return result
2 changes: 1 addition & 1 deletion gptqmodel/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

import json
import os.path
import pcre as re
from dataclasses import dataclass, field, fields
from enum import Enum
from os.path import join
from typing import Any, Dict, List, Optional, Tuple, Union

import pcre as re
import torch
from packaging import version
from random_word import random_word
Expand Down
11 changes: 9 additions & 2 deletions gptqmodel/utils/bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
# Contact: qubitium@modelcloud.ai, x.com/qubitium

import os
from contextlib import nullcontext

import threadpoolctl as tctl
import torch

from ..nn_modules.qlinear.bitblas import BitBLASQuantLinear
from ..quantization import FORMAT, QuantizeConfig
from ..utils.logger import setup_logger
from .model import load_checkpoint_in_model_then_tie_weights
from .safe import THREADPOOLCTL
from .torch import torch_empty_cache


Expand Down Expand Up @@ -76,7 +77,13 @@ def convert_to_bitblas(model, model_quantlinear, qcfg: QuantizeConfig, sym: bool
message = "Overriding QuantLinear layers to use BitBLAS's QuantLinear..."

# TODO: need to benchmark to see multiple threads help with bitblas/tvm compilation and runtime
with tctl.threadpool_limits(limits=1):
threadpool_limits = (
THREADPOOLCTL.threadpool_limits
if THREADPOOLCTL is not None
else (lambda *args, **kwargs: nullcontext())
)

with threadpool_limits(limits=1):
os.environ["NUMEXPR_MAX_THREADS"] = "1"

# Note that due to tvm compilation of per layer modules shapes, the first layer loop is
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/utils/mmlupro.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import json
import os
import random
import pcre as re
import time

import pcre as re
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import math
import operator
import os
import pcre as re
import shutil
import struct
import threading
Expand All @@ -22,6 +21,7 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import accelerate
import pcre as re
import torch
import torch.nn as nn
import transformers
Expand Down
Loading
Loading