In [None]:
from pathlib import Path
import torch
from sentence_transformers import SentenceTransformer

# Output paths (relative to this notebook: ../models)
notebook_dir = Path.cwd()
models_dir = notebook_dir.parent / "models"
models_dir.mkdir(parents=True, exist_ok=True)

# Trained model directory from train.py output
trained_model_dir = models_dir / "contrastive-minilm"
onnx_path = models_dir / "arxiv-all-MiniLM-L6-v2.onnx"

model = SentenceTransformer(str(trained_model_dir))
model.eval()
model.to("cpu")

dummy = model.tokenize(["This is a dummy input"])
input_ids = dummy["input_ids"]
attention_mask = dummy["attention_mask"]
token_type_ids = dummy.get("token_type_ids")

class OnnxWrapper(torch.nn.Module):
    def __init__(self, st_model: SentenceTransformer, use_token_type_ids: bool):
        super().__init__()
        self.st_model = st_model
        self.use_token_type_ids = use_token_type_ids

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        features = {"input_ids": input_ids, "attention_mask": attention_mask}
        if self.use_token_type_ids:
            features["token_type_ids"] = token_type_ids
        outputs = self.st_model(features)
        return outputs["sentence_embedding"]

use_token_type_ids = token_type_ids is not None
wrapper = OnnxWrapper(model, use_token_type_ids)
wrapper.eval()

input_names = ["input_ids", "attention_mask"]
inputs = (input_ids, attention_mask)
dynamic_axes = {
    "input_ids": {0: "batch", 1: "sequence"},
    "attention_mask": {0: "batch", 1: "sequence"},
    "sentence_embedding": {0: "batch"},
}

if use_token_type_ids:
    input_names.append("token_type_ids")
    inputs = (input_ids, attention_mask, token_type_ids)
    dynamic_axes["token_type_ids"] = {0: "batch", 1: "sequence"}

print(f"Exporting ONNX model to: {onnx_path}")

torch.onnx.export(
    wrapper,
    inputs,
    str(onnx_path),
    input_names=input_names,
    output_names=["sentence_embedding"],
    dynamic_axes=dynamic_axes
)

print("Done. ONNX model saved.")