Skip to content


Folders and files

Last commit message
Last commit date

Latest commit


Repository files navigation

ONNX-GoMLX from ONNX to GoMLX and back


πŸ“– Overview

ONNX-GoMLX converts ONNX models (.onnx suffix) to GoMLX (an accelerated machine learning framework for Go and optionally back to ONNX.

The main use cases so far are:

  1. Fine-tuning: import an inference only ONNX model to GoMLX, and use its auto-differentiation and training loop to fine-tune models. It allows saving the fine-tuned model as a GoMLX checkpoint or export the fine-tuned weights back to the ONNX model. This can also be used to expand / combine models.
  2. Inference: use an ONNX file using Go and not having to include ONNX Runtime (or Python) -- at the cost of including XLA/PJRT (the current only backend for GoMLX). It also allows one to extend the model with extra ML pre/post-processing using GoMLX (image transformations, normalization, combining models, building ensembles, etc.). This may be interesting for large/expensive models, or large throughput on large batches.
    • Notice if you want to simply get a pure Go inference of ONNX models, see or They will be slower (~8x based on a SentenceEncoder model, BERT based using gonnx vs ONNXRuntime) than the XLA inference (or onnxruntime) for large projects, but for many use cases it doesn't matter, and they are a much smaller pure Go dependency. Only for CPU (no GPU support).

Coverage of ONNX Ops Set

There are at least 10 or so models that are working so far:

But not all operations ("ops") are converted yet. If you try it and find some that is not, please let us know (create an "issue") we will be happy to try to convert them -- generally, all the required scaffolding and tooling is already there, and converting ops has been very easy.

πŸŽ“ Example

We download (and cache) the all-MiniLM-L6-v2

The tokens were for now hardcoded -- eventually should also do the tokenization for various models.

import (


// Download and cache ONNX model from HuggingFace.
hfAuthToken := os.Getenv("HF_TOKEN")
hfModelID := "sentence-transformers/all-MiniLM-L6-v2"
repo := hub.New(modelID).WithAuth(hfAuthToken)
modelPath := must.M1(repo.DownloadFile("onnx/model.onnx"))

// Parse ONNX model.
model := must.M1(onnx.ReadFile(modelPath))

// Convert ONNX variables (model weights) to GoMLX Context -- which stores variables and can be checkpointed (saved):
ctx := context.New()

// Execute it with GoMLX/XLA:
sentences := []string{
    "This is an example sentence",
    "Each sentence is converted"}
//... tokenize ...
inputIDs := [][]int64{
    {101, 2023, 2003, 2019, 2742, 6251,  102},
    { 101, 2169, 6251, 2003, 4991,  102,    0}}
tokenTypeIDs := [][]int64{
    {0, 0, 0, 0, 0, 0, 0},
    {0, 0, 0, 0, 0, 0, 0}}
attentionMask := [][]int64{
    {1, 1, 1, 1, 1, 1, 1},
    {1, 1, 1, 1, 1, 1, 0}}
var embeddings []*tensors.Tensor
embeddings = context.ExecOnceN( // Execute a GoMLX computation graph with a context
	backends.New(),  // GoMLX backend to use (defaults to XLA) 
	ctx, // Context store the model variables/weights and optional hyperparameters.
	func (ctx *context.Context, inputs []*Node) []*Node {
		// Convert ONNX model (in `model`) to a GoMLX computation graph. It returns a slice of values (with only one for this model)
		return model.CallGraph(ctx, inputs[0].Graph(), map[string]*Node{
			"input_ids": inputs[0],
			"attention_mask": inputs[1],
			"token_type_ids": inputs[2]}, targetOutputs...)
	inputIDs, attentionMask, tokenTypeIDs)  // Inputs to the GoMLX function.
fmt.Printf("Embeddings: %s", embeddings)

The output looks like:

Embeddings: [2][7][384]float32{
 {{-0.0886, -0.0368, 0.0180, ..., 0.0261, 0.0912, -0.0152},
  {-0.0200, -0.0014, -0.0177, ..., 0.0204, 0.0522, 0.1991},
  {-0.0196, -0.0336, -0.0319, ..., 0.0203, 0.0709, 0.0644},
  {-0.0253, 0.0408, 0.0125, ..., -0.0270, 0.0377, 0.1133},
  {-0.0140, -0.0275, 0.0796, ..., -0.0748, 0.0774, -0.0657},
  {0.0318, -0.0032, -0.0210, ..., 0.0387, 0.0191, -0.0059}},
 {{-0.0886, -0.0368, 0.0180, ..., 0.0261, 0.0912, -0.0152},
  {0.0304, 0.0531, -0.0238, ..., -0.1011, 0.0218, 0.0473},
  {-0.0027, -0.0508, 0.0805, ..., -0.0777, 0.0881, -0.0560},
  {0.0928, 0.0165, -0.0976, ..., 0.0449, 0.0390, -0.0182},
  {0.0231, 0.0090, -0.0213, ..., 0.0232, 0.0191, -0.0066},
  {-0.0213, 0.0019, 0.0043, ..., 0.0561, 0.0170, 0.0256}}}


  1. Extract the ONNX model's weight to GoMLX Context: see Model.VariablesToContext().
  2. Use Model.CallGraph() in your GoMLX model function (see example just above).
  3. Train model as usual in GoMLX.
  4. Depending how you are going to use the model:
    1. Save the model as a GoMLX checkpoint, as usual.
    2. Save the model by updating the ONNX model: after training use Model.ContextToONNX() to copy the update variable
      values from GoMLX Context back to the ONNX model (in-memory), and then use Model.Write() or Model.SaveToFile() to save the updated ONNX model to disk.


We have some GoMLX/XLA and ONNX Runtime (Microsoft) benchmarks in this spreadsheet, tested on the sentence encoder model we were interested in. This was used during development, and reflects how it improves performance -- the numbers on the bottom of the sheets are the currently accurate.

See docs/ for more information.

🀝 Collaborating

Collaboration is very welcome: either in the form of code, or simply with ideas with real applicability. Don't hesitate starting a discussion or issue in the repository.

If you are interested, we have two notebooks we use to compare results. They are a good starting point for anyone curious:

πŸ₯³ Thanks