Let's export the trained model to HuggingFace Hub in safetensors formats for compatibility with downstream inference engines. First, we'll define some variables.

In [10]:
model_name = "ProtHash-512"
checkpoint_path = "./checkpoints/checkpoint.pt"
exports_path = "./exports"

Then, we'll load the base model checkpoint into memory from disk.

In [11]:
import torch

from src.prothash.model import ProtHash

checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)

model = ProtHash(**checkpoint["model_args"])

model.add_adapter_head(checkpoint["teacher_embedding_dimensions"])

model.load_state_dict(checkpoint["model"])

model.load_state_dict(checkpoint["model"])

model.remove_adapter_head()

model.eval()

print("Base checkpoint loaded successfully")

Base checkpoint loaded successfully


Now, let's export the model in HuggingFace format so that it can be used with the HuggingFace ecosystem.

In [12]:
from os import path

hf_path = path.join(exports_path, model_name)

model.save_pretrained(hf_path)

print(f"Model saved to {hf_path}")

Model saved to ./exports/ProtHash-512


Login to HuggingFaceHub and upload the model under our account.

In [13]:
from huggingface_hub import notebook_login

notebook_login()

model.push_to_hub(model_name)

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

CommitInfo(commit_url='https://huggingface.co/andrewdalpino/ProtHash-512/commit/129c0f2bc33712c0bebd96981dcc90955de0e2b8', commit_message='Push model using huggingface_hub.', commit_description='', oid='129c0f2bc33712c0bebd96981dcc90955de0e2b8', pr_url=None, repo_url=RepoUrl('https://huggingface.co/andrewdalpino/ProtHash-512', endpoint='https://huggingface.co', repo_type='model', repo_id='andrewdalpino/ProtHash-512'), pr_revision=None, pr_num=None)

Lastly, we'll export a model in ONNX format for use with the ONNX runtime.

In [14]:
from os import path

from torch.onnx import export as export_onnx

from torch.export.dynamic_shapes import Dim

from src.prothash.model import ONNXModel

onnx_path = path.join(exports_path, model_name, "model.onnx")

onnx_model = ONNXModel(model)

x = torch.randint(0, 32, (1, 1000), dtype=torch.int64)

dynamic_shapes = {
    "x": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC},
}

onnx_graph = export_onnx(
    onnx_model,
    (x,),
    dynamic_shapes=dynamic_shapes,
    dynamo=True,
    input_names=["x"],
    output_names=["output"],
)

onnx_graph.save(onnx_path)

[torch.onnx] Obtain model graph for `ONNXModel([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ONNXModel([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 35 of general pattern rewrite rules.


With how haphazardly the ONNX support is implemented in PyTorch it's wise to do a quick sanity check on the newly exported ONNX model.

In [15]:
import onnxruntime

from numpy.testing import assert_allclose

pytorch_logits = model.embed(x).detach().numpy()

session = onnxruntime.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])

onnx_input = {"x": x.numpy()}

onnx_logits = session.run(None, onnx_input)

onnx_logits = onnx_logits[0]

assert_allclose(pytorch_logits, onnx_logits, rtol=1e-2, atol=1e-03)

print("Looks good!")

Looks good!
