Let's export the trained model to HuggingFace Hub in safetensors formats for compatibility with downstream inference engines. First, we'll define some variables.

In [None]:
model_name = "ProtHash-V2-512-Tiny"
checkpoint_path = "./checkpoints/checkpoint.pt"
exports_path = "./exports"

Then, we'll load the base model checkpoint into memory from disk.

In [None]:
import torch

from src.prothash.model import ProtHash

checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)

model = ProtHash(**checkpoint["model_args"])

model.load_state_dict(checkpoint["model"])

model.merge_lora_adapters()

model.remove_fake_quantized_tensors()

model.eval()

print("Base checkpoint loaded successfully")

Now, let's export the model in HuggingFace format so that it can be used with the HuggingFace ecosystem.

In [None]:
from os import path

hf_path = path.join(exports_path, model_name)

model.save_pretrained(hf_path)

print(f"Model saved to {hf_path}")

Login to HuggingFaceHub and upload the model under our account.

In [None]:
from huggingface_hub import notebook_login

notebook_login()

model.push_to_hub(model_name)

Lastly, we'll export a model in ONNX format for use with the ONNX runtime. We'll output two separate computational graphs - one that outputs in native dimensionality and another that outputs in the teacher's dimensionality.

In [None]:
from os import path
from functools import partial

from torch.onnx import export as export_onnx

from torch.export.dynamic_shapes import Dim

from src.prothash.model import ONNXModelNative, ONNXModelTeacher

dynamic_shapes = {
    "x": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC},
}

new_onnx_graph = partial(
    export_onnx,
    dynamic_shapes=dynamic_shapes,
    dynamo=True,
    input_names=["sequences"],
    output_names=["embeddings"],
)

x = torch.randint(0, 32, (1, 1000), dtype=torch.int64)

onnx_path = path.join(exports_path, model_name, "model_native.onnx")

onnx_model = ONNXModelNative(model)

onnx_graph = new_onnx_graph(onnx_model, (x,))

onnx_graph.save(onnx_path)

onnx_path = path.join(exports_path, model_name, "model_teacher.onnx")

onnx_model = ONNXModelTeacher(model)

onnx_graph = new_onnx_graph(onnx_model, x)

onnx_graph.save(onnx_path)

With how haphazardly the ONNX support is implemented in PyTorch it's wise to do a quick sanity check on the newly exported ONNX model.

In [None]:
import onnxruntime

from numpy.testing import assert_allclose

pytorch_logits = model.embed_teacher(x).detach().numpy()

session = onnxruntime.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])

onnx_input = {"sequences": x.numpy()}

onnx_logits = session.run(None, onnx_input)

onnx_logits = onnx_logits[0]

assert_allclose(pytorch_logits, onnx_logits, rtol=1e-2, atol=1e-03)

print("Looks good!")