## 下载模型

In [2]:
!hf download Qwen/Qwen3-0.6B --local-dir ./Qwen3-0.6B

Fetching 10 files: 100%|██████████████████████| 10/10 [00:00<00:00, 3632.38it/s]
/Users/comistrymo/Qwen3-0.6B Inference/Qwen3-0.6B


## 打印模型结构

In [3]:
from safetensors.torch import load_file
state_dict = load_file('./Qwen3-0.6B/model.safetensors')
for k,p in state_dict.items():
    print(f"key:{k} shape:{p.shape}")

key:lm_head.weight shape:torch.Size([151936, 1024])
key:model.embed_tokens.weight shape:torch.Size([151936, 1024])
key:model.layers.0.input_layernorm.weight shape:torch.Size([1024])
key:model.layers.0.mlp.down_proj.weight shape:torch.Size([1024, 3072])
key:model.layers.0.mlp.gate_proj.weight shape:torch.Size([3072, 1024])
key:model.layers.0.mlp.up_proj.weight shape:torch.Size([3072, 1024])
key:model.layers.0.post_attention_layernorm.weight shape:torch.Size([1024])
key:model.layers.0.self_attn.k_norm.weight shape:torch.Size([128])
key:model.layers.0.self_attn.k_proj.weight shape:torch.Size([1024, 1024])
key:model.layers.0.self_attn.o_proj.weight shape:torch.Size([1024, 2048])
key:model.layers.0.self_attn.q_norm.weight shape:torch.Size([128])
key:model.layers.0.self_attn.q_proj.weight shape:torch.Size([2048, 1024])
key:model.layers.0.self_attn.v_proj.weight shape:torch.Size([1024, 1024])
key:model.layers.1.input_layernorm.weight shape:torch.Size([1024])
key:model.layers.1.mlp.down_proj.w

## 打印模型结构（详细）

In [4]:
from view_model import view_model_info

view_model_info("./Qwen3-0.6B/")

正在加载模型配置...
正在构建无内存占用的模型骨架...

--- 模型结构 (无内存占用) ---
Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (up_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (down_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen3RMSNo

## 定义模型结构

In [5]:
from transformers import AutoConfig
import torch
import torch.nn as nn
import torch.nn.functional as F

# 为什么给RMSNorm就不传整个config呢

class Qwen3Model(nn.Module):
    def __init__(self, config):
        print("begin init model")
        super().__init__()
        self.config = config
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        # 为什么这几个模块需要自定义
        self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)])
        self.norm = Qwen3RMSNorm(config.hidden_size)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.lm_head.weight = self.embed_tokens.weight

    def forward(self, input_ids):
        bsz, q_len = input_ids.shape
        pos_ids = torch.arange(q_len, dtype=torch.long, device=input_ids.device).unsqueeze(0)
        casual_mask = torch.triu(
            torch.full((q_len, q_len), float('-inf'), dtype=torch.float32, device=input_ids.device),
            diagonal=1
        ).unsqueeze(0).unsqueeze(0).expand(bsz, 1, q_len, q_len)
        hidden_states = self.embed_tokens(input_ids)
        for layer in self.layers:
            hidden_states = layer(hidden_states, pos_ids=pos_ids, attn_mask=casual_mask)
        hidden_states = self.norm(hidden_states)
        logits = self.lm_head(hidden_states)
        return logits
        
class Qwen3DecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self_attn = Attention(config)
        self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size)
        self.mlp = MLP(config)
        self.input_layernorm = Qwen3RMSNorm(config.hidden_size)

    def forward(self, hidden_states, pos_ids=None, attn_mask=None):
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(hidden_states, pos_ids, attn_mask)
        hidden_states = residual + hidden_states
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states

class Qwen3RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps
    
    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        var = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(var + self.eps)
        rms_res = self.weight * hidden_states.to(input_dtype)
        return rms_res
    
class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.num_kv_heads = config.num_key_value_heads
        self.head_dim = config.head_dim
        self.num_kv_groups = self.num_heads // self.num_kv_heads  #GQA

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

        self.q_norm = Qwen3RMSNorm(config.head_dim, eps=config.rms_norm_eps)
        self.k_norm = Qwen3RMSNorm(config.head_dim, eps=config.rms_norm_eps)
        self.rope_theta = config.rope_theta

    def forward(self, hidden_states, pos_ids=None, attn_mask=None):
        bsz, q_len, _=  hidden_states.size()
        q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(hidden_states).view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(hidden_states).view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        q = self.q_norm(q)
        k = self.k_norm(k)
        q, k = apply_rope(q, k, pos_ids, self.head_dim, self.rope_theta)
        if self.num_kv_groups > 1:
            k = k.unsqueeze(2).expand(-1, -1, self.num_kv_groups, -1, -1).flatten(1, 2)
            v = v.unsqueeze(2).expand(-1, -1, self.num_kv_groups, -1, -1).flatten(1, 2)

        attn_score = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if attn_mask is not None:
            attn_score = attn_score + attn_mask
        attn_score = torch.softmax(attn_score, dim=-1, dtype=torch.float32).to(hidden_states.dtype)
        attn_output = torch.matmul(attn_score, v)
        attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
        attn_output = self.o_proj(attn_output)
        return attn_output

def apply_rope(q, k, position_ids, head_dim, rope_theta=1000000.0):
    device = q.device
    inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) / head_dim))
    freqs = position_ids.unsqueeze(-1).float() * inv_freq.unsqueeze(0).unsqueeze(0)
    emb = torch.cat([freqs, freqs], dim=-1)
    cos = emb.cos().unsqueeze(1).to(q.dtype)
    sin = emb.sin().unsqueeze(1).to(q.dtype)

    def rotate_half(x):
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)

    def forward(self, x):
        ret=self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
        return ret

### 推理

In [10]:
from tokenizers import Tokenizer
config = AutoConfig.from_pretrained("./Qwen3-0.6B/")

model = Qwen3Model(config)
new_state_dict = {}
for k, v in state_dict.items():
    if k.startswith("model."):
        new_state_dict[k[len("model") + 1:]] = v
    else:
        new_state_dict[k] = v

model.load_state_dict(new_state_dict, strict=True)
model.eval()

tokenizer = Tokenizer.from_file(str("./Qwen3-0.6B/tokenizer.json"))
message="<|im_start|>user你好，我是ComistryMo，请多指教！<|im_end|><|im_start|>assistant"
input_ids = tokenizer.encode(message).ids
input_ids = torch.tensor([input_ids], dtype=torch.long)

with torch.no_grad():
    while True:
        logits = model(input_ids)
        next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
        if next_token.item() == 151645:
            break
        input_ids = torch.cat([input_ids, next_token], dim=1)

output_text = tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True)
print(output_text)

begin init model


### 性能对比

In [9]:
import time
import torch
from tokenizers import Tokenizer
from transformers import AutoConfig, AutoModelForCausalLM
from safetensors.torch import load_file

device = "mps" if torch.mps.is_available() else "cpu"
path = "./Qwen3-0.6B/"

config = AutoConfig.from_pretrained(path)
my_model = Qwen3Model(config).to(device) 

state_dict = load_file(f"{path}/model.safetensors", device=device)
new_state_dict = {}
for k, v in state_dict.items():
    if k.startswith("model."):
        new_state_dict[k[len("model") + 1:]] = v
    else:
        new_state_dict[k] = v
my_model.load_state_dict(new_state_dict, strict=True)
my_model.eval()

official_model = AutoModelForCausalLM.from_pretrained(
    path, 
    torch_dtype=torch.float16 if device == "mps" else torch.float32,
    trust_remote_code=True
).to(device)
official_model.eval()

tokenizer = Tokenizer.from_file(str(f"{path}/tokenizer.json"))
message = "<|im_start|>user你好，我是ComistryMo，请多指教！<|im_end|><|im_start|>assistant"
input_ids_raw = tokenizer.encode(message).ids
input_ids = torch.tensor([input_ids_raw], dtype=torch.long).to(device)

def measure_time(func, name):
    # 预热一次，防止第一次运行包含了初始化开销
    print(f"正在预热 {name}...")
    if device == "mps":
        torch.mps.synchronize()
    
    start = time.perf_counter()
    result = func()
    
    if device == "mps":
        torch.mps.synchronize()
    end = time.perf_counter()
    
    elapsed = end - start
    print(f"[{name}] 耗时: {elapsed:.4f} 秒")
    return result

def run_my_inference():
    curr_input = input_ids.clone()
    with torch.no_grad():
        while True:
            logits = my_model(curr_input)
            next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
            if next_token.item() == 151645:
                break

            curr_input = torch.cat([curr_input, next_token], dim=1)
            
            if curr_input.shape[1] > 200: 
                break
    return tokenizer.decode(curr_input[0].tolist(), skip_special_tokens=True)

def run_official_inference():
    with torch.no_grad():
        output = official_model.generate(
            input_ids, 
            max_new_tokens=200,
            eos_token_id=151645,
            pad_token_id=151645,
            use_cache=True
        )
    return tokenizer.decode(output[0].tolist(), skip_special_tokens=True)

print("--- 开始对比 ---")
output_my = measure_time(run_my_inference, "手动推理 (无KV Cache)")
output_official = measure_time(run_official_inference, "官方推理 (有KV Cache)")

print("\n--- 结果验证 ---")
print(f"手动结果长度: {len(output_my)}")
print(f"官方结果长度: {len(output_official)}")

begin init model


The module name  (originally ) is not a valid Python identifier. Please rename the original module to avoid import issues.


--- 开始对比 ---
正在预热 手动推理 (无KV Cache)...
[手动推理 (无KV Cache)] 耗时: 79.4075 秒
正在预热 官方推理 (有KV Cache)...
[官方推理 (有KV Cache)] 耗时: 54.9702 秒

--- 结果验证 ---
手动结果长度: 389
官方结果长度: 340
