In [5]:
import torch
import torch.onnx as onnx
from pathlib import Path

TS_PATH = "checkpoint.pt"
ONNX_PATH = "audiobox_aesthetics_op18.onnx"

# 1 – load the ScriptModule
ts = torch.jit.load(TS_PATH).eval()

# 2 – representative input (any length is fine)
dummy = torch.randn(1, 16_000 * 10)

# 3 – dynamo_export does the heavy lifting
prog = onnx.dynamo_export(
    ts,
    dummy,
    # opset 18 has the cleanest type/broadcast semantics
    opset_version = 18,
    # tell the exporter the time axis is variable length
    dynamic_shapes = {"audio": {1: "n_samples"}},
)
prog.save(ONNX_PATH)
print(f"✅  wrote {Path(ONNX_PATH).resolve()}")


RuntimeError: PytorchStreamReader failed locating file constants.pkl: file not found

In [11]:
import torch, torchaudio, audiobox_aesthetics  # pip install audiobox_aesthetics torchaudio onnx onnxruntime

from audiobox_aesthetics.model.aes import AesMultiOutput

device = "cpu"          # stay on CPU for export
ckpt   = "checkpoint.pt"  # or HF download path

core = AesMultiOutput.from_pretrained("facebook/audiobox-aesthetics").to(device).eval()

class Wrapper(torch.nn.Module):
    def __init__(self, inner):
        super().__init__()
        self.inner = inner
    def forward(self, wav):                    # wav: (B, T) 16 kHz mono
        out = self.inner(wav)[0]               # dict → tensor [B,4]
        return out

model = Wrapper(core)

example = torch.zeros(1, 160000, dtype=torch.float32)  # 1 s dummy

torch.onnx.export(
    model, example, "audiobox_aesthetics.onnx",

    input_names=["audio"],
    output_names=["scores"],
    dynamic_axes={
        "audio":  {1: "n_samples"},   # allow arbitrary length waveforms
        "scores": {0: "n_axes"}       # (CE, CU, PC, PQ)
    },
    opset_version=17,  # ONNX >=1.17, PyTorch 2.2
    export_params=True,
)
print("Exported OK!")


IndexError: too many indices for tensor of dimension 2

In [14]:
import torch

ckpt_path = "audiobox_aesthetics_ts_cpu.pt"

model = torch.jit.load(ckpt_path, map_location="cpu")

model.eval()          # important: disables dropout, etc.


RecursiveScriptModule(original_name=AesWrapper)

In [15]:
dummy = torch.randn(1, 16000 * 10)   # shape [batch, samples]


In [28]:

# 3.  ONNX export (no dynamic_axes → input/output shapes are frozen)
torch.onnx.export(
    model,
    dummy,
    "audiobox_aesthetics_fixed.onnx",
    opset_version=17,           # 13+ is usually fine; keep <= the runtime you’ll use
    input_names=["waveform"],
    output_names=["scores"],
    do_constant_folding=True,   # fold weights where possible
    verbose=False
)


In [29]:
sess = ort.InferenceSession("audiobox_aesthetics_fixed.onnx", providers=["CPUExecutionProvider"])
out = sess.run(None, {"waveform": dummy.numpy()})[0]
print(out)   # should match ≈ tensor([[2.1967, 4.3483, 2.1122, 4.6560]])


Fail: [ONNXRuntimeError] : 1 : FAIL : Node (/Where) Op (Where) [ShapeInferenceError] Incompatible dimensions

In [35]:
import onnx, netron  # <-- open Netron in browser

m = onnx.load("audiobox_aesthetics_fixed.onnx")
for i, n in enumerate(m.graph.node):
    if n.op_type == "Where":
        print(i, n.input)      # look at the 3 tensor names

243 ['/Cast_2_output_0', '/Constant_168_output_0', '/Add_8_output_0']
282 ['/Less_output_0', '/Abs_output_0', '/Min_output_0']
338 ['/Cast_6_output_0', '/Constant_205_output_0', '/ConstantOfShape_1_output_0']
369 ['/Equal_output_0', '/ConstantOfShape_2_output_0', '/Constant_219_output_0']
470 ['/Cast_10_output_0', '/Constant_253_output_0', '/ConstantOfShape_3_output_0']
501 ['/Equal_1_output_0', '/ConstantOfShape_4_output_0', '/Constant_267_output_0']
602 ['/Cast_14_output_0', '/Constant_301_output_0', '/ConstantOfShape_5_output_0']
633 ['/Equal_2_output_0', '/ConstantOfShape_6_output_0', '/Constant_315_output_0']
734 ['/Cast_18_output_0', '/Constant_349_output_0', '/ConstantOfShape_7_output_0']
765 ['/Equal_3_output_0', '/ConstantOfShape_8_output_0', '/Constant_363_output_0']
866 ['/Cast_22_output_0', '/Constant_397_output_0', '/ConstantOfShape_9_output_0']
897 ['/Equal_4_output_0', '/ConstantOfShape_10_output_0', '/Constant_411_output_0']
998 ['/Cast_26_output_0', '/Constant_445_outp

In [37]:
import torch.onnx as onnx

# 3. Export (without the dynamic_axes argument)
onnx.export(
    ts,
    dummy,
    "audiobox_aesthetics_fixed0.onnx", # Use a new name to avoid confusion
    export_params=True,
    opset_version=17,
    input_names=["audio"],
    output_names=["scores"],
) # <--- dynamic_axes argument is removed




In [45]:
import onnxruntime as rt
import numpy as np

# The shape of this input MUST match the shape of the 'dummy' tensor used for export.
inference_input = np.random.randn(1, 16000 * 180).astype(np.float32)

# This would now cause an error because the shape is different
# wrong_input = np.random.randn(1, 80000).astype(np.float32) 

sess = rt.InferenceSession("audiobox_aesthetics_fixed0.onnx")

# This should now work without a runtime error
out = sess.run(None, {"audio": inference_input})[0]

print(torch.tensor(out))

Fail: [ONNXRuntimeError] : 1 : FAIL : Node (/Where) Op (Where) [ShapeInferenceError] Incompatible dimensions

In [32]:
from torch.onnx import dynamo_export

onnx_program = dynamo_export(
    model,
    dummy,
    dynamic_shapes=True,         # toggles the same “batch / time” flexibility
    opset_version=17,
)
onnx_program.save("audiomodel.onnx")
# ``` :contentReference[oaicite:1]{index=1}

# Use this if your model has data-dependent branching, loops, or ops unsupported by the legacy tracer.

# ---

# ## 5  Quick verification with ONNX Runtime

# ```python
# import onnxruntime as ort
# import numpy as np

# ort_sess = ort.InferenceSession("audiomodel.onnx", providers=["CPUExecutionProvider"])
# out = ort_sess.run(None, {"audio": dummy_audio.numpy()})[0]
# print(out)           # should be shape [1, 4] ~ tensor([[2.19, 4.34, …]])


OnnxExporterError: Failed to export the model to ONNX. Generating SARIF report at 'report_dynamo_export.sarif'. SARIF is a standard format for the output of static analysis tools. SARIF logs can be loaded in VS Code SARIF viewer extension, or SARIF web viewer (https://microsoft.github.io/sarif-web-component/). Please report a bug on PyTorch Github: https://github.com/pytorch/pytorch/issues