Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions gptqmodel/looper/awq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ..quantization.awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV, WQLinear_GEMVFast, WQLinear_Marlin
from ..quantization.awq.quantize.scale import apply_clip, apply_scale
from ..quantization.awq.utils.module import append_str_prefix, get_op_name, set_op_by_name
from ..quantization.awq.utils.utils import clear_memory, get_best_device
from ..quantization.awq.utils.utils import get_best_device
from ..quantization.config import FORMAT, METHOD, QuantizeConfig
from ..utils.logger import setup_logger
from ..utils.model import get_module_by_name_prefix, move_to
Expand Down Expand Up @@ -166,8 +166,6 @@ def forward(self, *args, **kwargs):
# we no longer need embed, reduce vram
self.gptq_model.move_embed("cpu")

clear_memory()

if layer_kwargs.get("attention_mask") is not None:
layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to(
best_device
Expand Down Expand Up @@ -271,7 +269,7 @@ def _search_best_scale(
w_scale = w_scale.view(org_shape)
# Gets the average rescaled magnitude for each output channel
w_mean = w_scale.mean(0)
clear_memory(weight)
del weight

# [STEP 2]: Compute per-channel mean of the input activation with chunking
# move inp to cpu to avoid memory leak
Expand All @@ -293,7 +291,7 @@ def _search_best_scale(
x_sum += chunk_sum.to(inp.device)

x_mean = (x_sum / num_elements).to(inp.dtype)
clear_memory(x_sum)
del x_sum

# [STEP 3]: Compute output of module
module_kwargs = self._sanitize_kwargs(kwargs, module2inspect)
Expand Down Expand Up @@ -360,8 +358,6 @@ def layer_quantize(self, module: Module, device: torch.device, named_childs: Dic

input_feat = self._get_input_feat(module, named_linears)

clear_memory()

# [STEP 2]: Compute and apply scale list
with tf32_disable_guard():
module_config: List[Dict] = self.gptq_model.awq_get_modules_for_scaling(
Expand Down Expand Up @@ -392,8 +388,6 @@ def layer_quantize(self, module: Module, device: torch.device, named_childs: Dic
with tf32_disable_guard():
self._apply_quant(module, named_childs, start, scales_list)

clear_memory()

@torch.inference_mode()
def _search_best_clip(self, layer, named_linears, input_feat):
clip_list = []
Expand Down Expand Up @@ -469,9 +463,8 @@ def _compute_best_clip(
best_max_val_all.append(best_max_val)

best_max_val = torch.cat(best_max_val_all, dim=0)

clear_memory(input_feat)
clear_memory(org_out)
del input_feat
del org_out

return best_max_val.squeeze(1)

Expand Down Expand Up @@ -705,7 +698,6 @@ def _apply_quant(self, module, named_linears: Dict[str, NamedModule], start_time
linear_layer.cpu()
q_linear.to(next(module.parameters()).device)
set_op_by_name(module, name, q_linear)
clear_memory()

# records
duration = time.time() - start_time
Expand Down
7 changes: 0 additions & 7 deletions gptqmodel/quantization/awq/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,6 @@ def set_module_name(model, name, value):
setattr(parent, child_name, value)


def clear_memory(weight=None):
if weight is not None:
del weight
# gc.collect()
# torch.cuda.empty_cache()


def compute_memory_used_pct(device):
memory_used = torch.cuda.max_memory_allocated(device) / (1024**3)
memory_pct = (
Expand Down