Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions gptqmodel/looper/tensorparallel_weight_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,15 @@ def __init__(self, *args, **kwargs):
kwargs.setdefault("require_fwd", False)
kwargs.setdefault("fwd_after_process", False)
super().__init__(*args, **kwargs)
qcfg_from_kwargs = kwargs.pop("qcfg", None)
if qcfg_from_kwargs is not None:
self.qcfg = qcfg_from_kwargs

self._target_multiple = math.lcm(*self._TP_TARGETS)

if self.qcfg and hasattr(self.qcfg, 'group_size') and self.qcfg.group_size > 0:
self._target_multiple = math.lcm(self._target_multiple, self.qcfg.group_size)

def preprocess(self, module: NamedModule): # pragma: no cover - simple hook
# The processor operates on every eligible module; no setup required.
pass
Expand Down
3 changes: 3 additions & 0 deletions tests/models/test_longllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from model_test import ModelTest

from gptqmodel.utils.eval import EVAL
from gptqmodel.utils.backend import BACKEND



class TestLongLlama(ModelTest):
Expand All @@ -19,6 +21,7 @@ class TestLongLlama(ModelTest):
}
USE_VLLM = False
USE_FLASH_ATTN = False
LOAD_BACKEND = BACKEND.TORCH

def test_longllama(self):
self.quant_lm_eval()
120 changes: 119 additions & 1 deletion tests/test_tensorparallel_weight_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

import math
import torch

from gptqmodel.looper.named_module import NamedModule
Expand Down Expand Up @@ -34,7 +35,6 @@ def test_tensorparallel_pre_padding_applies_zero_pad_metadata():
calibration_concat_size=None,
calibration_sort=None,
batch_size=1,
logger_board="",
)

preprocessor.process(named)
Expand Down Expand Up @@ -62,3 +62,121 @@ def test_tensorparallel_pre_padding_applies_zero_pad_metadata():

gptq.free()
assert "tp_pad_info" not in named.state


def test_tensorparallel_weight_processor_with_positive_group_size():
"""Test that _target_multiple is correctly calculated when group_size > 0."""
linear = torch.nn.Linear(10, 7, bias=False)
named = NamedModule(linear, name="proj", full_name="layer.0.proj", layer_index=0)

qcfg = QuantizeConfig(bits=4, mock_quantization=True)
qcfg.group_size = 128 # Positive group_size
qcfg.desc_act = False
qcfg.act_group_aware = False

calibration_stub = [{"input_ids": torch.ones((1, 1), dtype=torch.long)}]

preprocessor = TensorParallelWeightProcessor(
tokenizer=None,
qcfg=qcfg,
calibration=calibration_stub,
prepare_dataset_func=_noop_prepare_dataset,
calibration_concat_size=None,
calibration_sort=None,
batch_size=1,
)

# Verify that _target_multiple includes group_size in LCM calculation
# Default TP_TARGETS = (2, 4, 8), so math.lcm(2, 4, 8) = 8
# With group_size = 128, math.lcm(8, 128) = 128
expected_target_multiple = math.lcm(8, 128)
assert preprocessor._target_multiple == expected_target_multiple
assert preprocessor._target_multiple == 128


def test_tensorparallel_weight_processor_with_negative_group_size():
"""Test that _target_multiple uses default value when group_size < 0."""
linear = torch.nn.Linear(10, 7, bias=False)
named = NamedModule(linear, name="proj", full_name="layer.0.proj", layer_index=0)

qcfg = QuantizeConfig(bits=4, mock_quantization=True)
qcfg.group_size = -1 # Negative group_size
qcfg.desc_act = False
qcfg.act_group_aware = False

calibration_stub = [{"input_ids": torch.ones((1, 1), dtype=torch.long)}]

preprocessor = TensorParallelWeightProcessor(
tokenizer=None,
qcfg=qcfg,
calibration=calibration_stub,
prepare_dataset_func=_noop_prepare_dataset,
calibration_concat_size=None,
calibration_sort=None,
batch_size=1,
)

# Verify that _target_multiple only uses TP_TARGETS when group_size < 0
# Default TP_TARGETS = (2, 4, 8), so math.lcm(2, 4, 8) = 8
expected_target_multiple = math.lcm(2, 4, 8)
assert preprocessor._target_multiple == expected_target_multiple
assert preprocessor._target_multiple == 8


def test_tensorparallel_weight_processor_group_size_lcm_calculation():
"""Test LCM calculation with various group_size values."""
linear = torch.nn.Linear(10, 7, bias=False)
named = NamedModule(linear, name="proj", full_name="layer.0.proj", layer_index=0)

calibration_stub = [{"input_ids": torch.ones((1, 1), dtype=torch.long)}]

# Test with group_size = 32
qcfg = QuantizeConfig(bits=4, mock_quantization=True)
qcfg.group_size = 32
qcfg.desc_act = False
qcfg.act_group_aware = False

preprocessor = TensorParallelWeightProcessor(
tokenizer=None,
qcfg=qcfg,
calibration=calibration_stub,
prepare_dataset_func=_noop_prepare_dataset,
calibration_concat_size=None,
calibration_sort=None,
batch_size=1,
)

# math.lcm(8, 32) = 32
assert preprocessor._target_multiple == 32

# Test with group_size = 64
qcfg.group_size = 64
preprocessor = TensorParallelWeightProcessor(
tokenizer=None,
qcfg=qcfg,
calibration=calibration_stub,
prepare_dataset_func=_noop_prepare_dataset,
calibration_concat_size=None,
calibration_sort=None,
batch_size=1,
)

# math.lcm(8, 64) = 64
assert preprocessor._target_multiple == 64

# Test with group_size = 12 (not a power of 2)
qcfg.group_size = 12
preprocessor = TensorParallelWeightProcessor(
tokenizer=None,
qcfg=qcfg,
calibration=calibration_stub,
prepare_dataset_func=_noop_prepare_dataset,
calibration_concat_size=None,
calibration_sort=None,
batch_size=1,
)

# math.lcm(8, 12) = 24
expected_lcm = math.lcm(8, 12)
assert preprocessor._target_multiple == expected_lcm
assert preprocessor._target_multiple == 24