In [1]:

import onnxruntime
import torch
import numpy as np

# 创建模型配置文件
from dataclasses import dataclass

@dataclass
class TransformerConfig:
    block_size: int = 1024
    vocab_size: int = 50304
    n_layer: int = 4
    n_head: int = 4
    n_embd: int = 768
    dropout: float = 0.0
    bias: bool = True

model_config = TransformerConfig(vocab_size=10, block_size=12, n_layer=2, n_head=4, n_embd=16, dropout=0.0, bias=True)



In [2]:

# 创建模型

from tiny_transformer import Transformer

model = Transformer(model_config)


# ONNX 模型路径
onnx_model_path = "transformer.onnx"

# 创建 ONNX Runtime 会话
ort_session = onnxruntime.InferenceSession(onnx_model_path)

# 准备输入数据 (必须是 numpy 数组)
# 确保输入形状与导出时使用的 dummy_input 匹配，并且数据类型正确
# 例如: (batch_size, sequence_length)
batch_size = 1
sequence_length = model_config.block_size  # 使用 block_size
input_data = torch.randint(1, 10, (batch_size, sequence_length)).numpy().astype(np.int64)

# 运行 ONNX 推理
ort_inputs = {'input': input_data}  # 输入名称必须与导出时指定的一致
ort_outputs = ort_session.run(None, ort_inputs) # None 表示获取所有输出

# ort_outputs 是一个列表，包含所有输出的结果
#  根据模型定义，假设我们只有一个输出
onnx_output = ort_outputs[0]

# 转换回 PyTorch Tensor (如果需要)
onnx_output_tensor = torch.from_numpy(onnx_output)

# 打印结果
print("ONNX Output Tensor Shape:", onnx_output_tensor.shape)
print("ONNX Output Tensor:", onnx_output_tensor)

# 可选: 与 PyTorch 模型的结果进行比较，以验证 ONNX 转换的正确性
# 注意: 需要使用相同的输入数据，并且确保 PyTorch 模型处于 eval 模式
model.eval()
torch_input = torch.from_numpy(input_data)
with torch.no_grad():
    torch_output, _ = model(torch_input)

print("PyTorch Output Tensor Shape:", torch_output.shape)
print("PyTorch Output Tensor:", torch_output)



number of parameters: 0.02M
ONNX Output Tensor Shape: torch.Size([1, 1, 10])
ONNX Output Tensor: tensor([[[-0.0452,  0.0436, -0.0248,  0.1356,  0.0727, -0.0507, -0.1029,
          -0.1260,  0.0055,  0.0508]]])
idx torch.Size([1, 12])
tok_emb torch.Size([1, 12, 16])
x after wpe: torch.Size([1, 12, 16])
enc_out: torch.Size([1, 12, 16])
x after decoder: torch.Size([1, 12, 16])
PyTorch Output Tensor Shape: torch.Size([1, 1, 10])
PyTorch Output Tensor: tensor([[[-0.0924, -0.0464,  0.1344, -0.0711,  0.1091, -0.0670,  0.1169,
          -0.1366,  0.0530,  0.0289]]])


In [3]:

#Compare outputs
if torch_output.shape == onnx_output_tensor.shape:
    diff = torch.abs(torch_output - onnx_output_tensor)
    max_diff = torch.max(diff)
    print("Max difference between PyTorch and ONNX outputs:", max_diff.item())
else:
    print("PyTorch and ONNX outputs have different shapes, cannot compare directly.")


Max difference between PyTorch and ONNX outputs: 0.21988385915756226
