In [1]:
# !pip3 install onnxruntime # for running on Arm-based CPU and/or macOS
# !pip3 install onnxruntime-gpu # for running on GPU

In [2]:
from transformers import AutoTokenizer, AutoModel
import torch



In [3]:
# Import Ernie
tokenizer = AutoTokenizer.from_pretrained("nghuyong/ernie-1.0-base-zh")
model = AutoModel.from_pretrained("nghuyong/ernie-1.0-base-zh")

In [None]:
# 1. Just like BERT, there are usually 12 or 24 stacked Transformer layers (depending on the model size).

# 2. Each layer is an instance of a block like the following:
# class TransformerBlock(nn.Module):
#     def __init__(...):
#         self.attention = MultiheadAttention(...)
#         self.norm1 = LayerNorm(...)
#         self.mlp = FeedForward(...)
#         self.norm2 = LayerNorm(...)

# 3. To optimize ERNIE end-to-end:
# - Modify MultiheadAttention 
# - Ensure the TransformerBlock wraps around your optimized version.
# - Your changes automatically propagate to all N layers. Is this true?

In [47]:
# Access a single transformer block
block = model.encoder.layer[0] # For BERT-style ERNIE

In [48]:
block

ErnieLayer(
  (attention): ErnieAttention(
    (self): ErnieSelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): ErnieSelfOutput(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): ErnieIntermediate(
    (dense): Linear(in_features=768, out_features=3072, bias=True)
    (intermediate_act_fn): ReLU()
  )
  (output): ErnieOutput(
    (dense): Linear(in_features=3072, out_features=768, bias=True)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [None]:
# Fused QKV (What’s happening internally?):
# - The single Linear layer has shape [input_dim, 3 * embed_dim].
# - During the forward pass, torch.chunk splits the output into 3 tensors.
# - This results in 1 kernel call, as opposed to 3 separate linear layers.

In [32]:
# Backwards-compatible conversion
query = block.attention.self.query
key = block.attention.self.key
value = block.attention.self.value

In [33]:
# Step 1: Create your fused module
fused_qkv = FusedQKV(input_dim=768, embed, num_heads=12)

In [None]:
# Step 2: Concatenate weights and biases
# torch.cat() concatenates the given sequence of tensors in tensors in the given dimension
# Weight: [3*embed_dim, input_dim] since Linear weight is (out, in)
qkv_weight = torch.cat([
    query.weight.data,
    key.weight.data,
    value.weight.data
], dim=0)  # Shape: [3*768, 768]

# Bias: [3*embed_dim]
qkv_bias = torch.cat([
    query.bias.data,
    key.bias.data,
    value.bias.data
], dim=0)  # Shape: [3*768]

In [None]:
# Step 3: Assign to fused projection layer - This overrites the randonly initalized weights\
# in qkv_proj with the concatenated weights of Q, K, and V from the original ERNIE Layer.
fused_qkv.qkv_proj.weight.data.copy_(qkv_weight)
fused_qkv.qkv_proj.bias.data.copy_(qkv_bias)

In [None]:
# Stack them into one linear operation
# QKV = W_qkv @ x + b_qkv

In [None]:
# If you're deploying this model efficiently:
# - TorchScript will often fuse the operations automatically.
# - Libraries like FlashAttention, xFormers, or FusedLinear (from Nvidia's APEX or Triton)
# - offer fused QKV kernels with GPU-level optimizations.

In [35]:
# Explore / overwrite the following components.
# block.attention.self.dropout
# block.attention.output.dense
# block.attention.output.LayerNorm
# block.attention.output.dropout
# block.intermediate.dense
# block.intermediate.intermediate_act_fn
# block.output.dense
# block.output.LayerNorm
# block.output.dropout

In [40]:
# block_1 = model.encoder.layer[11] 
# block_1

In [7]:
# # Export to ONNX
# dummy_input = torch.randint(0, 100, (1, 128))  # (batch, seq_len)
# torch.onnx.export(
#     model, (dummy_input,),
#     "ernie.onnx",
#     input_names=["input_ids"],
#     output_names=["output"],
#     dynamic_axes={"input_ids": {0: "batch_size", 1: "seq_len"}},
#     opset_version=14
# )

# UnsupportedOperatorError: Exporting the operator 'aten::scaled_dot_product_attention'
# to ONNX opset version 13 is not supported. Support for this operator was added in 
# version 14, try exporting with this version.

In [8]:
# Verify in ONNX Runtime
# import onnxruntime as ort
# ort.InferenceSession("ernie.onnx")

https://docs.pytorch.org/docs/stable/onnx.html

In [9]:
# An example for exporting a model from PyTorch to ONNX (Open Neural Network eXchange) 
# ONNX is an open standard format for representing ML models. The torch.onnx module 
# captures the computation graph from a native PyTorch torch.nn.Module model and convert
# it into an ONNX graph.

# class MyModel(torch.nn.Module):
#     def __init__(self):
#         super(MyModel, self).__init__()
#         self.conv1 = torch.nn.Conv2d(1, 128, 5)

#     def forward(self, x):
#         return torch.relu(self.conv1(x))

# input_tensor = torch.rand((1, 1, 128, 128), dtype=torch.float32)

# model = MyModel()
# # model
# # input_tensor.dtype

# torch.onnx.export(
#     model,                 # model to export
#     (input_tensor,),       # inputs of the model
#     "my_model.onnx",       # filename of the ONNX model
#     input_names=["input"], # Rename inputs for the ONNX model
#     dynamo=True            # True or False o select the exporter to use
#)