Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions gptqmodel/looper/gptq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ def preprocess(self, module: NamedModule, fail_safe: bool):
qcfg_clone.act_group_aware = act_group_aware_override
qcfg_clone.damp_percent = self.qcfg.dynamic_get(module.full_name, "damp_percent", qcfg_clone.damp_percent)
qcfg_clone.static_groups = self.qcfg.dynamic_get(module.full_name, "static_groups", qcfg_clone.static_groups)
qcfg_clone.v2 = self.qcfg.dynamic_get(module.full_name, "v2", qcfg_clone.v2)
qcfg_clone.v2_alpha = self.qcfg.dynamic_get(module.full_name, "v2_alpha", qcfg_clone.v2_alpha)
qcfg_clone.gptaq = self.qcfg.dynamic_get(module.full_name, "gptaq", qcfg_clone.gptaq)
qcfg_clone.gptaq_alpha = self.qcfg.dynamic_get(module.full_name, "gptaq_alpha", qcfg_clone.gptaq_alpha)

qcfg_clone._resolve_activation_ordering(desc_act_override, act_group_aware_override)

# store last used qcfg_dynamic
self.qcfg_dynamic = qcfg_clone

if qcfg_clone.v2 is True:
if qcfg_clone.gptaq is True:
tmp = GPTQv2(module=module, qcfg=qcfg_clone)
else:
tmp = GPTQ(module=module, qcfg=qcfg_clone)
Expand Down Expand Up @@ -383,4 +383,4 @@ def verify_calibration_dataset(self, processor_index: int) -> bool:
def name(self) -> str:
# TODO fix me..this hacks inherited base class logic, why not override name in gptqv2?
qcfg = self.qcfg_dynamic if self.qcfg_dynamic is not None else self.qcfg
return "gptq v2" if qcfg.v2 else "gptq"
return "gptaq" if qcfg.gptaq else "gptq"
2 changes: 1 addition & 1 deletion gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,7 @@ def _loop_impl(self, fail_safe: bool = False, **kwargs):
for p_index, processor in enumerate(self.processors):
if not processor.verify_calibration_dataset(p_index):
if isinstance(processor, EoraProcessor) or\
(isinstance(processor, GPTQProcessor) and self.gptq_model.quantize_config.v2):
(isinstance(processor, GPTQProcessor) and self.gptq_model.quantize_config.gptaq):
prev_processor = self.processors[p_index - 1]
processor.set_calibration_dataset(prev_processor.calibration_dataset)
# If calibration_dataset is None or Empty, the input_cache of the previous processor is used.
Expand Down
20 changes: 10 additions & 10 deletions gptqmodel/looper/native_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,19 @@ def tmp(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
# gptq is mutable.
inp = inp[0].detach()

if self.qcfg.v2_memory_device == "auto":
v2_memory_device = DEVICE_1
elif self.qcfg.v2_memory_device == "cpu":
if self.qcfg.gptaq_memory_device == "auto":
target_device = DEVICE_1
elif self.qcfg.gptaq_memory_device == "cpu":
# slower but >= 4x vram memory reduction
v2_memory_device = CPU
elif isinstance(self.qcfg.v2_memory_device, str):
v2_memory_device = torch.device(self.qcfg.v2_memory_device)
elif isinstance(self.qcfg.v2_memory_device, torch.device):
v2_memory_device = self.qcfg.v2_memory_device
target_device = CPU
elif isinstance(self.qcfg.gptaq_memory_device, str):
target_device = torch.device(self.qcfg.gptaq_memory_device)
elif isinstance(self.qcfg.gptaq_memory_device, torch.device):
target_device = self.qcfg.gptaq_memory_device
else:
v2_memory_device = DEVICE_1
target_device = DEVICE_1

self.native_inp_caches[name] += [inp.to(device=v2_memory_device)]
self.native_inp_caches[name] += [inp.to(device=target_device)]
del inp, out

return tmp
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def quantize(
GPTQProcessor(**args),
]

if self.quantize_config.v2 is True:
if self.quantize_config.gptaq is True:
from ..looper.native_processor import NativeProcessor

# During the deepcopy process, self.prepare_dataset will be deeply copied along with self. However,
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ def assign(mod, device_id):

if qcfg.format == FORMAT.GPTQ:
# validate sym=False v1 loading needs to be protected for models produced with new v2 format codebase
if not qcfg.sym and not qcfg.is_quantized_by_v2():
if not qcfg.sym and not qcfg.is_quantized_by_gptaq():
raise ValueError(
f"Format: Loading of a sym=False model with format={FORMAT.GPTQ} is only supported if produced by gptqmodel version >= {MIN_VERSION_WITH_V2}"
)
Expand Down
41 changes: 32 additions & 9 deletions gptqmodel/models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
META_FIELD_STATIC_GROUPS,
META_FIELD_TRUE_SEQUENTIAL,
META_FIELD_URI,
META_FIELD_V2_ALPHA,
META_FIELD_V2_ENABLED,
META_FIELD_GPTAQ_ALPHA,
META_FIELD_GPTAQ_ENABLED,
META_QUANTIZER_GPTQMODEL,
META_VALUE_URI,
MIN_VERSION_WITH_V2,
Expand Down Expand Up @@ -199,13 +199,13 @@ def save_quantized(
)

self.quantize_config.meta_set(
key=META_FIELD_V2_ENABLED,
value=self.quantize_config.v2
key=META_FIELD_GPTAQ_ENABLED,
value=self.quantize_config.gptaq
)

self.quantize_config.meta_set(
key=META_FIELD_V2_ALPHA,
value=self.quantize_config.v2_alpha
key=META_FIELD_GPTAQ_ALPHA,
value=self.quantize_config.gptaq_alpha
)

self.quantize_config.meta_set(
Expand Down Expand Up @@ -236,9 +236,32 @@ def save_quantized(
config.quantization_config = quantize_config.to_dict()
self.model.config = config

# Save model config, including generation_config
# Use empty state_dict hack to bypass saving weights
self.model.save_pretrained(save_dir, state_dict={}, is_main_process=True)
def strip_attention_impl_fields(target: Any) -> Dict[str, Any]:
removed: Dict[str, Any] = {}
for attr in ("attn_implementation", "_attn_implementation"):
if hasattr(target, attr):
removed[attr] = getattr(target, attr)
delattr(target, attr)
return removed

generation_config = getattr(self.model, "generation_config", None)
removed_config_attention_attrs: Dict[str, Any] = {}
removed_generation_attention_attrs: Dict[str, Any] = {}

try:
removed_config_attention_attrs = strip_attention_impl_fields(self.model.config)
if generation_config is not None:
removed_generation_attention_attrs = strip_attention_impl_fields(generation_config)

# Save model config, including generation_config
# Use empty state_dict hack to bypass saving weights
self.model.save_pretrained(save_dir, state_dict={}, is_main_process=True)
finally:
for attr, value in removed_config_attention_attrs.items():
setattr(self.model.config, attr, value)
if generation_config is not None:
for attr, value in removed_generation_attention_attrs.items():
setattr(generation_config, attr, value)

gen_config_path = os.path.join(save_dir, "generation_config.json")
if sanitize_generation_config_file(gen_config_path):
Expand Down
52 changes: 42 additions & 10 deletions gptqmodel/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@
META_FIELD_MSE = "mse"
META_FIELD_ACT_GROUP_AWARE = "act_group_aware"

META_FIELD_V2_ENABLED = "v2"
META_FIELD_V2_ALPHA = "v2_alpha"
META_FIELD_V2_MEMORY_DEVICE = "v2_memory_device"
META_FIELD_GPTAQ_ENABLED = "gptaq"
META_FIELD_GPTAQ_ALPHA = "gptaq_alpha"
META_FIELD_GPTAQ_MEMORY_DEVICE = "gptaq_memory_device"

ADAPTER_FIELD = "adapter"

Expand Down Expand Up @@ -112,10 +112,19 @@ class VRAMStrategy(str, Enum):
"q_group_size": GROUP_SIZE_FIELD_CODE,
# AWQ compat
"version" : FORMAT_FIELD_CODE,
"v2": "gptaq",
"v2_alpha": "gptaq_alpha",
"v2_memory_device": "gptaq_memory_device",
# map format field (checkpoint_format) to class/code (format)
FORMAT_FIELD_CHECKPOINT: FORMAT_FIELD_CODE,
}

DYNAMIC_FIELD_SYNONYMS = {
"gptaq": ("v2",),
"gptaq_alpha": ("v2_alpha",),
"gptaq_memory_device": ("v2_memory_device",),
}

def dict_scale_dtype_to_str(d: Dict[str, Any]) -> None:
"""
Checks whether the passed dictionary and its nested dicts have a *scale_dtype* key and if it's not None,
Expand Down Expand Up @@ -145,12 +154,23 @@ def dynamic_get(dynamic: Dict[str, Dict[str, Union[int, bool]]], module_name: st
# subkey example: Lora override format: `{ "adapter": { "rank": 512 } }`
if sub_key:
sub_value = overrides.get(key, None)
if sub_value is None and key in DYNAMIC_FIELD_SYNONYMS:
for legacy_key in DYNAMIC_FIELD_SYNONYMS[key]:
if legacy_key in overrides:
sub_value = overrides[legacy_key]
break
if isinstance(sub_value, Dict):
return sub_value.get(sub_key, default)
else:
log.info(f"QuantConfig: Dynamic `sub_key`: `{sub_key}` failed extraction from `sub_value`: `{sub_value}`")
else:
return overrides.get(key, default)
if key in overrides:
return overrides[key]
if key in DYNAMIC_FIELD_SYNONYMS:
for legacy_key in DYNAMIC_FIELD_SYNONYMS[key]:
if legacy_key in overrides:
return overrides[legacy_key]
return default
return default

@dataclass
Expand Down Expand Up @@ -222,10 +242,10 @@ class QuantizeConfig():
# use mock quantization to quantize module so the gptq process can continue and not fail
fail_safe: bool = field(default=False)

# gptq v2* only:
v2: bool = field(default=False)
v2_alpha: float = field(default=0.25)
v2_memory_device: str = field(default="auto")
# gptaq only:
gptaq: bool = field(default=False)
gptaq_alpha: float = field(default=0.25)
gptaq_memory_device: str = field(default="auto")

# awq only:
zero_point: bool = field(default=True)
Expand Down Expand Up @@ -449,8 +469,8 @@ def meta_get_versionable(self, key: str) -> List[Tuple[str, str]]:
result.append((parts[0].lower(), parts[1].lower()))
return result

# is quantized model quantized or packed by gptqmodel version with v2 format code
def is_quantized_by_v2(self) -> bool:
# is quantized model quantized or packed by gptqmodel version with gptaq format code
def is_quantized_by_gptaq(self) -> bool:
# check meta.quantizer
result = self.meta_get_versionable(META_FIELD_QUANTIZER)
if len(result) > 0:
Expand Down Expand Up @@ -550,6 +570,18 @@ def from_quant_config(cls, quantize_cfg, format: str = None):
"QuantizeConfig: config does not contain `sym` (symmetric quantization). This may result in silent errors. Defaulting to `sym=True`."
)

dynamic_overrides = normalized.get("dynamic")
if isinstance(dynamic_overrides, dict):
for overrides in dynamic_overrides.values():
if not isinstance(overrides, dict):
continue
if "v2" in overrides and "gptaq" not in overrides:
overrides["gptaq"] = overrides.pop("v2")
if "v2_alpha" in overrides and "gptaq_alpha" not in overrides:
overrides["gptaq_alpha"] = overrides.pop("v2_alpha")
if "v2_memory_device" in overrides and "gptaq_memory_device" not in overrides:
overrides["gptaq_memory_device"] = overrides.pop("v2_memory_device")

return cls(**normalized)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/quantization/gptqv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def quantize(
Q = torch.zeros_like(W)

Hinv, damp = self.hessian_inverse(H)
P = self.qcfg.v2_alpha * ((self.dXXT @ Hinv.T).triu(diagonal=1)) @ Hinv
P = self.qcfg.gptaq_alpha * ((self.dXXT @ Hinv.T).triu(diagonal=1)) @ Hinv
del self.dXXT

for i1 in range(0, self.columns, blocksize):
Expand Down
5 changes: 3 additions & 2 deletions tests/test_out_of_model_tensor_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ class _DummyQuantizeConfig:
static_groups = False
true_sequential = False
mse = False
v2 = False
v2_alpha = 0.0
gptaq = False
gptaq_alpha = 0.0
gptaq_memory_device = "auto"
act_group_aware = False
adapter = None
dynamic = False
Expand Down
5 changes: 3 additions & 2 deletions tests/test_writer_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ class _DummyQuantizeConfig:
static_groups = False
true_sequential = False
mse = False
v2 = False
v2_alpha = 0.0
gptaq = False
gptaq_alpha = 0.0
gptaq_memory_device = "auto"
act_group_aware = False
adapter = None
dynamic = False
Expand Down