In [1]:
# 创建模型配置文件
from dataclasses import dataclass

@dataclass
class TransformerConfig:
    block_size: int = 1024
    vocab_size: int = 50304 
    n_layer: int = 4
    n_head: int = 4
    n_embd: int = 768
    dropout: float = 0.0
    bias: bool = True 

model_config = TransformerConfig(vocab_size=10, block_size=12, n_layer=2, n_head=4, n_embd=16, dropout=0.0, bias=True)


In [2]:
# 创建模型

from tiny_transformer import Transformer

model = Transformer(model_config)

number of parameters: 0.02M


In [3]:
# 前向传递

import torch

idx = torch.randint(1, 10, (4, 8))
logits, _ = model(idx)
print("logits",logits.size())

idx torch.Size([4, 8])
tok_emb torch.Size([4, 8, 16])
x after wpe: torch.Size([4, 8, 16])
enc_out: torch.Size([4, 8, 16])
x after decoder: torch.Size([4, 8, 16])
logits torch.Size([4, 1, 10])


In [4]:
# 推理
result = model.generate(idx, 3)
print("generate result",result.size())

idx torch.Size([4, 8])
tok_emb torch.Size([4, 8, 16])
x after wpe: torch.Size([4, 8, 16])
enc_out: torch.Size([4, 8, 16])
x after decoder: torch.Size([4, 8, 16])
idx torch.Size([4, 9])
tok_emb torch.Size([4, 9, 16])
x after wpe: torch.Size([4, 9, 16])
enc_out: torch.Size([4, 9, 16])
x after decoder: torch.Size([4, 9, 16])
idx torch.Size([4, 10])
tok_emb torch.Size([4, 10, 16])
x after wpe: torch.Size([4, 10, 16])
enc_out: torch.Size([4, 10, 16])
x after decoder: torch.Size([4, 10, 16])
generate result torch.Size([4, 11])


In [5]:
# 生成结果
result

tensor([[6, 3, 9, 6, 2, 2, 3, 6, 2, 2, 6],
        [9, 6, 7, 5, 8, 3, 1, 9, 3, 5, 4],
        [2, 7, 2, 9, 2, 4, 5, 4, 1, 3, 1],
        [3, 9, 7, 8, 1, 7, 8, 1, 3, 4, 7]])

In [6]:
# 导出模型到 ONNX 格式
import torch.onnx

model.eval()  # 设置为评估模式，禁用 dropout 等层

# 创建一个 dummy input 用于 tracing
dummy_input = torch.randint(1, 10, (1, model_config.block_size), dtype=torch.long)  # 调整batchsize为1，block_size

# 定义 ONNX 文件名
onnx_file_path = "transformer.onnx"

# 导出 ONNX 模型
torch.onnx.export(
    model,  # 要导出的模型
    dummy_input,  # 模型的输入
    onnx_file_path,  # ONNX 文件的保存路径
    export_params=True,  # 导出模型参数
    opset_version=13,  # ONNX 算子集版本，根据你的环境选择合适的版本
    do_constant_folding=True,  # 是否执行常量折叠优化
    input_names=['input'],  # 输入节点的名称
    output_names=['output'],  # 输出节点的名称
    dynamic_axes={'input': {1: 'sequence'}, 'output': {1: 'sequence'}}  # 定义动态轴，sequence表示长度可变
)

print(f"模型已导出到：{onnx_file_path}")

  assert t <= self.config.block_size, f"不能计算该序列，该序列长度为 {t}, 最大序列长度只有 {self.config.block_size}"
  k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
  q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
  v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
  att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))


idx torch.Size([1, 12])
tok_emb torch.Size([1, 12, 16])
x after wpe: torch.Size([1, 12, 16])
enc_out: torch.Size([1, 12, 16])
x after decoder: torch.Size([1, 12, 16])
模型已导出到：transformer.onnx
