In [None]:
#| default_exp utils

In [None]:
#| hide
from nbdev.showdoc import *

In [1]:
#| export
from fastcore.all import *
import json
import numpy as np
import onnxruntime as ort
import os
from tokenizers import Tokenizer, AddedToken

Let's load some default models that work well off the box for various tasks

In [2]:
#| export
embedding_gemma_prompt = AttrDict(
    document='Instruct: document \n document: {text}',
    query='Instruct: query \n query: {text}',
)
modernbert_prompt = AttrDict(
    document='search_document: {text}',
    query='search_query: {text}',
)
embedding_gemma = AttrDict(model='onnx-community/embeddinggemma-300m-ONNX', onnx_path='onnx/model.onnx', prompt=embedding_gemma_prompt)
modernbert = AttrDict(model='nomic-ai/modernbert-embed-base', onnx_path='onnx/model.onnx', prompt=modernbert_prompt)

FastEncode is an onnx based embedding model wrapper that can work with most onnx model with a huggingface tokenizer. (The Qwen models are a bit tricky due to their padding token handling so they need a custom wrapper which we will add later)

In [13]:
#| export
class FastEncode:
	def __init__(self,
				 model_dict=embedding_gemma,# model dict with model repo, onnx path and prompt templates
				 repo_id=None,              # model repo on HF. needs to have onnx model file
				 md=None,                   # local model dir
				 md_nm=None,                # onnx model file name
				 normalize=True,            # normalize embeddings
				 dtype=np.float16,          # output dtype
				 tti=False,                 # use token type ids
				 prompt=None,               # prompt templates
				 hf_token=None              # HF token. you can also set HF_TOKEN env variable
	):
		'''Fast ONNX-based text encoder'''
		assert (model_dict is None) != (repo_id is None), 'Either model_dict or repo_id must be provided and not both'
		repo_id = model_dict.model if model_dict else repo_id
		md = md or (model_dict.model if model_dict else repo_id)
		md_nm = md_nm or (model_dict.onnx_path if model_dict else 'onnx/model.onnx')
		prompt = prompt or (model_dict.prompt if model_dict else AttrDictDefault())
		store_attr()
		try: self.md = download_model(repo_id=repo_id, md=md, token=hf_token)
		except Exception as ex: print(f'model download failed: {ex}. hint: is hf_token set')
		self._load_enc()
	def _load_enc(self):
		try:
			onnx_p = Path(self.md)/ self.md_nm
			sess_opt = ort.SessionOptions()
			sess_opt.intra_op_num_threads = os.cpu_count() or 1
			sess_opt.execution_mode = ort.ExecutionMode.ORT_PARALLEL
			sess_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
			self._load_tok()
			xtra=filter_ex(ort.get_available_providers(),lambda x: x in ('CUDAExecutionProvider','CoreMLExecutionProvider'))
			try: self.sess = ort.InferenceSession(onnx_p, sess_opt, providers=xtra+["CPUExecutionProvider"])
			except Exception as ex: self.sess = ort.InferenceSession(onnx_p, sess_opt, providers=["CPUExecutionProvider"])
		except Exception as ex:
			print(f'Encoding setup errored out with exception: {ex}')
			self.sess = None
	def _load_tok(self):
		cfg = json.load(open(os.path.join(self.md, "config.json")))
		tok_cfg = json.load(open(os.path.join(self.md, "tokenizer_config.json")))
		tok_map = json.load(open(os.path.join(self.md, "special_tokens_map.json")))
		self.tok = Tokenizer.from_file(os.path.join(self.md, "tokenizer.json"))
		self.tok.enable_padding(pad_id=cfg["pad_token_id"], pad_token=tok_cfg["pad_token"])
		self.tok.enable_truncation(max_length=min(tok_cfg['model_max_length'], 512))
		for t in tok_map.values(): self.tok.add_special_tokens(
			[t if isinstance(t, str) else AddedToken(**t) if isinstance(t, dict) else None])
	def _enc(self, txts:list, dtype=np.int64):
		encs = self.tok.encode_batch(txts, add_special_tokens=True)
		ids = np.array([e.ids for e in encs], dtype=dtype)
		msk = np.array([e.attention_mask for e in encs], dtype=dtype)
		return ids, msk
	def _mp(self, mout: np.ndarray, msk: np.ndarray):
		token_embeddings = mout
		input_mask_expanded = np.expand_dims(msk, axis=-1)
		input_mask_expanded = np.broadcast_to(input_mask_expanded, token_embeddings.shape)
		sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1)
		sum_mask = np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None)
		return sum_embeddings / sum_mask
	def encode(self, lns:list, **kw):
		if not self.sess: print('ONNX session not initialized properly. Fix error during initialisation'); return None
		if not lns: return np.zeros((0, self.sess.get_outputs()[0].shape[-1]), dtype=self.dtype)
		ids, msk = self._enc(lns)
		if ids.ndim ==1: ids, msk = np.expand_dims(ids, axis=0), np.expand_dims(msk, axis=0)
		inp = dict(input_ids=ids,attention_mask=msk)
		if self.tti: inp['token_type_ids']=np.zeros(ids.shape, dtype=np.int64)
		o=self._mp(self.sess.run(None, inp)[0], msk)
		if self.normalize: o = o / np.clip(np.linalg.norm(o, ord=2, axis=1, keepdims=True), 1e-12, None)
		return o.astype(self.dtype)
	def encode_document(self, lns, prompt:str=None, **kw):
		if prompt is None: prompt = self.prompt.get('document', None)
		return self.encode(L(lns).map(lambda l: prompt.format(text=l) if prompt else l), **kw)
	def encode_query(self, lns, prompt:str=None, **kw):
		if prompt is None: prompt = self.prompt.get('query', None)
		return self.encode(L(lns).map(lambda l: prompt.format(text=l) if prompt else l), **kw)

def download_model(repo_id=embedding_gemma.model,   # model repo on HF
			   md=embedding_gemma.model,        # local model dir
			   token=None                       # HF token. you can also set HF_TOKEN env variable
):
	'''Download model from HF hub'''
	if Path(md).exists(): return md
	import huggingface_hub as hf
	return hf.snapshot_download(repo_id=repo_id, local_dir=md, token=token or os.getenv('HF_TOKEN'))

Let's quickly check if the encoder is working

In [14]:
enc=FastEncode()

[0;93m2025-12-16 11:27:09.334066665 [W:onnxruntime:, transformer_memcpy.cc:111 ApplyImpl] 736 Memcpy nodes are added to the graph main_graph for CUDAExecutionProvider. It might have negative impact on performance (including unable to run CUDA graph). Set session_options.log_severity_level=1 to see the detail logs before this message.[m


In [15]:
enc.encode_document(['This is a test', 'Another test'])

array([[ 0.05774 ,  0.001704,  0.002562, ..., -0.06177 , -0.00661 ,
         0.03174 ],
       [ 0.02939 , -0.008194, -0.00918 , ..., -0.02846 , -0.002222,
         0.02847 ]], shape=(2, 768), dtype=float16)

In [16]:
modern_enc=FastEncode(modernbert)

In [17]:
modern_enc.encode_query(['This is a test', 'Another test'])

array([[-0.05026 , -0.04352 , -0.0171  , ..., -0.04974 ,  0.01598 ,
        -0.07056 ],
       [-0.05093 , -0.02133 , -0.0368  , ..., -0.10736 , -0.000944,
        -0.01177 ]], shape=(2, 768), dtype=float16)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()