# Export HuggingFace Models to ONNX

This notebook shows how to export:
- A distilbert embedding model (encoder-only)  
- A BART summarization model (seq2seq-LM)

…to ONNX format, ready for GPU inference (e.g. with TensorRT).  


In [1]:
!pip install transformers torch onnxruntime



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [12]:
from pathlib import Path
import torch
from itertools import chain
from transformers import (
    DistilBertModel, DistilBertTokenizer,
    AutoModelForSeq2SeqLM, AutoTokenizer
)
from transformers.onnx import FeaturesManager


In [3]:
import logging
import time
from torch.jit._trace import TracerWarning
import warnings

# --- suppress tracer warnings if you like ---
warnings.filterwarnings("ignore", category=TracerWarning)

# --- basic logging setup ---
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)


In [9]:
def export_embed_model(
    model_dir: str,
    onnx_output_path: str,
    opset: int = 17
):
    """
    Export a DistilBERT bi-encoder from `model_dir` to ONNX format at `onnx_output_path`.
    
    Args:
        model_dir (str): Path or HF identifier of the pretrained DistilBERT model.
        onnx_output_path (str): Path where the ONNX file will be written.
        opset (int): ONNX opset version to target (default: 17).
    """
    t0 = time.time()
    logger.info(f"⏳ Loading DistilBERT model from `{model_dir}`")
    # load the DistilBERT encoder
    model = DistilBertModel.from_pretrained(model_dir)
    tokenizer = DistilBertTokenizer.from_pretrained(model_dir)
    model.eval()
    logger.info(f"✅ Loaded model & tokenizer in {time.time() - t0:.1f}s")

    t1 = time.time()
    logger.info("⏳ Building ONNX config for encoder-only export")
    # ensure the model is supported and get the ONNXConfig class
    model_kind, onnx_config_cls = FeaturesManager.check_supported_model_or_raise(
        model, feature="default"
    )
    onnx_config = onnx_config_cls(model.config)
    # create dummy inputs (PyTorch tensors) for export
    dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework="pt")
    logger.info(f"✅ Prepared ONNX config & dummy inputs in {time.time() - t1:.1f}s")

    # extract the names & dynamic axes
    input_names = list(onnx_config.inputs.keys())     # e.g. ['input_ids','attention_mask']
    output_names = list(onnx_config.outputs.keys())   # e.g. ['last_hidden_state']
    dynamic_axes = {**onnx_config.inputs, **onnx_config.outputs}

    # build the tuple of example inputs in the correct order
    example_inputs = tuple(dummy_inputs[name] for name in input_names)

    t2 = time.time()
    logger.info(f"⏳ Exporting to ONNX (opset {opset}) → {onnx_output_path}")
    # perform the export
    torch.onnx.export(
        model,                        # model to export
        example_inputs,               # tuple of torch.Tensor matching input_names
        onnx_output_path,             # where to save the .onnx
        input_names=input_names,      # names of the model inputs
        output_names=output_names,    # names of the model outputs
        dynamic_axes=dynamic_axes,    # which axes are dynamic (batch, seq)
        opset_version=opset,          # ONNX opset version
        do_constant_folding=True      # fold constants for optimization
    )
    logger.info(f"✅ Exported ONNX in {time.time() - t2:.1f}s")
    logger.info(f"🏁 Total embed export time: {time.time() - t0:.1f}s")

In [10]:
def export_summarization_model(
    hf_name_or_dir: str,
    onnx_output_path: str,
    opset: int = 17
):
    t0 = time.time()
    logger.info(f"⏳ Loading seq2seq model from `{hf_name_or_dir}`")
    model = AutoModelForSeq2SeqLM.from_pretrained(hf_name_or_dir)
    tokenizer = AutoTokenizer.from_pretrained(hf_name_or_dir)
    model.eval()
    logger.info(f"✅ Loaded model & tokenizer in {time.time() - t0:.1f}s")

    t1 = time.time()
    logger.info("⏳ Building ONNX config for seq2seq-LM export")
    _, onnx_config_class = FeaturesManager.check_supported_model_or_raise(
        model, feature="seq2seq-lm"
    )
    onnx_config = onnx_config_class(model.config)
    dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework="pt")
    logger.info(f"✅ Prepared ONNX config & dummy inputs in {time.time() - t1:.1f}s")

    # Prepare names & axes
    input_names = list(onnx_config.inputs.keys())
    output_names = list(onnx_config.outputs.keys())
    dynamic_axes = {**onnx_config.inputs, **onnx_config.outputs}
    example_inputs = tuple(dummy_inputs[n] for n in input_names)

    t2 = time.time()
    logger.info(f"⏳ Exporting seq2seq model to ONNX (opset {opset}) → {onnx_output_path}")
    torch.onnx.export(
        model,
        example_inputs,
        onnx_output_path,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=dynamic_axes,
        opset_version=opset,
        do_constant_folding=True
    )
    logger.info(f"✅ Exported ONNX in {time.time() - t2:.1f}s")
    logger.info(f"🏁 Total summarization export time: {time.time() - t0:.1f}s")


In [None]:
export_embed_model(
    model_dir="/home/pb/projects/course/sem2/mlops/project/mlops/models/artifacts/model/model.sentence_transformer",
    onnx_output_path="models/bert.onnx",
)
export_summarization_model(
    hf_name_or_dir="facebook/bart-large",
    onnx_output_path="models/bart_summarize.onnx",
)


19:32:55 INFO ⏳ Loading DistilBERT model from `/home/pb/projects/course/sem2/mlops/project/mlops/models/artifacts/model/model.sentence_transformer`
19:32:55 INFO ✅ Loaded model & tokenizer in 0.1s
19:32:55 INFO ⏳ Building ONNX config for encoder-only export
19:32:55 INFO ✅ Prepared ONNX config & dummy inputs in 0.0s
19:32:55 INFO ⏳ Exporting to ONNX (opset 17) → models/bert.onnx
  torch.onnx.export(
19:32:55 INFO ✅ Exported ONNX in 0.4s
19:32:55 INFO 🏁 Total embed export time: 0.5s


In [15]:
# Cell: Register ONNX models in MLflow

import os
os.environ["MLFLOW_TRACKING_URI"] = "http://129.114.27.112:8000"
os.environ["MLFLOW_TRACKING_USERNAME"] = "admin"
os.environ["MLFLOW_TRACKING_PASSWORD"] = "password"


import mlflow
import mlflow.onnx
import onnx

# point MLflow at your tracking server (or rely on env vars you’ve already set)
mlflow.set_experiment("onnx-model-registration")

def make_input_example(model_cls, model_dir_or_name, feature):
    # load HF model & tokenizer to construct dummy inputs
    if feature == "default":
        model = model_cls.from_pretrained(model_dir_or_name)
        tokenizer = DistilBertTokenizer.from_pretrained(model_dir_or_name)
    else:  # "seq2seq-lm"
        model = model_cls.from_pretrained(model_dir_or_name)
        tokenizer = AutoTokenizer.from_pretrained(model_dir_or_name)
    model.eval()
    _, config_cls = FeaturesManager.check_supported_model_or_raise(model, feature=feature)
    onnx_config = config_cls(model.config)
    dummy = onnx_config.generate_dummy_inputs(tokenizer, framework="pt")
    # convert to numpy for MLflow
    return {k: v.cpu().numpy() for k, v in dummy.items()}


In [16]:
# 1) distilbert embedding ONNX
distilbert_input = make_input_example(
    DistilBertModel,
    "/home/pb/projects/course/sem2/mlops/project/mlops/models/artifacts/model/model.sentence_transformer",
    "default"
)
lf_onnx = onnx.load("models/distilbert.onnx")
with mlflow.start_run(run_name="distilbert-onnx-registration"):
    mlflow.onnx.log_model(
        onnx_model=lf_onnx,
        artifact_path="model",
        registered_model_name="distilbert-embedding-onnx",
        input_example=distilbert_input
    )
    print("✅ Registered distilbert ONNX with input_example")

Successfully registered model 'distilbert-embedding-onnx'.
2025/05/11 19:37:53 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: distilbert-embedding-onnx, version 1
Created version '1' of model 'distilbert-embedding-onnx'.


✅ Registered distilbert ONNX with input_example
🏃 View run distilbert-onnx-registration at: http://129.114.27.112:8000/#/experiments/8/runs/57398c53d71b435fbcd69113af0f141e
🧪 View experiment at: http://129.114.27.112:8000/#/experiments/8


In [None]:
# 2) BART summarization ONNX
bart_input = make_input_example(
    AutoModelForSeq2SeqLM,
    "facebook/bart-large",
    "seq2seq-lm"
)
bart_onnx = onnx.load("models/bart_summarize.onnx")
with mlflow.start_run(run_name="bart-summarize-onnx-registration"):
    mlflow.onnx.log_model(
        onnx_model=bart_onnx,
        artifact_path="model",
        registered_model_name="bart-summarize-onnx",
        input_example=bart_input
    )
    print("✅ Registered BART summarization ONNX with input_example")