diff --git a/examples/quantization/basic_usage_wikitext2.py b/examples/quantization/basic_usage_wikitext2.py index 5a737ba56..bf94e6dc0 100644 --- a/examples/quantization/basic_usage_wikitext2.py +++ b/examples/quantization/basic_usage_wikitext2.py @@ -20,7 +20,7 @@ def get_wikitext2(tokenizer, nsamples, seqlen): return [tokenizer(example["text"]) for example in traindata.select(range(nsamples))] -@torch.no_grad() +@torch.inference_mode() def calculate_avg_ppl(model, tokenizer): from gptqmodel.utils.perplexity import Perplexity diff --git a/gptqmodel/looper/awq_processor.py b/gptqmodel/looper/awq_processor.py index 93a39ec0c..eae540407 100644 --- a/gptqmodel/looper/awq_processor.py +++ b/gptqmodel/looper/awq_processor.py @@ -236,7 +236,7 @@ def cat_and_assert(k, v): input_feat = {k: cat_and_assert(k, v) for k, v in input_feat.items()} return input_feat - @torch.no_grad() + @torch.inference_mode() def _search_best_scale( self, module, @@ -296,7 +296,7 @@ def _search_best_scale( clear_memory(x_sum) # [STEP 3]: Compute output of module - with torch.no_grad(): + with torch.inference_mode(): module_kwargs = self._sanitize_kwargs(kwargs, module2inspect) fp16_output = self._module_forward(inp, module2inspect, module_kwargs) fp16_output = fp16_output.clip(torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max) @@ -387,7 +387,7 @@ def layer_quantize(self, module: Module, device: torch.device, named_childs: Dic clear_memory() - @torch.no_grad() + @torch.inference_mode() def _search_best_clip(self, layer, named_linears, input_feat): clip_list = [] avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"] @@ -406,7 +406,7 @@ def _search_best_clip(self, layer, named_linears, input_feat): return clip_list - @torch.no_grad() + @torch.inference_mode() def _compute_best_clip( self, w: torch.Tensor, @@ -580,7 +580,7 @@ def _compute_best_scale( return best_scales.detach().cpu(), best_error - @torch.no_grad() + @torch.inference_mode() def _compute_loss( self, fp16_output: torch.Tensor, @@ -612,7 +612,7 @@ def _compute_loss( return loss - @torch.no_grad() + @torch.inference_mode() def _module_forward( self, x: torch.Tensor, module: torch.nn.Module, module_kwargs: Dict ) -> torch.Tensor: diff --git a/gptqmodel/models/definitions/gpt_oss.py b/gptqmodel/models/definitions/gpt_oss.py index 090a8747d..a4cfbeb69 100644 --- a/gptqmodel/models/definitions/gpt_oss.py +++ b/gptqmodel/models/definitions/gpt_oss.py @@ -43,7 +43,7 @@ def __init__(self, config, ori_experts=None): d_w_src = ori_experts.down_proj[i].detach().t().contiguous() d_b_src = ori_experts.down_proj_bias[i].detach() - with torch.no_grad(): + with torch.inference_mode(): tgt_gu_w.copy_(gu_w_src) tgt_gu_b.copy_(gu_b_src) tgt_d_w.copy_(d_w_src) @@ -113,7 +113,7 @@ def __init__(self, config, ori_router=None): self.bias = nn.Parameter(torch.empty(self.num_experts)) if ori_router is not None: - with torch.no_grad(): + with torch.inference_mode(): self.weight.copy_(ori_router.weight.detach()) self.bias.copy_(ori_router.bias.detach()) diff --git a/gptqmodel/models/definitions/llama4.py b/gptqmodel/models/definitions/llama4.py index be48d8d85..146fd90bc 100644 --- a/gptqmodel/models/definitions/llama4.py +++ b/gptqmodel/models/definitions/llama4.py @@ -103,7 +103,7 @@ def __init__(self, config, original): super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)]) intermediate_size = original.down_proj.shape[1] - with torch.no_grad(): + with torch.inference_mode(): # Batch process all expert parameters to avoid loops gate_up_batch = torch.stack([original.gate_up_proj[i] for i in range(self.num_experts)]) down_batch = torch.stack([original.down_proj[i] for i in range(self.num_experts)]) diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index b9dc99589..6a7dfa975 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -592,6 +592,7 @@ def _pack_rows_3(int32_blk_32xN: t.Tensor, dst: t.Tensor, dst_rows_base: int): # ---------- thread task: process a single [i0,i1) block ---------- block_in = max(word_bits, (block_in // word_bits) * word_bits) + @t.inference_mode() def _process_block(i0: int, i1: int): blk = i1 - i0 # [out, blk] diff --git a/gptqmodel/nn_modules/qlinear/awq_gemm.py b/gptqmodel/nn_modules/qlinear/awq_gemm.py index f0240b3e0..032cfae5e 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemm.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemm.py @@ -98,7 +98,7 @@ def forward(self, x: torch.Tensor): self.out_features, ) else: - with torch.no_grad(): + with torch.inference_mode(): out = WQLinearMMFunction.apply( x, self.qweight, diff --git a/gptqmodel/nn_modules/qlinear/awq_gemm_ipex.py b/gptqmodel/nn_modules/qlinear/awq_gemm_ipex.py index cb62c64bc..d3a5c75ea 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemm_ipex.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemm_ipex.py @@ -115,7 +115,7 @@ def forward(self, x: torch.Tensor): out_shape = x.shape[:-1] + (self.out_features,) if hasattr(self, "ipex_linear"): - with torch.no_grad(): + with torch.inference_mode(): out = self.ipex_linear(x) else: out = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.bits, self.group_size).to(x.dtype) diff --git a/gptqmodel/quantization/awq/modules/linear/gemm.py b/gptqmodel/quantization/awq/modules/linear/gemm.py index 6b6563f1e..bd28d5553 100644 --- a/gptqmodel/quantization/awq/modules/linear/gemm.py +++ b/gptqmodel/quantization/awq/modules/linear/gemm.py @@ -270,7 +270,7 @@ def forward(self, x): self.out_features, ) else: - with torch.no_grad(): + with torch.inference_mode(): out = WQLinearMMFunction.apply( x, self.qweight, diff --git a/gptqmodel/quantization/awq/modules/linear/gemm_ipex.py b/gptqmodel/quantization/awq/modules/linear/gemm_ipex.py index 95c975995..a36a1ce76 100644 --- a/gptqmodel/quantization/awq/modules/linear/gemm_ipex.py +++ b/gptqmodel/quantization/awq/modules/linear/gemm_ipex.py @@ -101,7 +101,7 @@ def forward(self, x): self.init_ipex = True if hasattr(self, "ipex_linear"): - with torch.no_grad(): + with torch.inference_mode(): outputs = self.ipex_linear(x) else: outputs = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.w_bit, self.group_size).to(x.dtype) diff --git a/gptqmodel/quantization/awq/modules/linear/gemv.py b/gptqmodel/quantization/awq/modules/linear/gemv.py index 071cb6912..68f0dcc82 100644 --- a/gptqmodel/quantization/awq/modules/linear/gemv.py +++ b/gptqmodel/quantization/awq/modules/linear/gemv.py @@ -153,7 +153,7 @@ def from_linear( awq_linear.qzeros = qzeros return awq_linear - @torch.no_grad() + @torch.inference_mode() def forward(self, x): if awq_ext is None: raise ModuleNotFoundError("External AWQ kernels are not properly installed." + msg) diff --git a/gptqmodel/quantization/awq/modules/linear/gemv_fast.py b/gptqmodel/quantization/awq/modules/linear/gemv_fast.py index c716efa47..8a5bca338 100644 --- a/gptqmodel/quantization/awq/modules/linear/gemv_fast.py +++ b/gptqmodel/quantization/awq/modules/linear/gemv_fast.py @@ -182,7 +182,7 @@ def from_linear( return awq_linear - @torch.no_grad() + @torch.inference_mode() def forward(self, x): if awq_v2_ext is None: raise ModuleNotFoundError("External AWQ V2 kernels are not properly installed." + msg) diff --git a/gptqmodel/quantization/awq/modules/linear/marlin.py b/gptqmodel/quantization/awq/modules/linear/marlin.py index 626788d3d..9cc921a13 100644 --- a/gptqmodel/quantization/awq/modules/linear/marlin.py +++ b/gptqmodel/quantization/awq/modules/linear/marlin.py @@ -168,7 +168,7 @@ def post_init(self): persistent=False, ) - @torch.no_grad() + @torch.inference_mode() def forward(self, x): assert hasattr(self, "workspace"), ( "module.post_init() must be called before module.forward(). " diff --git a/gptqmodel/quantization/awq/quantize/scale.py b/gptqmodel/quantization/awq/quantize/scale.py index 09f9c5cd3..b34b85070 100644 --- a/gptqmodel/quantization/awq/quantize/scale.py +++ b/gptqmodel/quantization/awq/quantize/scale.py @@ -22,7 +22,7 @@ ] -@torch.no_grad() +@torch.inference_mode() def apply_clip(module, clip_list: Tuple[str, torch.Tensor]): for name, max_val in clip_list: layer: nn.Linear = get_op_by_name(module, name) @@ -85,7 +85,7 @@ def apply_scale(module, scales_list, input_feat_dict=None): scales.cpu() -@torch.no_grad() +@torch.inference_mode() def scale_ln_fcs(ln: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor): if not isinstance(fcs, list): fcs = [fcs] @@ -114,7 +114,7 @@ def scale_ln_fcs(ln: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor): assert torch.isnan(p).sum() == 0 -@torch.no_grad() +@torch.inference_mode() def scale_fc_fc(fc1: nn.Linear, fc2: nn.Linear, scales: torch.Tensor): assert isinstance(fc1, nn.Linear) assert isinstance(fc2, nn.Linear) @@ -133,7 +133,7 @@ def scale_fc_fc(fc1: nn.Linear, fc2: nn.Linear, scales: torch.Tensor): assert torch.isnan(p).sum() == 0 -@torch.no_grad() +@torch.inference_mode() def scale_fc_fcs(fc1: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor): if not isinstance(fcs, list): fcs = [fcs] @@ -154,7 +154,7 @@ def scale_fc_fcs(fc1: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor): assert torch.isnan(p).sum() == 0 -@torch.no_grad() +@torch.inference_mode() def scale_gelu_fc(gelu: allowed_act_fns, fc: nn.Linear, scales: torch.Tensor): assert any(isinstance(gelu, t) for t in allowed_act_fns) assert isinstance(fc, nn.Linear) diff --git a/gptqmodel/utils/bitblas.py b/gptqmodel/utils/bitblas.py index 945f3c509..a37a67146 100644 --- a/gptqmodel/utils/bitblas.py +++ b/gptqmodel/utils/bitblas.py @@ -59,7 +59,7 @@ def prepare_model_for_bitblas_load( return model -@torch.no_grad() +@torch.inference_mode() def convert_to_bitblas(model, model_quantlinear, qcfg: QuantizeConfig, sym: bool, desc_act: bool, repack: bool): """ Converts GPTQ-packed weights to the Bitblas format. diff --git a/gptqmodel/utils/mmlupro.py b/gptqmodel/utils/mmlupro.py index 367d13432..7f939b754 100644 --- a/gptqmodel/utils/mmlupro.py +++ b/gptqmodel/utils/mmlupro.py @@ -121,7 +121,7 @@ def batch_inference(model, tokenizer, inference_batchs, batch_size): for batch in pb: input_tensor = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=max_model_length, padding_side='left').to(model.device) - with torch.no_grad(): + with torch.inference_mode(): outputs = model.generate( input_ids=input_tensor["input_ids"], tokenizer=tokenizer, @@ -166,7 +166,7 @@ def save_res(res, output_path): return accu, corr, wrong -@torch.no_grad() +@torch.inference_mode() def eval_cot(subject, model, tokenizer, val_df, test_df, output_path, ntrain, batch_size): global choices log.info("evaluating " + subject) diff --git a/gptqmodel/utils/offload.py b/gptqmodel/utils/offload.py index 908139aca..49020d833 100644 --- a/gptqmodel/utils/offload.py +++ b/gptqmodel/utils/offload.py @@ -214,7 +214,7 @@ def _restore_leaves_from_weights_map(mod: nn.Module, device: torch.device, dtype except Exception: return False - with torch.no_grad(): + with torch.inference_mode(): for name, tensor, is_param in list(_iter_leaf_tensors(mod, include_buffers=True)): is_meta = getattr(tensor, "is_meta", False) or tensor.device is META if not is_meta: @@ -265,7 +265,7 @@ def undo_offload_to_disk( offload_dirs: Set[str] = set() # 1) Materialize all offloaded leaves as real tensors on the target device/dtype. - with torch.no_grad(): + with torch.inference_mode(): for sub in module.modules(): if not has_offloaded_params(sub): continue diff --git a/gptqmodel/utils/openai_server.py b/gptqmodel/utils/openai_server.py index b3d15459e..e84653ecb 100644 --- a/gptqmodel/utils/openai_server.py +++ b/gptqmodel/utils/openai_server.py @@ -58,7 +58,7 @@ async def create_completion(request: OpenAiRequest): return_tensors='pt').to(self.model.device) do_sample = True if request.temperature != 0.0 else False - with torch.no_grad(): + with torch.inference_mode(): outputs = self.model.generate( inputs_tensor, max_length=inputs_tensor.shape[0] + request.max_tokens, diff --git a/gptqmodel/utils/perplexity.py b/gptqmodel/utils/perplexity.py index 6886ab4fd..51b31456a 100644 --- a/gptqmodel/utils/perplexity.py +++ b/gptqmodel/utils/perplexity.py @@ -242,6 +242,6 @@ def _compute_batch_logits(self, tokens, batch_start, batch_size): The logits for the batch of tokens. """ # Compute the logits without keeping track of gradients - with torch.no_grad(): + with torch.inference_mode(): outputs = self._model(tokens[:, batch_start: batch_start + batch_size]) return outputs.logits.detach() diff --git a/gptqmodel/utils/structure.py b/gptqmodel/utils/structure.py index 1b203588d..f187ee096 100644 --- a/gptqmodel/utils/structure.py +++ b/gptqmodel/utils/structure.py @@ -536,7 +536,7 @@ def alias_from_turtle_for_submodule( # ---- copy params/buffers CPU->GPU into target_submodule (your existing code) ---- t_params = dict(target_submodule.named_parameters(recurse=True)) s_params = dict(src_sub.named_parameters(recurse=True)) - with torch.no_grad(): + with torch.inference_mode(): for name, s_p in s_params.items(): t_p = t_params.get(name) if t_p is None or t_p.shape != s_p.shape: diff --git a/tests/models/model_test.py b/tests/models/model_test.py index 9b47c5b93..d8b4e09a4 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -173,9 +173,10 @@ def check_kernel(self, model, expected_kernels): def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", need_eval=True, batch_size: int = QUANT_BATCH_SIZE, **kwargs): quantize_config = QuantizeConfig( + quant_method=self.METHOD, + format=self.FORMAT, bits=self.BITS, group_size=self.GROUP_SIZE, - format=self.FORMAT, desc_act=self.DESC_ACT if not self.ACT_GROUP_AWARE else False, act_group_aware=self.ACT_GROUP_AWARE, fail_safe=self.FAIL_SAFE, @@ -296,7 +297,7 @@ def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, del for framework, tasks in task_groups.items(): log.info(f"TEST: EVAL starting: backend = {self.LOAD_BACKEND}") results = GPTQModel.eval( - model_or_id_or_path=model, + model_or_id_or_path=model.model_local_path, llm_backend="vllm" if self.USE_VLLM else "gptqmodel", model_args=model_args, output_path=tmp_dir, diff --git a/tests/test_q4_exllama_v1.py b/tests/test_q4_exllama_v1.py index da5fd839a..f61866d50 100644 --- a/tests/test_q4_exllama_v1.py +++ b/tests/test_q4_exllama_v1.py @@ -1122,7 +1122,7 @@ def test_exllama(self): inp = torch.rand(1, m, k, dtype=torch.float16).to(device) - with torch.no_grad(): + with torch.inference_mode(): res = linear(inp)[0][0] reference = REFERENCE.to(device) diff --git a/tests/test_q4_exllama_v2.py b/tests/test_q4_exllama_v2.py index 57a9834a0..458b7419d 100644 --- a/tests/test_q4_exllama_v2.py +++ b/tests/test_q4_exllama_v2.py @@ -70,7 +70,7 @@ def test_exllamav2(self): inp = torch.rand(1, m, k, dtype=torch.float16).to(device) - with torch.no_grad(): + with torch.inference_mode(): res = linear(inp)[0][0] reference = REFERENCE.to(device)