Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,8 +1156,6 @@ def _loop_impl(self, fail_safe: bool = False, **kwargs):

shared_kv_cache_dict = {}

replace_module_with_hooked_legacy(self.gptq_model.model, quant_lm_head=self.gptq_model.quantize_config.lm_head)

if self.gptq_model.quantize_config.lm_head:
lm_head_module = get_module(self.gptq_model.model, key=self.gptq_model.lm_head)
if lm_head_module and isinstance(lm_head_module, torch.nn.Linear):
Expand Down
11 changes: 10 additions & 1 deletion gptqmodel/looper/stage_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import time
from concurrent.futures import as_completed
from typing import TYPE_CHECKING, Dict, List, Optional

from ..nn_modules.hooked_linear import replace_module_with_hooked_legacy
from ..nn_modules.converter import MODULE_CONVERTER_MAP
import torch

from .. import DEBUG_ON, DEVICE_THREAD_POOL
Expand Down Expand Up @@ -69,6 +70,14 @@ def run_layer_stage(

module = looper.gptq_model.pre_quantize(module)

model_type = looper.gptq_model.model.config.model_type
if model_type in MODULE_CONVERTER_MAP:
converter = MODULE_CONVERTER_MAP[model_type]
module = converter(module, looper.gptq_model.model.config)

replace_module_with_hooked_legacy(module, quant_lm_head=looper.gptq_model.quantize_config.lm_head)

layers[layer_index] = module
if is_lm_head_module:
layer_descriptor = looper.gptq_model.lm_head
elif layers_prefix:
Expand Down
42 changes: 1 addition & 41 deletions gptqmodel/models/definitions/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ def forward(self, hidden_states):
return router_scores, router_indices

class GPTOSSGPTQ(BaseQModel):
support_offload_to_disk = False

dynamic_expert_index = "num_local_experts"

pre_lm_head_norm_module = "model.norm"
Expand All @@ -154,42 +152,4 @@ def before_model_load(self, load_quantized_model=False):
import transformers.models.gpt_oss.modeling_gpt_oss as gpt_oss_modeling

gpt_oss_modeling.GptOssExperts = GptOssExpertsNew
gpt_oss_modeling.GptOssTopKRouter = GptOssTopKRouterNew

def after_model_load(self, model, load_quantized_model=False):
if load_quantized_model:
return model

import os
from concurrent.futures import ThreadPoolExecutor
from functools import partial

import transformers.models.gpt_oss.modeling_gpt_oss as gpt_oss_modeling
from transformers.integrations.hub_kernels import use_kernel_forward_from_hub

@use_kernel_forward_from_hub("MegaBlocksMoeMLP")
class GptOssMLPNew(nn.Module):
def __init__(self, config, ori_mlp=None):
super().__init__()
self.router = ori_mlp.router
experts_new = GptOssExpertsNew(config, ori_mlp.experts)
self.experts = experts_new

def forward(self, hidden_states):
router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len)
routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
return routed_out, router_scores

model = model.to("cpu")
def process_module(name, module, model, config):
if isinstance(module, gpt_oss_modeling.GptOssMLP):
new_module = GptOssMLPNew(config=config, ori_mlp=module)
parent, child = name.rsplit(".", maxsplit=1)
parent = model.get_submodule(parent)
setattr(parent, child, new_module)

with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
process_fn = partial(process_module, model=model, config=model.config)
list(executor.map(lambda x: process_fn(x[0], x[1]), model.named_modules()))

return model
gpt_oss_modeling.GptOssTopKRouter = GptOssTopKRouterNew
87 changes: 1 addition & 86 deletions gptqmodel/models/definitions/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ class Llama4QModel(BaseQModel):
# some bug in the attention_mask of transformers.modeling_llama4,
# so batch quantization for Llama4 is temporarily not supported.
support_batch_quantize = False
support_offload_to_disk = False
loader = AutoModelForImageTextToText

pre_lm_head_norm_module = "language_model.model.norm"
Expand Down Expand Up @@ -82,88 +81,4 @@ def forward(self, hidden_states: torch.Tensor):
return out, router_logits

llama4_modeling.Llama4TextMoe = SequentialLlama4TextMoe


def after_model_load(self, model, load_quantized_model=False):
if load_quantized_model:
return model

import os
from concurrent.futures import ThreadPoolExecutor
from functools import partial

import torch
from transformers.modeling_utils import no_init_weights
from transformers.models.llama4.modeling_llama4 import Llama4TextMLP, Llama4TextMoe

# adapted/modified from https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modeling/llama4.py
class SequentialLlama4TextExperts(torch.nn.ModuleList):
def __init__(self, config, original):
self.num_experts = original.gate_up_proj.shape[0]
with no_init_weights():
super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)])
intermediate_size = original.down_proj.shape[1]

with torch.inference_mode():
# Batch process all expert parameters to avoid loops
gate_up_batch = torch.stack([original.gate_up_proj[i] for i in range(self.num_experts)])
down_batch = torch.stack([original.down_proj[i] for i in range(self.num_experts)])

# Batch split and transpose
gate_batch = gate_up_batch[:, :, :intermediate_size].transpose(-2, -1).contiguous()
up_batch = gate_up_batch[:, :, intermediate_size:].transpose(-2, -1).contiguous()
down_batch = down_batch.transpose(-2, -1).contiguous()

# Batch assignment
for i in range(self.num_experts):
self[i].gate_proj.weight.data = gate_batch[i]
self[i].up_proj.weight.data = up_batch[i]
self[i].down_proj.weight.data = down_batch[i]

class SequentialLlama4TextMoe(torch.nn.Module):
def __init__(self, config, original):
super().__init__()
self.top_k = config.num_experts_per_tok
self.hidden_dim = config.hidden_size
self.num_experts = config.num_local_experts
self.experts = SequentialLlama4TextExperts(config, original.experts)
self.router = original.router
self.shared_expert = original.shared_expert

def forward(self, hidden_states: torch.Tensor):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = self.router(hidden_states)
if isinstance(router_logits, tuple):
router_scores, router_logits = router_logits
router_scores = router_scores.t()
else:
# transformers < 4.54.0 only returns router_logits
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)

router_scores = (
torch.full_like(router_logits, float("-inf"))
.scatter_(1, router_indices, router_top_value)
.transpose(0, 1)
)
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)

out = self.shared_expert(hidden_states)
for i in range(self.num_experts):
out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1)

return out, router_logits

model = model.to("cpu")
def process_module(name, module, model, config):
if isinstance(module, Llama4TextMoe):
new_module = SequentialLlama4TextMoe(config=config, original=module)
parent, child = name.rsplit(".", maxsplit=1)
print("replace moe" + name + child)
parent = model.get_submodule(parent)
setattr(parent, child, new_module)
print("cpu count", os.cpu_count())
with ThreadPoolExecutor(max_workers=8) as executor:
process_fn = partial(process_module, model=model, config=model.config.get_text_config())
list(executor.map(lambda x: process_fn(x[0], x[1]), model.named_modules()))

return model

106 changes: 106 additions & 0 deletions gptqmodel/nn_modules/converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium


def convert_gpt_oss_expert_converter(module, config):
import torch.nn as nn
import transformers.models.gpt_oss.modeling_gpt_oss as gpt_oss_modeling
from transformers.integrations.hub_kernels import use_kernel_forward_from_hub
from ..models.definitions.gpt_oss import GptOssExpertsNew

@use_kernel_forward_from_hub("MegaBlocksMoeMLP")
class GptOssMLPNew(nn.Module):
def __init__(self, config, ori_mlp=None):
super().__init__()
self.router = ori_mlp.router
experts_new = GptOssExpertsNew(config, ori_mlp.experts)
self.experts = experts_new

def forward(self, hidden_states):
router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len)
routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
return routed_out, router_scores

# loop sub module to replace GptOssMLP with GptOssMLPNew
for name, sub_module in module.named_modules():
if isinstance(sub_module, gpt_oss_modeling.GptOssMLP):
new_module = GptOssMLPNew(config=config, ori_mlp=sub_module)
setattr(module, name, new_module)

return module

def convert_llama4_expert_converter(module, config):
import torch
from transformers.modeling_utils import no_init_weights
from transformers.models.llama4.modeling_llama4 import Llama4TextMLP, Llama4TextMoe

# adapted/modified from https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modeling/llama4.py
class SequentialLlama4TextExperts(torch.nn.ModuleList):
def __init__(self, config, original):
self.num_experts = original.gate_up_proj.shape[0]
with no_init_weights():
super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)])
intermediate_size = original.down_proj.shape[1]

with torch.inference_mode():
# Batch process all expert parameters to avoid loops
gate_up_batch = torch.stack([original.gate_up_proj[i] for i in range(self.num_experts)])
down_batch = torch.stack([original.down_proj[i] for i in range(self.num_experts)])

# Batch split and transpose
gate_batch = gate_up_batch[:, :, :intermediate_size].transpose(-2, -1).contiguous()
up_batch = gate_up_batch[:, :, intermediate_size:].transpose(-2, -1).contiguous()
down_batch = down_batch.transpose(-2, -1).contiguous()

# Batch assignment
for i in range(self.num_experts):
self[i].gate_proj.weight.data = gate_batch[i]
self[i].up_proj.weight.data = up_batch[i]
self[i].down_proj.weight.data = down_batch[i]

class SequentialLlama4TextMoe(torch.nn.Module):
def __init__(self, config, original):
super().__init__()
self.top_k = config.num_experts_per_tok
self.hidden_dim = config.hidden_size
self.num_experts = config.num_local_experts
self.experts = SequentialLlama4TextExperts(config, original.experts)
self.router = original.router
self.shared_expert = original.shared_expert

def forward(self, hidden_states: torch.Tensor):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = self.router(hidden_states)
if isinstance(router_logits, tuple):
router_scores, router_logits = router_logits
router_scores = router_scores.t()
else:
# transformers < 4.54.0 only returns router_logits
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)

router_scores = (
torch.full_like(router_logits, float("-inf"))
.scatter_(1, router_indices, router_top_value)
.transpose(0, 1)
)
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)

out = self.shared_expert(hidden_states)
for i in range(self.num_experts):
out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1)

return out, router_logits

for name, sub_module in module.named_modules():
if isinstance(sub_module, Llama4TextMoe):
new_module = SequentialLlama4TextMoe(config=config.get_text_config(), original=sub_module)
setattr(module, name, new_module)

return module

MODULE_CONVERTER_MAP = {
"llama4": convert_llama4_expert_converter,
"gpt_oss": convert_gpt_oss_expert_converter,
}