Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 58 additions & 48 deletions gptqmodel/looper/awq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,17 @@
from ..quantization.config import FORMAT, METHOD, QuantizeConfig
from ..utils.logger import setup_logger
from ..utils.model import get_module_by_name_prefix, move_to
from ..utils.torch import CPU, torch_sync
from ..utils.torch import CPU, tf32_disable_guard, tf32_enable_guard, torch_sync

log = setup_logger()

class AWQProcessor(LoopProcessor):
def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration, prepare_dataset_func,
calibration_concat_size: Optional[int], batch_size: int, gptq_model, model,
calibration_concat_size: Optional[int], calibration_sort: Optional[str], batch_size: int, gptq_model, model,
logger_board: str = "", require_fwd: bool = True, calculate_w_wq_diff: bool = False):

super().__init__(tokenizer=tokenizer, qcfg=qcfg, calibration=calibration,
calibration_concat_size=calibration_concat_size,
calibration_concat_size=calibration_concat_size, calibration_sort=calibration_sort,
prepare_dataset_func=prepare_dataset_func, batch_size=batch_size,
logger_board=logger_board, require_fwd=require_fwd)

Expand Down Expand Up @@ -147,8 +147,9 @@ def forward(self, *args, **kwargs):
target_device = model_device

print(f"AWQProcessor: model parameters are on meta device, using {target_device} instead")

self.model(samples.to(torch.device(target_device)), use_cache=False)

with tf32_enable_guard():
self.model(samples.to(torch.device(target_device)), use_cache=False)
except ValueError: # work with early exit
pass
modules[0] = modules[0].module # restore
Expand Down Expand Up @@ -295,15 +296,18 @@ def _search_best_scale(
clear_memory(x_sum)

# [STEP 3]: Compute output of module
module_kwargs = self._sanitize_kwargs(kwargs, module2inspect)
with torch.inference_mode():
module_kwargs = self._sanitize_kwargs(kwargs, module2inspect)
fp16_output = self._module_forward(inp, module2inspect, module_kwargs)
fp16_output = fp16_output.clip(torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max)
with tf32_enable_guard():
fp16_output = self._module_forward(inp, module2inspect, module_kwargs)

# [STEP 4]: Compute loss
best_scales, loss = self._compute_best_scale(
inp, w_mean, x_mean, module2inspect, layers, fp16_output, module_kwargs
)
with tf32_disable_guard():
fp16_output = fp16_output.clip(torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max)

# [STEP 4]: Compute loss
best_scales, loss = self._compute_best_scale(
inp, w_mean, x_mean, module2inspect, layers, fp16_output, module_kwargs
)

return (
get_op_name(module, prev_op),
Expand All @@ -326,10 +330,11 @@ def layer_quantize(self, module: Module, device: torch.device, named_childs: Dic
# self.gptq_model.move_embed(common_device)

# Transformers >= 4.48.0 requires positional embeddings should be computed before forward pass
if (self.module_kwargs.get("position_embeddings") is None):
self.module_kwargs["position_embeddings"] = self.model.model.rotary_emb(
self.inps, self.module_kwargs["position_ids"]
)
if self.module_kwargs.get("position_embeddings") is None:
with tf32_enable_guard():
self.module_kwargs["position_embeddings"] = self.model.model.rotary_emb(
self.inps, self.module_kwargs["position_ids"]
)

# TODO FIX ME: ???
if (self.module_kwargs.get('attention_mask') is None):
Expand Down Expand Up @@ -358,31 +363,34 @@ def layer_quantize(self, module: Module, device: torch.device, named_childs: Dic
clear_memory()

# [STEP 2]: Compute and apply scale list
module_config: List[Dict] = self.gptq_model.awq_get_modules_for_scaling(
module, input_feat, self.module_kwargs
)
scales_list = [
self._search_best_scale(module, **layer)
for layer in module_config
]
apply_scale(module, scales_list, input_feat_dict=input_feat)
with tf32_disable_guard():
module_config: List[Dict] = self.gptq_model.awq_get_modules_for_scaling(
module, input_feat, self.module_kwargs
)
scales_list = [
self._search_best_scale(module, **layer)
for layer in module_config
]
apply_scale(module, scales_list, input_feat_dict=input_feat)
scales_list = append_str_prefix(
scales_list, get_op_name(self.model, module) + "."
)

# [STEP 3]: Compute and apply clipping list
if self.apply_clip:
clip_list = self._search_best_clip(
module, named_linears, input_feat
)
apply_clip(module, clip_list)
with tf32_disable_guard():
clip_list = self._search_best_clip(
module, named_linears, input_feat
)
apply_clip(module, clip_list)
clip_list = append_str_prefix(
clip_list, get_op_name(self.model, module) + "."
)

# [STEP 4]: Quantize weights
if not self.export_compatible:
self._apply_quant(module, named_childs, start, scales_list)
with tf32_disable_guard():
self._apply_quant(module, named_childs, start, scales_list)

clear_memory()

Expand All @@ -397,9 +405,10 @@ def _search_best_clip(self, layer, named_linears, input_feat):
continue

named_linears[name].to(get_best_device())
max_val = self._compute_best_clip(
named_linears[name].weight, input_feat[name]
)
with tf32_disable_guard():
max_val = self._compute_best_clip(
named_linears[name].weight, input_feat[name]
)
clip_list.append((name, max_val))
named_linears[name].cpu()

Expand Down Expand Up @@ -615,25 +624,26 @@ def _compute_loss(
def _module_forward(
self, x: torch.Tensor, module: torch.nn.Module, module_kwargs: Dict
) -> torch.Tensor:
if self.n_parallel_calib_samples is None:
# runs through all samples at once
module_output = module(x, **module_kwargs)
if isinstance(module_output, tuple):
module_output = module_output[0]
else:
# memory efficiently runs through all calibration samples
# but only n_parallel_calib_samples at a time
module_output = []
partitioned_inputs = torch.split(x, self.n_parallel_calib_samples)
for x_partial in partitioned_inputs:
partial_output = module(x_partial, **module_kwargs)
with tf32_enable_guard():
if self.n_parallel_calib_samples is None:
# runs through all samples at once
module_output = module(x, **module_kwargs)
if isinstance(module_output, tuple):
module_output = module_output[0]
else:
# memory efficiently runs through all calibration samples
# but only n_parallel_calib_samples at a time
module_output = []
partitioned_inputs = torch.split(x, self.n_parallel_calib_samples)
for x_partial in partitioned_inputs:
partial_output = module(x_partial, **module_kwargs)

if isinstance(partial_output, tuple):
partial_output = partial_output[0]
if isinstance(partial_output, tuple):
partial_output = partial_output[0]

module_output.append(partial_output.cpu())
module_output.append(partial_output.cpu())

module_output = torch.cat(module_output, dim=0)
module_output = torch.cat(module_output, dim=0)

return module_output

Expand Down
42 changes: 22 additions & 20 deletions gptqmodel/looper/eora_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,14 @@ def is_skipped(self, module: NamedModule) -> bool:

def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]:
def tmp(module, input: Tuple[torch.Tensor, ...], output: torch.Tensor):
self.eora_process_input(
input=input,
name=name,
eigen_scaling_diag_matrix=self.eigen_scaling_diag_matrix,
sample_size=self.num_batches,
device=module.target_device,
)
with tf32_disable_guard():
self.eora_process_input(
input=input,
name=name,
eigen_scaling_diag_matrix=self.eigen_scaling_diag_matrix,
sample_size=self.num_batches,
device=module.target_device,
)
return tmp

def pre_process_streaming(self, module: NamedModule):
Expand Down Expand Up @@ -131,19 +132,20 @@ def process(self, module: NamedModule):
assert w_wq_delta.dtype == torch.float32, f"w_wq_delta dtype: {w_wq_delta.dtype}"

# log.info(f"EoRA: module native dtype = `{module_native_dtype}")
A, B = self.eora_compute_lora(
w_wq_delta=w_wq_delta,
name=module.name,
eigen_scaling_diag_matrix=eigen_scaling_diag_matrix,
rank=module.adapter_cfg.rank,
dtype=module.module_dtype,
device=module.target_device,
)

del eigen_scaling_diag_matrix

# wq with A/B applied
computed_wq = wq + (B @ A)
with tf32_disable_guard():
A, B = self.eora_compute_lora(
w_wq_delta=w_wq_delta,
name=module.name,
eigen_scaling_diag_matrix=eigen_scaling_diag_matrix,
rank=module.adapter_cfg.rank,
dtype=module.module_dtype,
device=module.target_device,
)

del eigen_scaling_diag_matrix

# wq with A/B applied
computed_wq = wq + (B @ A)

module.state.update({
"wq": move_to(wq, device=CPU, stream=self.stream),
Expand Down
8 changes: 5 additions & 3 deletions gptqmodel/looper/gptq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..utils.importer import select_quant_linear
from ..utils.logger import setup_logger
from ..utils.model import create_quant_module, find_modules, move_to, pack_model, pack_module
from ..utils.torch import HAS_CUDA, torch_streamCtx, torch_sync
from ..utils.torch import HAS_CUDA, tf32_disable_guard, torch_streamCtx, torch_sync

log = setup_logger()
lock = threading.Lock()
Expand Down Expand Up @@ -103,7 +103,8 @@ def is_skipped(self, module: NamedModule) -> bool:
def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]:
def tmp(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
g = self.tasks[name] # noqa: F821
g.add_batch(inp[0].data, out.data) # noqa: F821
with tf32_disable_guard():
g.add_batch(inp[0].data, out.data) # noqa: F821
del inp, out
return tmp

Expand All @@ -125,7 +126,8 @@ def process(self, module: NamedModule):
with self.lock:
g = self.tasks[module.name]

wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, nsamples = g.quantize()
with tf32_disable_guard():
wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, nsamples = g.quantize()

q_scales = q_scales.to(CPU)
q_zeros = q_zeros.to(CPU)
Expand Down
8 changes: 5 additions & 3 deletions gptqmodel/looper/qqq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ..quantization.qqq import QQQ
from ..utils.logger import setup_logger
from ..utils.model import create_quant_module, find_modules, move_to, pack_model, pack_module
from ..utils.torch import CPU, DEVICE_0, torch_streamCtx, torch_sync
from ..utils.torch import CPU, DEVICE_0, tf32_disable_guard, torch_streamCtx, torch_sync

log = setup_logger()

Expand Down Expand Up @@ -103,7 +103,8 @@ def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tenso
def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
# gptq is mutable.
q = self.tasks[name] # noqa: F821
q.add_batch(inp[0].data, out.data) # noqa: F821
with tf32_disable_guard():
q.add_batch(inp[0].data, out.data) # noqa: F821
return tmp

def pre_process_streaming(self, module: NamedModule):
Expand All @@ -121,7 +122,8 @@ def process(self, module: NamedModule):
# logger.info(f"Quantizing module START: {name}, {gptq[name].shape()}")
## Need to return the quantized_weight for offloading
q = qqq[module.name]
wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, q_scales_extra, nsamples = q.quantize()
with tf32_disable_guard():
wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, q_scales_extra, nsamples = q.quantize()
## Assign the quantized weight to the weight
#gptq[name].layer.weight.data = q_full_weight.to(device=gptq[name].device)

Expand Down
23 changes: 18 additions & 5 deletions gptqmodel/nn_modules/hooked_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
from torch import nn

from ..utils.logger import setup_logger
from ..utils.torch import tf32_enable_guard


log = setup_logger()


class StopForward(Exception):
"""Signal an intentional early stop of the forward pass."""
pass
Expand All @@ -37,9 +39,12 @@ def from_conv1d(m: transformers.Conv1D):
custom.bias = m.bias
return custom

@torch.inference_mode()
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(device=self.weight.data.device)
output = super().forward(input)
with tf32_enable_guard():
output = super().forward(input)

if self.forward_hook:
self.forward_hook(self, (input,), output)
if self.forward_hook_last:
Expand Down Expand Up @@ -93,9 +98,11 @@ def from_conv1d(m: torch.nn.Conv1d):
custom.bias = m.bias
return custom

@torch.inference_mode()
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(device=self.weight.data.device)
output = super().forward(input)
with tf32_enable_guard():
output = super().forward(input)
if self.forward_hook:
self.forward_hook(self, (input,), output)
if self.forward_hook_last:
Expand Down Expand Up @@ -150,9 +157,11 @@ def from_conv2d(m: torch.nn.Conv2d):
custom.bias = m.bias
return custom

@torch.inference_mode()
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(device=self.weight.data.device)
output = super().forward(input)
with tf32_enable_guard():
output = super().forward(input)
if self.forward_hook:
self.forward_hook(self, (input,), output)
if self.forward_hook_last:
Expand All @@ -175,9 +184,11 @@ def from_conv1d(conv1d: transformers.Conv1D):
custom.bias = conv1d.bias
return custom

@torch.inference_mode()
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(device=self.weight.data.device)
output = super().forward(input)
with tf32_enable_guard():
output = super().forward(input)
if self.forward_hook:
self.forward_hook(self, (input,), output)
if self.forward_hook_last:
Expand All @@ -201,9 +212,11 @@ def from_linear(linear: torch.nn.Linear):
custom_linear.bias = linear.bias
return custom_linear

@torch.inference_mode()
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(device=self.weight.data.device)
output = super().forward(input)
with tf32_enable_guard():
output = super().forward(input)
if self.forward_hook:
self.forward_hook(self, (input,), output)
if self.forward_hook_last:
Expand Down
2 changes: 0 additions & 2 deletions gptqmodel/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
log = setup_logger()

lock = threading.Lock()
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

# TODO: is there a buffer init threading init bug in torch.linalg?
# bypass strange threading bug by warming up torch.linalg.cholesky to setup internal setup calls
Expand Down
3 changes: 0 additions & 3 deletions gptqmodel/quantization/qqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@

DEBUG = False

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

log = setup_logger()

def quantize(x, scale, zero, maxq, sym, groupsize):
Expand Down
Loading