diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 1a8cab28d..e2259cdf7 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -1156,8 +1156,6 @@ def _loop_impl(self, fail_safe: bool = False, **kwargs): shared_kv_cache_dict = {} - replace_module_with_hooked_legacy(self.gptq_model.model, quant_lm_head=self.gptq_model.quantize_config.lm_head) - if self.gptq_model.quantize_config.lm_head: lm_head_module = get_module(self.gptq_model.model, key=self.gptq_model.lm_head) if lm_head_module and isinstance(lm_head_module, torch.nn.Linear): diff --git a/gptqmodel/looper/stage_layer.py b/gptqmodel/looper/stage_layer.py index 1d2cfa759..134329e7e 100644 --- a/gptqmodel/looper/stage_layer.py +++ b/gptqmodel/looper/stage_layer.py @@ -12,7 +12,8 @@ import time from concurrent.futures import as_completed from typing import TYPE_CHECKING, Dict, List, Optional - +from ..nn_modules.hooked_linear import replace_module_with_hooked_legacy +from ..nn_modules.converter import MODULE_CONVERTER_MAP import torch from .. import DEBUG_ON, DEVICE_THREAD_POOL @@ -69,6 +70,14 @@ def run_layer_stage( module = looper.gptq_model.pre_quantize(module) + model_type = looper.gptq_model.model.config.model_type + if model_type in MODULE_CONVERTER_MAP: + converter = MODULE_CONVERTER_MAP[model_type] + module = converter(module, looper.gptq_model.model.config) + + replace_module_with_hooked_legacy(module, quant_lm_head=looper.gptq_model.quantize_config.lm_head) + + layers[layer_index] = module if is_lm_head_module: layer_descriptor = looper.gptq_model.lm_head elif layers_prefix: diff --git a/gptqmodel/models/definitions/gpt_oss.py b/gptqmodel/models/definitions/gpt_oss.py index 9bbe72fac..a9297b06c 100644 --- a/gptqmodel/models/definitions/gpt_oss.py +++ b/gptqmodel/models/definitions/gpt_oss.py @@ -126,8 +126,6 @@ def forward(self, hidden_states): return router_scores, router_indices class GPTOSSGPTQ(BaseQModel): - support_offload_to_disk = False - dynamic_expert_index = "num_local_experts" pre_lm_head_norm_module = "model.norm" @@ -154,42 +152,4 @@ def before_model_load(self, load_quantized_model=False): import transformers.models.gpt_oss.modeling_gpt_oss as gpt_oss_modeling gpt_oss_modeling.GptOssExperts = GptOssExpertsNew - gpt_oss_modeling.GptOssTopKRouter = GptOssTopKRouterNew - - def after_model_load(self, model, load_quantized_model=False): - if load_quantized_model: - return model - - import os - from concurrent.futures import ThreadPoolExecutor - from functools import partial - - import transformers.models.gpt_oss.modeling_gpt_oss as gpt_oss_modeling - from transformers.integrations.hub_kernels import use_kernel_forward_from_hub - - @use_kernel_forward_from_hub("MegaBlocksMoeMLP") - class GptOssMLPNew(nn.Module): - def __init__(self, config, ori_mlp=None): - super().__init__() - self.router = ori_mlp.router - experts_new = GptOssExpertsNew(config, ori_mlp.experts) - self.experts = experts_new - - def forward(self, hidden_states): - router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) - routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) - return routed_out, router_scores - - model = model.to("cpu") - def process_module(name, module, model, config): - if isinstance(module, gpt_oss_modeling.GptOssMLP): - new_module = GptOssMLPNew(config=config, ori_mlp=module) - parent, child = name.rsplit(".", maxsplit=1) - parent = model.get_submodule(parent) - setattr(parent, child, new_module) - - with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: - process_fn = partial(process_module, model=model, config=model.config) - list(executor.map(lambda x: process_fn(x[0], x[1]), model.named_modules())) - - return model + gpt_oss_modeling.GptOssTopKRouter = GptOssTopKRouterNew \ No newline at end of file diff --git a/gptqmodel/models/definitions/llama4.py b/gptqmodel/models/definitions/llama4.py index 16742d2a0..aa3018dbd 100644 --- a/gptqmodel/models/definitions/llama4.py +++ b/gptqmodel/models/definitions/llama4.py @@ -12,7 +12,6 @@ class Llama4QModel(BaseQModel): # some bug in the attention_mask of transformers.modeling_llama4, # so batch quantization for Llama4 is temporarily not supported. support_batch_quantize = False - support_offload_to_disk = False loader = AutoModelForImageTextToText pre_lm_head_norm_module = "language_model.model.norm" @@ -82,88 +81,4 @@ def forward(self, hidden_states: torch.Tensor): return out, router_logits llama4_modeling.Llama4TextMoe = SequentialLlama4TextMoe - - - def after_model_load(self, model, load_quantized_model=False): - if load_quantized_model: - return model - - import os - from concurrent.futures import ThreadPoolExecutor - from functools import partial - - import torch - from transformers.modeling_utils import no_init_weights - from transformers.models.llama4.modeling_llama4 import Llama4TextMLP, Llama4TextMoe - - # adapted/modified from https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modeling/llama4.py - class SequentialLlama4TextExperts(torch.nn.ModuleList): - def __init__(self, config, original): - self.num_experts = original.gate_up_proj.shape[0] - with no_init_weights(): - super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)]) - intermediate_size = original.down_proj.shape[1] - - with torch.inference_mode(): - # Batch process all expert parameters to avoid loops - gate_up_batch = torch.stack([original.gate_up_proj[i] for i in range(self.num_experts)]) - down_batch = torch.stack([original.down_proj[i] for i in range(self.num_experts)]) - - # Batch split and transpose - gate_batch = gate_up_batch[:, :, :intermediate_size].transpose(-2, -1).contiguous() - up_batch = gate_up_batch[:, :, intermediate_size:].transpose(-2, -1).contiguous() - down_batch = down_batch.transpose(-2, -1).contiguous() - - # Batch assignment - for i in range(self.num_experts): - self[i].gate_proj.weight.data = gate_batch[i] - self[i].up_proj.weight.data = up_batch[i] - self[i].down_proj.weight.data = down_batch[i] - - class SequentialLlama4TextMoe(torch.nn.Module): - def __init__(self, config, original): - super().__init__() - self.top_k = config.num_experts_per_tok - self.hidden_dim = config.hidden_size - self.num_experts = config.num_local_experts - self.experts = SequentialLlama4TextExperts(config, original.experts) - self.router = original.router - self.shared_expert = original.shared_expert - - def forward(self, hidden_states: torch.Tensor): - hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = self.router(hidden_states) - if isinstance(router_logits, tuple): - router_scores, router_logits = router_logits - router_scores = router_scores.t() - else: - # transformers < 4.54.0 only returns router_logits - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) - - router_scores = ( - torch.full_like(router_logits, float("-inf")) - .scatter_(1, router_indices, router_top_value) - .transpose(0, 1) - ) - router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) - - out = self.shared_expert(hidden_states) - for i in range(self.num_experts): - out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1) - - return out, router_logits - - model = model.to("cpu") - def process_module(name, module, model, config): - if isinstance(module, Llama4TextMoe): - new_module = SequentialLlama4TextMoe(config=config, original=module) - parent, child = name.rsplit(".", maxsplit=1) - print("replace moe" + name + child) - parent = model.get_submodule(parent) - setattr(parent, child, new_module) - print("cpu count", os.cpu_count()) - with ThreadPoolExecutor(max_workers=8) as executor: - process_fn = partial(process_module, model=model, config=model.config.get_text_config()) - list(executor.map(lambda x: process_fn(x[0], x[1]), model.named_modules())) - - return model + \ No newline at end of file diff --git a/gptqmodel/nn_modules/converter.py b/gptqmodel/nn_modules/converter.py new file mode 100644 index 000000000..02a674319 --- /dev/null +++ b/gptqmodel/nn_modules/converter.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + + +def convert_gpt_oss_expert_converter(module, config): + import torch.nn as nn + import transformers.models.gpt_oss.modeling_gpt_oss as gpt_oss_modeling + from transformers.integrations.hub_kernels import use_kernel_forward_from_hub + from ..models.definitions.gpt_oss import GptOssExpertsNew + + @use_kernel_forward_from_hub("MegaBlocksMoeMLP") + class GptOssMLPNew(nn.Module): + def __init__(self, config, ori_mlp=None): + super().__init__() + self.router = ori_mlp.router + experts_new = GptOssExpertsNew(config, ori_mlp.experts) + self.experts = experts_new + + def forward(self, hidden_states): + router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) + routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) + return routed_out, router_scores + + # loop sub module to replace GptOssMLP with GptOssMLPNew + for name, sub_module in module.named_modules(): + if isinstance(sub_module, gpt_oss_modeling.GptOssMLP): + new_module = GptOssMLPNew(config=config, ori_mlp=sub_module) + setattr(module, name, new_module) + + return module + +def convert_llama4_expert_converter(module, config): + import torch + from transformers.modeling_utils import no_init_weights + from transformers.models.llama4.modeling_llama4 import Llama4TextMLP, Llama4TextMoe + + # adapted/modified from https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modeling/llama4.py + class SequentialLlama4TextExperts(torch.nn.ModuleList): + def __init__(self, config, original): + self.num_experts = original.gate_up_proj.shape[0] + with no_init_weights(): + super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)]) + intermediate_size = original.down_proj.shape[1] + + with torch.inference_mode(): + # Batch process all expert parameters to avoid loops + gate_up_batch = torch.stack([original.gate_up_proj[i] for i in range(self.num_experts)]) + down_batch = torch.stack([original.down_proj[i] for i in range(self.num_experts)]) + + # Batch split and transpose + gate_batch = gate_up_batch[:, :, :intermediate_size].transpose(-2, -1).contiguous() + up_batch = gate_up_batch[:, :, intermediate_size:].transpose(-2, -1).contiguous() + down_batch = down_batch.transpose(-2, -1).contiguous() + + # Batch assignment + for i in range(self.num_experts): + self[i].gate_proj.weight.data = gate_batch[i] + self[i].up_proj.weight.data = up_batch[i] + self[i].down_proj.weight.data = down_batch[i] + + class SequentialLlama4TextMoe(torch.nn.Module): + def __init__(self, config, original): + super().__init__() + self.top_k = config.num_experts_per_tok + self.hidden_dim = config.hidden_size + self.num_experts = config.num_local_experts + self.experts = SequentialLlama4TextExperts(config, original.experts) + self.router = original.router + self.shared_expert = original.shared_expert + + def forward(self, hidden_states: torch.Tensor): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = self.router(hidden_states) + if isinstance(router_logits, tuple): + router_scores, router_logits = router_logits + router_scores = router_scores.t() + else: + # transformers < 4.54.0 only returns router_logits + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) + + router_scores = ( + torch.full_like(router_logits, float("-inf")) + .scatter_(1, router_indices, router_top_value) + .transpose(0, 1) + ) + router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) + + out = self.shared_expert(hidden_states) + for i in range(self.num_experts): + out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1) + + return out, router_logits + + for name, sub_module in module.named_modules(): + if isinstance(sub_module, Llama4TextMoe): + new_module = SequentialLlama4TextMoe(config=config.get_text_config(), original=sub_module) + setattr(module, name, new_module) + + return module + +MODULE_CONVERTER_MAP = { + "llama4": convert_llama4_expert_converter, + "gpt_oss": convert_gpt_oss_expert_converter, +}