Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 43 additions & 50 deletions gptqmodel/looper/awq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ def forward(self, *args, **kwargs):

print(f"AWQProcessor: model parameters are on meta device, using {target_device} instead")

with tf32_enable_guard():
self.model(samples.to(torch.device(target_device)), use_cache=False)
self.model(samples.to(torch.device(target_device)), use_cache=False)
except ValueError: # work with early exit
pass
modules[0] = modules[0].module # restore
Expand Down Expand Up @@ -283,16 +282,15 @@ def _search_best_scale(

# [STEP 3]: Compute output of module
module_kwargs = self._sanitize_kwargs(kwargs, module2inspect)
with ctx(torch.inference_mode(), tf32_enable_guard()):
with ctx(torch.inference_mode()):
fp16_output = self._module_forward(inp, module2inspect, module_kwargs)

with tf32_disable_guard():
fp16_output = fp16_output.clip(torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max)
fp16_output = fp16_output.clip(torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max)

# [STEP 4]: Compute loss
best_scales, loss = self._compute_best_scale(
inp, w_mean, x_mean, module2inspect, layers, fp16_output, module_kwargs
)
# [STEP 4]: Compute loss
best_scales, loss = self._compute_best_scale(
inp, w_mean, x_mean, module2inspect, layers, fp16_output, module_kwargs
)

return (
get_op_name(module, prev_op),
Expand All @@ -316,10 +314,9 @@ def layer_quantize(self, module: Module, device: torch.device, named_childs: Dic

# Transformers >= 4.48.0 requires positional embeddings should be computed before forward pass
if self.module_kwargs.get("position_embeddings") is None:
with tf32_enable_guard():
self.module_kwargs["position_embeddings"] = self.model.model.rotary_emb(
self.inps, self.module_kwargs["position_ids"]
)
self.module_kwargs["position_embeddings"] = self.model.model.rotary_emb(
self.inps, self.module_kwargs["position_ids"]
)

# TODO FIX ME: ???
if (self.module_kwargs.get('attention_mask') is None):
Expand All @@ -346,34 +343,31 @@ def layer_quantize(self, module: Module, device: torch.device, named_childs: Dic
input_feat = self._get_input_feat(module, named_linears)

# [STEP 2]: Compute and apply scale list
with tf32_disable_guard():
module_config: List[Dict] = self.gptq_model.awq_get_modules_for_scaling(
module, input_feat, self.module_kwargs
)
scales_list = [
self._search_best_scale(module, **layer)
for layer in module_config
]
apply_scale(module, scales_list, input_feat_dict=input_feat)
module_config: List[Dict] = self.gptq_model.awq_get_modules_for_scaling(
module, input_feat, self.module_kwargs
)
scales_list = [
self._search_best_scale(module, **layer)
for layer in module_config
]
apply_scale(module, scales_list, input_feat_dict=input_feat)
scales_list = append_str_prefix(
scales_list, get_op_name(self.model, module) + "."
)

# [STEP 3]: Compute and apply clipping list
if self.apply_clip:
with tf32_disable_guard():
clip_list = self._search_best_clip(
module, named_linears, input_feat
)
apply_clip(module, clip_list)
clip_list = self._search_best_clip(
module, named_linears, input_feat
)
apply_clip(module, clip_list)
clip_list = append_str_prefix(
clip_list, get_op_name(self.model, module) + "."
)

# [STEP 4]: Quantize weights
if not self.export_compatible:
with tf32_disable_guard():
self._apply_quant(module, named_childs, start, scales_list)
self._apply_quant(module, named_childs, start, scales_list)

@torch.inference_mode()
def _search_best_clip(self, layer, named_linears, input_feat):
Expand All @@ -386,10 +380,10 @@ def _search_best_clip(self, layer, named_linears, input_feat):
continue

named_linears[name].to(get_best_device())
with tf32_disable_guard():
max_val = self._compute_best_clip(
named_linears[name].weight, input_feat[name]
)

max_val = self._compute_best_clip(
named_linears[name].weight, input_feat[name]
)
clip_list.append((name, max_val))
named_linears[name].cpu()

Expand Down Expand Up @@ -604,26 +598,25 @@ def _compute_loss(
def _module_forward(
self, x: torch.Tensor, module: torch.nn.Module, module_kwargs: Dict
) -> torch.Tensor:
with tf32_enable_guard():
if self.n_parallel_calib_samples is None:
# runs through all samples at once
module_output = module(x, **module_kwargs)
if isinstance(module_output, tuple):
module_output = module_output[0]
else:
# memory efficiently runs through all calibration samples
# but only n_parallel_calib_samples at a time
module_output = []
partitioned_inputs = torch.split(x, self.n_parallel_calib_samples)
for x_partial in partitioned_inputs:
partial_output = module(x_partial, **module_kwargs)
if self.n_parallel_calib_samples is None:
# runs through all samples at once
module_output = module(x, **module_kwargs)
if isinstance(module_output, tuple):
module_output = module_output[0]
else:
# memory efficiently runs through all calibration samples
# but only n_parallel_calib_samples at a time
module_output = []
partitioned_inputs = torch.split(x, self.n_parallel_calib_samples)
for x_partial in partitioned_inputs:
partial_output = module(x_partial, **module_kwargs)

if isinstance(partial_output, tuple):
partial_output = partial_output[0]
if isinstance(partial_output, tuple):
partial_output = partial_output[0]

module_output.append(partial_output.cpu())
module_output.append(partial_output.cpu())

module_output = torch.cat(module_output, dim=0)
module_output = torch.cat(module_output, dim=0)

return module_output

Expand Down
38 changes: 18 additions & 20 deletions gptqmodel/looper/eora_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,13 @@ def is_skipped(self, module: NamedModule) -> bool:

def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]:
def tmp(module, input: Tuple[torch.Tensor, ...], output: torch.Tensor):
with tf32_disable_guard():
batch_index = self.current_batch_index()
batch, contribution, scale = self.eora_process_input(
input=input,
name=name,
sample_size=self.num_batches,
device=module.weight.data.device,
)
batch_index = self.current_batch_index()
batch, contribution, scale = self.eora_process_input(
input=input,
name=name,
sample_size=self.num_batches,
device=module.weight.data.device,
)

self._accumulate_eora_contribution(
name=name,
Expand Down Expand Up @@ -232,20 +231,19 @@ def process(self, module: NamedModule):
assert w_wq_delta.dtype == torch.float32, f"w_wq_delta dtype: {w_wq_delta.dtype}"

# log.info(f"EoRA: module native dtype = `{module_native_dtype}")
with tf32_disable_guard():
A, B = self.eora_compute_lora(
w_wq_delta=w_wq_delta,
name=module.name,
eigen_scaling_diag_matrix=eigen_scaling_diag_matrix,
rank=module.adapter_cfg.rank,
dtype=module.module_dtype,
device=module.weight.data.device,
)
A, B = self.eora_compute_lora(
w_wq_delta=w_wq_delta,
name=module.name,
eigen_scaling_diag_matrix=eigen_scaling_diag_matrix,
rank=module.adapter_cfg.rank,
dtype=module.module_dtype,
device=module.weight.data.device,
)

del eigen_scaling_diag_matrix
del eigen_scaling_diag_matrix

# wq with A/B applied
computed_wq = (wq_device + (B @ A)).to(dtype=wq.dtype, device=target_device)
# wq with A/B applied
computed_wq = (wq_device + (B @ A)).to(dtype=wq.dtype, device=target_device)

if pad_cols:
computed_wq_trim = computed_wq[:, :original_cols]
Expand Down
6 changes: 2 additions & 4 deletions gptqmodel/looper/gptq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@ def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tenso
def tmp(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
g = self.tasks[name] # noqa: F821
batch_idx = self.current_batch_index()
with tf32_disable_guard():
g.add_batch(inp[0].data, out.data, batch_index=batch_idx) # noqa: F821
g.add_batch(inp[0].data, out.data, batch_index=batch_idx) # noqa: F821
del inp, out
return tmp

Expand Down Expand Up @@ -173,8 +172,7 @@ def process(self, module: NamedModule):
f"while processing '{module.full_name}'."
)

with tf32_disable_guard():
wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, nsamples = g.quantize()
wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, nsamples = g.quantize()

module.stream_state_payload_to_cpu(
{
Expand Down
8 changes: 6 additions & 2 deletions gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
)
from ..utils.model import find_modules, get_module, get_module_by_name_prefix, move_to, nested_move_to
from ..utils.offload import offload_to_disk
from ..utils.torch import (CPU, META, timed_gc_collect, torch_sync)
from ..utils.torch import (CPU, META, timed_gc_collect, torch_sync, tf32_high_precision_guard)
from .. import DEVICE_THREAD_POOL
from .awq_processor import AWQProcessor
from .qqq_processor import QQQProcessor
Expand Down Expand Up @@ -899,8 +899,12 @@ def store_input_hook(module, args, kwargs):

return result

@torch.inference_mode()
def loop(self, fail_safe: bool = False, **kwargs):
with tf32_high_precision_guard():
return self._loop_impl(fail_safe=fail_safe, **kwargs)

@torch.inference_mode()
def _loop_impl(self, fail_safe: bool = False, **kwargs):
if self.gptq_model.quantize_config.lm_head:
if self.gptq_model.model.config.tie_word_embeddings and hasattr(self.gptq_model.model.model, "_tied_weights_keys"):
tied_keys = self.gptq_model.model._tied_weights_keys
Expand Down
14 changes: 9 additions & 5 deletions gptqmodel/looper/named_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,20 +187,24 @@ def _stream_tensor_dict(
store_callback(host_map)
return host_map

stream = torch.cuda.Stream(device=first.device)
done_event = torch.cuda.Event(enable_timing=False)
host_map: Dict[str, torch.Tensor] = {}

with torch.cuda.stream(stream):
copy_device = first.device
compute_stream = torch.cuda.current_stream(device=copy_device)
copy_stream = torch.cuda.Stream(device=copy_device)
done_event = torch.cuda.Event(enable_timing=False, blocking=False)

with torch.cuda.stream(copy_stream):
copy_stream.wait_stream(compute_stream)
for name, tensor in filtered.items():
src = tensor.detach()
host = host_pool.acquire(src.shape, src.dtype, src.layout)
host.copy_(src, non_blocking=True)
host_map[name] = host
done_event.record(stream)
done_event.record(copy_stream)

with self._state_lock:
events = self.state.setdefault("streaming_events", [])
events.append({"event": done_event, "stream": stream})
events.append({"event": done_event, "stream": copy_stream})
store_callback(host_map)
return host_map
6 changes: 2 additions & 4 deletions gptqmodel/looper/qqq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tenso
def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
# gptq is mutable.
q = self.tasks[name] # noqa: F821
with tf32_disable_guard():
q.add_batch(inp[0].data, out.data) # noqa: F821
q.add_batch(inp[0].data, out.data) # noqa: F821
return tmp

def process(self, module: NamedModule):
Expand All @@ -108,8 +107,7 @@ def process(self, module: NamedModule):
# logger.info(f"Quantizing module START: {name}, {gptq[name].shape()}")
## Need to return the quantized_weight for offloading
q = qqq[module.name]
with tf32_disable_guard():
wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, q_scales_extra, nsamples = q.quantize()
wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, q_scales_extra, nsamples = q.quantize()

q_scales = q_scales.to(CPU)
q_zeros = q_zeros.to(CPU)
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/models/definitions/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from ..base import BaseQModel
from ...quantization import METHOD
from ..base import BaseQModel


class Qwen3MoeQModel(BaseQModel):
Expand Down
16 changes: 5 additions & 11 deletions gptqmodel/nn_modules/hooked_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from torch import nn

from ..utils.logger import setup_logger
from ..utils.torch import tf32_enable_guard


log = setup_logger()
Expand Down Expand Up @@ -42,8 +41,7 @@ def from_conv1d(m: transformers.Conv1D):
@torch.inference_mode()
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(device=self.weight.data.device)
with tf32_enable_guard():
output = super().forward(input)
output = super().forward(input)

if self.forward_hook:
self.forward_hook(self, (input,), output)
Expand Down Expand Up @@ -101,8 +99,7 @@ def from_conv1d(m: torch.nn.Conv1d):
@torch.inference_mode()
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(device=self.weight.data.device)
with tf32_enable_guard():
output = super().forward(input)
output = super().forward(input)
if self.forward_hook:
self.forward_hook(self, (input,), output)
if self.forward_hook_last:
Expand Down Expand Up @@ -160,8 +157,7 @@ def from_conv2d(m: torch.nn.Conv2d):
@torch.inference_mode()
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(device=self.weight.data.device)
with tf32_enable_guard():
output = super().forward(input)
output = super().forward(input)
if self.forward_hook:
self.forward_hook(self, (input,), output)
if self.forward_hook_last:
Expand All @@ -187,8 +183,7 @@ def from_conv1d(conv1d: transformers.Conv1D):
@torch.inference_mode()
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(device=self.weight.data.device)
with tf32_enable_guard():
output = super().forward(input)
output = super().forward(input)
if self.forward_hook:
self.forward_hook(self, (input,), output)
if self.forward_hook_last:
Expand All @@ -215,8 +210,7 @@ def from_linear(linear: torch.nn.Linear):
@torch.inference_mode()
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(device=self.weight.data.device)
with tf32_enable_guard():
output = super().forward(input)
output = super().forward(input)
if self.forward_hook:
self.forward_hook(self, (input,), output)
if self.forward_hook_last:
Expand Down
Loading