# Export HuggingFace Models to ONNX

This notebook shows how to export:
- A distilbert embedding model (encoder-only)  
- A BART summarization model (seq2seq-LM)

…to ONNX format, ready for GPU inference (e.g. with TensorRT).  


In [1]:
!pip install transformers torch onnxruntime



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [21]:
from pathlib import Path
import torch
from itertools import chain
from transformers import (
    DistilBertModel, DistilBertTokenizer,
    AutoModelForSeq2SeqLM, AutoTokenizer
)
from transformers.onnx import FeaturesManager


In [22]:
import logging
import time
from torch.jit._trace import TracerWarning
import warnings

# --- suppress tracer warnings if you like ---
warnings.filterwarnings("ignore", category=TracerWarning)

# --- basic logging setup ---
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)


In [23]:
# Cell: Export DistilBERT → ONNX (no changes here)
def export_embed_model(
    model_dir: str,
    onnx_output_path: str,
    opset: int = 17
):
    t0 = time.time()
    logger.info(f"⏳ Loading DistilBERT model from `{model_dir}`")
    model = DistilBertModel.from_pretrained(model_dir)
    tokenizer = DistilBertTokenizer.from_pretrained(model_dir)
    model.eval()
    logger.info(f"✅ Loaded model & tokenizer in {time.time() - t0:.1f}s")

    t1 = time.time()
    logger.info("⏳ Building ONNX config for encoder-only export")
    _, onnx_config_cls = FeaturesManager.check_supported_model_or_raise(model, feature="default")
    onnx_config = onnx_config_cls(model.config)
    dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework="pt")
    logger.info(f"✅ Prepared ONNX config & dummy inputs in {time.time() - t1:.1f}s")

    input_names  = list(onnx_config.inputs.keys())
    output_names = list(onnx_config.outputs.keys())
    dynamic_axes = {**onnx_config.inputs, **onnx_config.outputs}
    example_inputs = tuple(dummy_inputs[n] for n in input_names)

    t2 = time.time()
    logger.info(f"⏳ Exporting to ONNX (opset {opset}) → {onnx_output_path}")
    torch.onnx.export(
        model,
        example_inputs,
        onnx_output_path,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=dynamic_axes,
        opset_version=opset,
        do_constant_folding=True
    )
    logger.info(f"✅ Exported ONNX in {time.time() - t2:.1f}s")
    logger.info(f"🏁 Total embed export time: {time.time() - t0:.1f}s")

In [24]:
def export_summarization_model(
    hf_name_or_dir: str,
    onnx_output_path: str,
    opset: int = 17
):
    t0 = time.time()
    logger.info(f"⏳ Loading seq2seq model from `{hf_name_or_dir}`")
    model = AutoModelForSeq2SeqLM.from_pretrained(hf_name_or_dir)
    tokenizer = AutoTokenizer.from_pretrained(hf_name_or_dir)
    model.eval()
    logger.info(f"✅ Loaded model & tokenizer in {time.time() - t0:.1f}s")

    t1 = time.time()
    logger.info("⏳ Building ONNX config for seq2seq-LM export")
    _, onnx_config_class = FeaturesManager.check_supported_model_or_raise(
        model, feature="seq2seq-lm"
    )
    onnx_config = onnx_config_class(model.config)
    dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework="pt")
    logger.info(f"✅ Prepared ONNX config & dummy inputs in {time.time() - t1:.1f}s")

    # Prepare names & axes
    input_names = list(onnx_config.inputs.keys())
    output_names = list(onnx_config.outputs.keys())
    dynamic_axes = {**onnx_config.inputs, **onnx_config.outputs}
    example_inputs = tuple(dummy_inputs[n] for n in input_names)

    t2 = time.time()
    logger.info(f"⏳ Exporting seq2seq model to ONNX (opset {opset}) → {onnx_output_path}")
    torch.onnx.export(
        model,
        example_inputs,
        onnx_output_path,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=dynamic_axes,
        opset_version=opset,
        do_constant_folding=True
    )
    logger.info(f"✅ Exported ONNX in {time.time() - t2:.1f}s")
    logger.info(f"🏁 Total summarization export time: {time.time() - t0:.1f}s")


In [25]:
# export into your folder
ONNX_DIR = Path("/home/pb/projects/course/sem2/mlops/project/mlops/models")
ONNX_DIR.mkdir(exist_ok=True)
export_embed_model(
    model_dir="/home/pb/projects/course/sem2/mlops/project/mlops/models/artifacts/model/model.sentence_transformer",
    onnx_output_path=str(ONNX_DIR/"distilbert.onnx"),
)
export_summarization_model(
    hf_name_or_dir="facebook/bart-large",
    onnx_output_path=str(ONNX_DIR/"bart_summarize.onnx")
)


22:27:04 INFO ⏳ Loading DistilBERT model from `/home/pb/projects/course/sem2/mlops/project/mlops/models/artifacts/model/model.sentence_transformer`
22:27:04 INFO ✅ Loaded model & tokenizer in 0.1s
22:27:04 INFO ⏳ Building ONNX config for encoder-only export
22:27:04 INFO ✅ Prepared ONNX config & dummy inputs in 0.0s
22:27:04 INFO ⏳ Exporting to ONNX (opset 17) → /home/pb/projects/course/sem2/mlops/project/mlops/models/distilbert.onnx
  torch.onnx.export(
22:27:05 INFO ✅ Exported ONNX in 1.3s
22:27:05 INFO 🏁 Total embed export time: 1.4s
22:27:05 INFO ⏳ Loading seq2seq model from `facebook/bart-large`
22:27:07 INFO ✅ Loaded model & tokenizer in 2.2s
22:27:07 INFO ⏳ Building ONNX config for seq2seq-LM export
22:27:07 INFO ✅ Prepared ONNX config & dummy inputs in 0.0s
22:27:07 INFO ⏳ Exporting seq2seq model to ONNX (opset 17) → /home/pb/projects/course/sem2/mlops/project/mlops/models/bart_summarize.onnx
  torch.onnx.export(
22:27:15 INFO ✅ Exported ONNX in 7.6s
22:27:15 INFO 🏁 Total summa

In [27]:
from mlflow.tracking import MlflowClient
from mlflow.models import Model

client = MlflowClient()

# Fetch metadata for version 1 of your model
mv = client.get_model_version(name="facebook-bart-large", version="1")
print("Source URI:", mv.source)

# Load the MLmodel metadata and list its flavors
model_conf = Model.load(mv.source)
print("Flavors available:", model_conf.flavors.keys())


Source URI: mlflow-artifacts:/7/72bdeff2ad79435fac87e63bd17da8cd/artifacts/model


Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Flavors available: dict_keys(['python_function', 'transformers'])


In [6]:
# Cell: Register ONNX models in MLflow

import os
os.environ["MLFLOW_TRACKING_URI"] = "http://129.114.27.112:8000"
os.environ["MLFLOW_TRACKING_USERNAME"] = "admin"
os.environ["MLFLOW_TRACKING_PASSWORD"] = "password"


import mlflow
import mlflow.onnx
import onnx

# point MLflow at your tracking server (or rely on env vars you’ve already set)
mlflow.set_experiment("onnx-model-registration")

def make_input_example(model_cls, model_dir_or_name, feature):
    # load HF model & tokenizer to construct dummy inputs
    if feature == "default":
        model = model_cls.from_pretrained(model_dir_or_name)
        tokenizer = DistilBertTokenizer.from_pretrained(model_dir_or_name)
    else:  # "seq2seq-lm"
        model = model_cls.from_pretrained(model_dir_or_name)
        tokenizer = AutoTokenizer.from_pretrained(model_dir_or_name)
    model.eval()
    _, config_cls = FeaturesManager.check_supported_model_or_raise(model, feature=feature)
    onnx_config = config_cls(model.config)
    dummy = onnx_config.generate_dummy_inputs(tokenizer, framework="pt")
    # convert to numpy for MLflow
    return {k: v.cpu().numpy() for k, v in dummy.items()}


In [8]:
# Cell: Register base ONNX in MLflow (fixed)
import onnx

# load the base ONNX file
onnx_base = onnx.load(str(ONNX_DIR/"distilbert.onnx"))

# build the input_example (already NumPy arrays)
example_inputs = make_input_example(
    DistilBertModel,
    "/home/pb/projects/course/sem2/mlops/project/mlops/models/artifacts/model/model.sentence_transformer",
    "default"
)

with mlflow.start_run(run_name="distilbert-base-onnx"):
    mlflow.onnx.log_model(
        onnx_model=onnx_base,
        artifact_path="model",
        registered_model_name="distilbert-embedding-onnx",
        input_example=example_inputs    # already a dict of NumPy arrays
    )
    print("✅ Registered base DistilBERT ONNX")


🏃 View run distilbert-base-onnx at: http://129.114.27.112:8000/#/experiments/8/runs/20c4fab9168d471995e1f53b4ee2f77a
🧪 View experiment at: http://129.114.27.112:8000/#/experiments/8


KeyboardInterrupt: 

In [9]:
# Cell: Graph-optimize
import onnxruntime as ort

base_path = str(ONNX_DIR/"distilbert.onnx")
opt_path  = str(ONNX_DIR/"distilbert_opt.onnx")

sess_opts = ort.SessionOptions()
sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
sess_opts.optimized_model_filepath = opt_path

# this will write out the optimized graph
_ = ort.InferenceSession(base_path, sess_options=sess_opts, providers=["CPUExecutionProvider"])
print(f"✅ Graph-optimized model written to {opt_path}")


✅ Graph-optimized model written to /home/pb/projects/course/sem2/mlops/project/mlops/models/distilbert_opt.onnx


In [10]:
# Cell: Dynamic Quantization (weight-only)
from onnxruntime.quantization import quantize_dynamic, QuantType

dyn_path = str(ONNX_DIR/"distilbert_dyn.onnx")
quantize_dynamic(
    model_input=base_path,
    model_output=dyn_path,
    weight_type=QuantType.QInt8
)
print(f"✅ Dynamic-quant ONNX written to {dyn_path}")


21:05:09 INFO Quantization parameters for tensor:"/embeddings/LayerNorm/LayerNormalization_output_0" not specified
21:05:09 INFO Ignore MatMul due to non constant B: /[/transformer/layer.0/attention/MatMul]
21:05:09 INFO Ignore MatMul due to non constant B: /[/transformer/layer.0/attention/MatMul_1]
21:05:09 INFO Quantization parameters for tensor:"/transformer/layer.0/attention/Reshape_3_output_0" not specified
21:05:09 INFO Quantization parameters for tensor:"/transformer/layer.0/sa_layer_norm/LayerNormalization_output_0" not specified
21:05:09 INFO Quantization parameters for tensor:"/transformer/layer.0/ffn/activation/Mul_1_output_0" not specified
21:05:09 INFO Quantization parameters for tensor:"/transformer/layer.0/output_layer_norm/LayerNormalization_output_0" not specified
21:05:09 INFO Ignore MatMul due to non constant B: /[/transformer/layer.1/attention/MatMul]
21:05:09 INFO Ignore MatMul due to non constant B: /[/transformer/layer.1/attention/MatMul_1]
21:05:09 INFO Quantiza

✅ Dynamic-quant ONNX written to /home/pb/projects/course/sem2/mlops/project/mlops/models/distilbert_dyn.onnx


In [17]:
# Cell: Static Quantization (heavy / moderate / low)
from onnxruntime.quantization import (
    quantize_static, QuantType, QuantFormat, CalibrationDataReader, CalibrationMethod
)

# Simple DataReader using the dummy inputs
class DistilBertDataReader(CalibrationDataReader):
    def __init__(self, tokenizer, texts):
        enc = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="np")
        self.inputs = [
            {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}
        ]
    def get_next(self):
        return self.inputs.pop(0) if self.inputs else None

tokenizer = DistilBertTokenizer.from_pretrained("/home/pb/projects/course/sem2/mlops/project/mlops/models/artifacts/model/model.sentence_transformer")
# use a handful of samples for calibration
cal_texts = ["This is a test.", "Another example.", "One more sentence."] * 5
dr = DistilBertDataReader(tokenizer, cal_texts)

# 1) Heavy static: QOperator + int8 activations + int8 weights
heavy_path = str(ONNX_DIR/"distilbert_static_heavy.onnx")
quantize_static(
    model_input=base_path,
    model_output=heavy_path,
    calibration_data_reader=dr,
    quant_format=QuantFormat.QOperator,
    activation_type=QuantType.QInt8,
    weight_type=QuantType.QInt8,
    calibrate_method=CalibrationMethod.MinMax
)
print(f"✅ Heavy static quant ONNX → {heavy_path}")

# 2) Moderate static: QDQ format + int8 activations + int8 weights
dr = DistilBertDataReader(tokenizer, cal_texts)
moderate_path = str(ONNX_DIR/"distilbert_static_moderate.onnx")
quantize_static(
    model_input=base_path,
    model_output=moderate_path,
    calibration_data_reader=dr,
    quant_format=QuantFormat.QDQ,
    activation_type=QuantType.QInt8,
    weight_type=QuantType.QInt8,
    calibrate_method=CalibrationMethod.MinMax
)
print(f"✅ Moderate static quant ONNX → {moderate_path}")

  elem_type: 7
  shape {
  }
}
.
  elem_type: 7
  shape {
    dim {
      dim_param: "batch"
    }
    dim {
      dim_param: "sequence"
    }
  }
}
.
  elem_type: 7
  shape {
    dim {
      dim_param: "batch"
    }
    dim {
      dim_value: 1
    }
    dim {
      dim_param: "sequence"
    }
  }
}
.
  elem_type: 7
  shape {
  }
}
.
  elem_type: 7
  shape {
  }
}
.
  elem_type: 7
  shape {
  }
}
.
  elem_type: 7
  shape {
    dim {
      dim_value: 1
    }
  }
}
.
  elem_type: 7
  shape {
    dim {
      dim_value: 4
    }
  }
}
.
  elem_type: 7
  shape {
    dim {
      dim_param: "unk__4"
    }
  }
}
.
  elem_type: 7
  shape {
  }
}
.
  elem_type: 7
  shape {
    dim {
      dim_value: 1
    }
  }
}
.
  elem_type: 7
  shape {
  }
}
.
  elem_type: 7
  shape {
    dim {
      dim_value: 1
    }
  }
}
.
  elem_type: 7
  shape {
  }
}
.
  elem_type: 7
  shape {
    dim {
      dim_value: 1
    }
  }
}
.
  elem_type: 7
  shape {
  }
}
.
  elem_type: 7
  shape {
    dim {
      dim_value

✅ Heavy static quant ONNX → /home/pb/projects/course/sem2/mlops/project/mlops/models/distilbert_static_heavy.onnx




✅ Moderate static quant ONNX → /home/pb/projects/course/sem2/mlops/project/mlops/models/distilbert_static_moderate.onnx


In [19]:
# Cell: Register all optimized models in MLflow (with input_example)

import onnx

# Build one input_example up front (NumPy arrays) for the DistilBERT encoder
EMBED_MODEL_DIR = "/home/pb/projects/course/sem2/mlops/project/mlops/models/artifacts/model/model.sentence_transformer"
embed_input_example = make_input_example(
    DistilBertModel,
    EMBED_MODEL_DIR,
    "default"
)

# Now register each variant
for tag, path in [
    ("graph-opt",      opt_path),
    ("dynamic-quant",  dyn_path),
    ("static-heavy",   heavy_path),
    ("static-moderate",moderate_path)
]:
    onnx_m = onnx.load(path)
    with mlflow.start_run(run_name=f"distilbert-{tag}-registration"):
        mlflow.onnx.log_model(
            onnx_model=onnx_m,
            artifact_path="model",
            registered_model_name=f"distilbert-embedding-onnx-{tag}",
            input_example=embed_input_example
        )
        print(f"✅ Registered distilbert-{tag} ONNX with input_example")


Successfully registered model 'distilbert-embedding-onnx-graph-opt'.
2025/05/11 21:17:25 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: distilbert-embedding-onnx-graph-opt, version 1
Created version '1' of model 'distilbert-embedding-onnx-graph-opt'.


✅ Registered distilbert-graph-opt ONNX with input_example
🏃 View run distilbert-graph-opt-registration at: http://129.114.27.112:8000/#/experiments/8/runs/f88c3d52746f45e8b746cbb9b43bc817
🧪 View experiment at: http://129.114.27.112:8000/#/experiments/8


Successfully registered model 'distilbert-embedding-onnx-dynamic-quant'.
2025/05/11 21:18:04 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: distilbert-embedding-onnx-dynamic-quant, version 1
Created version '1' of model 'distilbert-embedding-onnx-dynamic-quant'.


✅ Registered distilbert-dynamic-quant ONNX with input_example
🏃 View run distilbert-dynamic-quant-registration at: http://129.114.27.112:8000/#/experiments/8/runs/d045a3cae5c74830af7526933adf5243
🧪 View experiment at: http://129.114.27.112:8000/#/experiments/8


Successfully registered model 'distilbert-embedding-onnx-static-heavy'.
2025/05/11 21:18:43 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: distilbert-embedding-onnx-static-heavy, version 1
Created version '1' of model 'distilbert-embedding-onnx-static-heavy'.


✅ Registered distilbert-static-heavy ONNX with input_example
🏃 View run distilbert-static-heavy-registration at: http://129.114.27.112:8000/#/experiments/8/runs/c8f4f8f01d464e64a0756b3c445b3c23
🧪 View experiment at: http://129.114.27.112:8000/#/experiments/8


Successfully registered model 'distilbert-embedding-onnx-static-moderate'.
2025/05/11 21:19:25 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: distilbert-embedding-onnx-static-moderate, version 1
Created version '1' of model 'distilbert-embedding-onnx-static-moderate'.


✅ Registered distilbert-static-moderate ONNX with input_example
🏃 View run distilbert-static-moderate-registration at: http://129.114.27.112:8000/#/experiments/8/runs/31936eea118a4cc2990a2ba55a5bca72
🧪 View experiment at: http://129.114.27.112:8000/#/experiments/8


In [None]:
# 2) BART summarization ONNX
bart_input = make_input_example(
    AutoModelForSeq2SeqLM,
    "facebook/bart-large",
    "seq2seq-lm"
)
bart_onnx = onnx.load("/home/pb/projects/course/sem2/mlops/project/mlops/models/bart_summarize.onnx")
with mlflow.start_run(run_name="bart-summarize-onnx-registration"):
    mlflow.onnx.log_model(
        onnx_model=bart_onnx,
        artifact_path="model",
        registered_model_name="bart-summarize-onnx",
        input_example=bart_input
    )
    print("✅ Registered BART summarization ONNX with input_example")