In [2]:
from oss_v1 import GptOssDecoderLayer as v1_timed

def modify_oss_decoder_layer(__type: str):
    if __type == "v1":
        from transformers.models.gpt_oss import modeling_gpt_oss as gmoe
        gmoe.GptOssDecoderLayer = v1_timed
    else:
        raise ValueError(f"Unknown OSS decoder layer type: {__type}")
modify_oss_decoder_layer("v1")

In [3]:
import torch
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from transformers import Mxfp4Config
model_id = "openai/gpt-oss-20b"
quantization_config = Mxfp4Config(dequantize=False)
cfg = AutoConfig.from_pretrained(model_id)
from common import init_timer_registry
init_timer_registry(cfg.num_hidden_layers, keep_history=True)
model = AutoModelForCausalLM.from_pretrained(model_id,quantization_config=quantization_config, config=cfg, device_map="cuda:0")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model.eval()

MXFP4 quantization requires Triton and kernels installed: CUDA requires Triton >= 3.4.0, XPU requires Triton >= 3.5.0, we will default to dequantizing the model to bf16


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

GptOssForCausalLM(
  (model): GptOssModel(
    (embed_tokens): Embedding(201088, 2880, padding_idx=199999)
    (layers): ModuleList(
      (0-23): 24 x GptOssDecoderLayer(
        (self_attn): GptOssAttention(
          (q_proj): Linear(in_features=2880, out_features=4096, bias=True)
          (k_proj): Linear(in_features=2880, out_features=512, bias=True)
          (v_proj): Linear(in_features=2880, out_features=512, bias=True)
          (o_proj): Linear(in_features=4096, out_features=2880, bias=True)
        )
        (mlp): GptOssMLPV1(
          (router): GptOssTopKRouter()
          (experts): GptOssExperts()
        )
        (input_layernorm): GptOssRMSNorm((2880,), eps=1e-05)
        (post_attention_layernorm): GptOssRMSNorm((2880,), eps=1e-05)
      )
    )
    (norm): GptOssRMSNorm((2880,), eps=1e-05)
    (rotary_emb): GptOssRotaryEmbedding()
  )
  (lm_head): Linear(in_features=2880, out_features=201088, bias=False)
)

In [4]:
import common
common._TREG
import torch
# warm up
text_list = ["explain the qwen"]
tokenizer.padding_side = "left"
input_001 = tokenizer(text_list, return_tensors="pt", padding=True, truncation=True).to(model.device)

common.warmup_model(model, tokenizer, text_list, 10)

init_timer_registry(model.config.num_hidden_layers, keep_history=True)
with torch.no_grad():
    _ = model(**input_001)  # prefill
    _ = model.generate(**input_001, max_new_tokens=64)  # decode

torch.cuda.synchronize()

# 打印结果
common.print_timers_summary()

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


=== Per-layer (ms) ===
layer	attn(PF)	mlp(PF)	gating(PF)	softmax(PF)	expert(PF)	||	attn(DEC)	mlp(DEC)	gating(DEC)	softmax(DEC)	expert(DEC)
L00	1.344		2.459		0.280		0.000		1.921		||	44.423		73.316		8.839		0.000		58.724
L01	1.093		2.335		0.262		0.000		1.906		||	32.604		71.425		7.991		0.000		58.354
L02	1.116		2.306		0.249		0.000		1.893		||	32.788		71.328		7.941		0.000		58.316
L03	1.195		2.312		0.250		0.000		1.899		||	32.293		71.166		7.957		0.000		58.282
L04	1.086		2.313		0.247		0.000		1.892		||	32.099		70.990		7.792		0.000		58.286
L05	1.090		2.297		0.247		0.000		1.894		||	31.462		70.780		7.657		0.000		58.161
L06	1.116		2.301		0.245		0.000		1.896		||	32.028		71.068		7.764		0.000		58.319
L07	1.191		2.344		0.279		0.000		1.899		||	32.087		71.615		7.936		0.000		58.485
L08	1.116		2.301		0.246		0.000		1.893		||	32.352		70.866		7.700		0.000		58.234
L09	1.062		2.338		0.259		0.000		1.895		||	31.356		71.189		7.896		0.000		58.296
L10	1.096		2.306		0.249		0.000		1.896		||	32.010		70.685		7.661		0.000	

In [None]:
prompt = "what is the difference between mixtral and qwen moe models?"
messages = [
    {"role": "user", "content": prompt},
]

inputs = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt",
    return_dict=True,
).to(model.device)

outputs = model.generate(
    **inputs,
    max_new_tokens=16384,
    temperature=0.7
)
print(tokenizer.decode(outputs[0]))