## qwen MOE trace and analysis module

In [None]:
# # import datasets
# from datasets import load_dataset
# data_id = "wikimedia/wikipedia"
# sub_set_id = "20231101.zh-classical"
# split = "train"
# dataset = load_dataset(data_id, sub_set_id, split=split)

In [1]:
from qwen_v1 import Qwen3MoeDecoderLayerTimed as v1Timed
from qwen_v2 import Qwen3MoeDecoderLayerTimed as v2Timed

version = 'v2'

def modify_qwen3_moe_block(type__: str):
    from transformers.models.qwen3_moe import modeling_qwen3_moe as qmoe
    if type__ == 'v1':
        qmoe.Qwen3MoeDecoderLayer = v1Timed
    elif type__ == 'v2':
        qmoe.Qwen3MoeDecoderLayer = v2Timed
    else:
        pass


modify_qwen3_moe_block(version)

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig


# model_name = "Qwen/Qwen3-30B-A3B-Instruct-2507"
model_name = "Qwen/Qwen3-30B-A3B"

# load the tokenizer and the model
cfg = AutoConfig.from_pretrained(model_name)
if version == 'v2':
    from common import init_timer_registry
    init_timer_registry(
        num_layers=cfg.num_hidden_layers, keep_history=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    dtype="auto",
    device_map="cuda:0"
)
model.eval()

Loading checkpoint shards: 100%|██████████| 16/16 [00:14<00:00,  1.07it/s]


Qwen3MoeForCausalLM(
  (model): Qwen3MoeModel(
    (embed_tokens): Embedding(151936, 2048)
    (layers): ModuleList(
      (0-47): 48 x Qwen3MoeDecoderLayerTimed(
        (self_attn): Qwen3MoeAttention(
          (q_proj): Linear(in_features=2048, out_features=4096, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=4096, out_features=2048, bias=False)
          (q_norm): Qwen3MoeRMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3MoeRMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MoeSparseMoeBlockV2(
          (gate): Linear(in_features=2048, out_features=128, bias=False)
          (experts): ModuleList(
            (0-127): 128 x Qwen3MoeMLP(
              (gate_proj): Linear(in_features=2048, out_features=768, bias=False)
              (up_proj): Linear(in_features=2048, out_features=768, bias=False)
              (down_proj):

In [3]:
text_list = ["explain the qwen"]
tokenizer.padding_side = "left"
input_001 = tokenizer(text_list, return_tensors="pt", padding=True, truncation=True).to(model.device)

input_001


{'input_ids': tensor([[94344,   279,  2804, 16948]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1]], device='cuda:0')}

## for v1 to test

In [None]:
from torch.profiler import profile, ProfilerActivity
import torch
import qwen_v1
# warm up
with torch.no_grad():
    model_output = model(**input_001, use_cache=True)
qwen_v1.reset_timers()
next_token = torch.argmax(model_output.logits[:, -1, :], dim=-1, keepdim=True)
past_kv = model_output.past_key_values

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=False, with_stack=True, profile_memory=False
) as prof:
    with torch.no_grad():
        _ = model(next_token, past_key_values=past_kv, use_cache=True)
qwen_v1.show_res()
prof.export_chrome_trace("trace/trace_qwen3_moe_v1.json")

## for v2 to test

In [4]:
import common
common._TREG

TimerRegistry(num_layers=48, keep_history=True, timers={})

In [7]:
import common
import torch
# warm up

common.warmup_model(model, tokenizer, text_list, 10)

init_timer_registry(model.config.num_hidden_layers, keep_history=True)
with torch.no_grad():
    _ = model(**input_001)  # prefill
    _ = model.generate(**input_001, max_new_tokens=64)  # decode

torch.cuda.synchronize()

# 打印结果
common.print_timers_summary()

=== Per-layer (ms) ===
layer	attn(PF)	mlp(PF)	gating(PF)	softmax(PF)	expert(PF)	||	attn(DEC)	mlp(DEC)	gating(DEC)	softmax(DEC)	expert(DEC)
L00	0.869		8.183		0.089		0.130		7.635		||	25.405		81.369		2.720		4.037		68.602
L01	0.777		6.521		0.082		0.148		6.104		||	22.920		80.188		2.451		3.958		68.169
L02	0.752		7.083		0.089		0.128		6.684		||	22.606		78.927		2.414		3.915		67.211
L03	0.762		7.364		0.081		0.129		6.966		||	22.655		79.240		2.380		3.947		67.577
L04	0.837		6.965		0.087		0.139		6.542		||	22.668		79.210		2.372		3.905		67.503
L05	0.768		6.752		0.083		0.127		6.362		||	22.831		79.252		2.391		3.928		67.537
L06	0.754		6.881		0.079		0.134		6.488		||	22.663		79.039		2.393		3.952		67.296
L07	0.742		6.616		0.085		0.136		6.203		||	22.561		79.387		2.368		3.947		67.712
L08	0.757		7.101		0.079		0.128		6.715		||	22.709		81.541		2.370		3.980		69.817
L09	0.743		6.282		0.080		0.126		5.899		||	22.549		78.926		2.376		3.923		67.267
L10	0.763		7.348		0.077		0.128		6.964		||	22.706		79.406		2.384		3.935	