In [None]:
# !pip install onnx onnxruntime onnxscript
# !pip3 install nbstripout

In [None]:
# --- Imports ---
import sys
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification

# Optional ONNX/TensorRT imports
# import onnx
# import onnxscript
# import onnxruntime as ort

# --- Add project root to path (adjust as needed) ---
PROJECT_ROOT = "LLM/ernie-tensorrt-inference"
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

In [None]:
# --- Device setup ---
def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

device = get_device()
print(f"Device: {device}")

In [None]:
# --- Precision selection for optimized model ---
PRECISION_OPTIONS = ["float32", "float16", "bfloat16"]
precision = "float16" # change as needed
assert precision in PRECISION_OPTIONS, f"Precision must be one of {PRECISION_OPTIONS}"
print(f"Optimized Model Precision: {precision}")

# --- Batch size selection ---
BATCH_SIZE_OPTIONS = [16, 32, 64, 128, 512]
batch_size = 16 # change as needed
assert batch_size in BATCH_SIZE_OPTIONS, f"Batch size must be one of {BATCH_SIZE_OPTIONS}"
print(f"Batch size: {batch_size}")

In [None]:
# --- Load validation dataset ---
def load_validation_dataset(dataset_name="C-MTEB/TNews-classification", split="validation"):
    ds = load_dataset(dataset_name)
    return ds[split]

val_data = load_validation_dataset()
print(f"Validation dataset size: {len(val_data)}")

In [None]:
# --- Model setup ---
def load_models(model_name="nghuyong/ernie-3.0-base-zh", num_labels=15):
    # Original model (baseline)
    model_orig = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)

    # Optimized model (for future TensorRT / fused layers)
    model_opt = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)

    return model_orig, model_opt

# --- Load models ---
model_orig, model_opt = load_models(num_labels=15)

# --- Print configuration info ---
config = model_orig.config
print(f"Number of labels: {config.num_labels}") # outputs: 15

In [None]:
from tokenizer import tokenize_function

# --- Tokenize + prepare DataLoader ---
def prepare_dataloader(dataset, batch_size=batch_size, max_length=64, shuffle=False):
    tokenized = dataset.map(lambda x: tokenize_function(x, max_length=max_length), batched=True)
    tokenized.set_format(
        type="torch",
        columns=["input_ids", "attention_mask", "token_type_ids", "label"]
    )
    return DataLoader(tokenized, batch_size=batch_size, shuffle=shuffle)

# --- Create DataLoader ---
val_loader = prepare_dataloader(val_data, batch_size=batch_size, max_length=64, shuffle=False)

# Note: smaller max_length (e.g., 64 or 128 vs default 512) reduces memory and boosts inference speed

In [None]:
from models import ErnieEmbeddings, ErnieSelfAttention, ErnieSelfOutput, ErnieIntermediate, ErnieOutput, ErniePooler

# --- Replace ERNIE model components for optimization ---
def replace_ernie_layers(model, config):
    # Replace embeddings
    model.ernie.embeddings = ErnieEmbeddings(config)

    # Replace each encoder layer's components
    for layer in model.ernie.encoder.layer:
        layer.attention.self = ErnieSelfAttention(config)
        layer.attention.output = ErnieSelfOutput(config)
        layer.intermediate = ErnieIntermediate(config)
        layer.output = ErnieOutput(config)
        layer.pooler = ErniePooler(config)

    model.ernie.pooler = ErniePooler(config)

    return model

model_opt = replace_ernie_layers(model_opt, config)

In [None]:
# --- Helper function to prepare model ---
def prepare_model(model, device):
    model.eval()
    return model.to(device)

# --- Apply to original and optimized models ---
model_orig = prepare_model(model_orig, device)
model_opt = prepare_model(model_opt, device)

In [None]:
# Display the structure of the optimized model
model_opt

In [None]:
# --- Set model precision ---
def set_model_precision(model, precision="float32"):
    if precision == "float16":
        model = model.half()
    elif precision == "bfloat16":
        model = model.to(torch.bfloat16)
    # float32 requires no action
    return model

model_opt = set_model_precision(model_opt, precision)

# --- Print model info ---
def print_model_info(model_orig, model_opt, precision):
    print(f"Original Model Device: {model_orig.device}")
    print(f"Optimized Model Device: {model_opt.device}")
    print(f"Optimized Model Precision: {model_opt.dtype} ({precision})")

# print_model_info(model_orig, model_opt, precision)

In [None]:
# --- Loss function ---
criterion = nn.CrossEntropyLoss()

# --- Metric computation ---
def evaluate_model(model, data_loader, device):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for batch in data_loader:
            # Move inputs and labels to device
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            token_type_ids = batch["token_type_ids"].to(device)
            labels = batch["label"].to(device)

            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids
            )

            # Use outputs[0] for models without logits attribute
            logits = outputs[0] # shape: [batch_size, num_classes]

            # Loss
            loss = criterion(logits, labels)
            total_loss += loss.item() * input_ids.size(0)

            # Accuracy
            preds = logits.argmax(dim=-1)
            total_correct += (preds == labels).sum().item()
            total_samples += input_ids.size(0)

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples
    return avg_loss, accuracy

# --- Latency benchmarking ---
def measure_latency(model, inputs, device, n_warmup=10, n_iter=100):
    latencies = []

    with torch.no_grad():
        # Warm-up
        for _ in range(n_warmup):
            _ = model(**inputs)
            torch.cuda.synchronize()

        # Measure
        for _ in range(n_iter):
            torch.cuda.synchronize()
            start = time.time()
            _ = model(**inputs)
            torch.cuda.synchronize()
            end = time.time()
            latencies.append((end - start) * 1000) # ms

    mean_latency = np.mean(latencies)
    return mean_latency

In [None]:
# --- Prepare a single batch for latency measurement ---
sample_batch = next(iter(val_loader))
# inputs = {k: v.to(device) for k, v in sample_batch.items() if k in ["input_ids", "attention_mask", "token_type_ids"]}

# Convert only float32 tensors to FP16, keep int64 (like input_ids) untouched
inputs = {
    k: (v.to(device).half() if v.dtype == torch.float32 else v.to(device))
    for k, v in sample_batch.items()
    if k in ["input_ids", "attention_mask", "token_type_ids"]}

# --- Evaluate Original Model ---
avg_loss_orig, acc_orig = evaluate_model(model_orig, val_loader, device)
latency_orig = measure_latency(model_orig, inputs, device)

print(f"[Original] Precision: {model_orig.dtype}, Loss: {avg_loss_orig:.4f}, Accuracy: {acc_orig:.4f}, Mean latency: {latency_orig:.2f} ms")

# --- Evaluate Optimized Model ---
avg_loss_opt, acc_opt = evaluate_model(model_opt, val_loader, device)
latency_opt = measure_latency(model_opt, inputs, device)

print(f"[Optimized] Precision: {model_opt.dtype}, Loss: {avg_loss_opt:.4f}, Accuracy: {acc_opt:.4f}, Mean latency: {latency_opt:.2f} ms")

In [None]:
from torch.profiler import profile, record_function, ProfilerActivity

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    with_stack=True,
    profile_memory=True
) as prof:
    with record_function("model_inference"):
        with torch.no_grad():
            for _ in range(10):   # run multiple iterations to get stable timings
                outputs = model_opt(**inputs)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

In [None]:
# {
#   "input_ids": torch.Tensor(...).to(device),        # int64
#   "attention_mask": torch.Tensor(...).to(device),   # float16 if chosen
#   "token_type_ids": torch.Tensor(...).to(device)    # int64
# }

In [None]:
# from torch.profiler import tensorboard_trace_handler

# with profile(
#     activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
#     on_trace_ready=tensorboard_trace_handler("./log"),
#     record_shapes=True,
#     profile_memory=True,
#     with_stack=True
# ) as prof:
#     with record_function("model_inference"):
#         for _ in range(10):
#             outputs = model(**inputs)