In [11]:
# ── 0. Setup ──────────────────────────────────────────────────────────────
# pip install --upgrade torch onnx onnxruntime

import torch
import torch.nn as nn
import torch.nn.functional as F
import onnx
import onnxruntime as ort
from pathlib import Path

# ── 1. Model definition (unchanged) ───────────────────────────────────────
class AudioTransformer(nn.Module):
    def __init__(
        self,
        patch_size: int = 400,
        embed_dim: int = 32,
        num_layers: int = 8,
        num_heads: int = 8,
        mlp_dim: int = 32,
    ):
        super().__init__()
        self.patch_size = patch_size
        num_patches = 16_000 // patch_size
        self.patch_embed = nn.Linear(patch_size, embed_dim)
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim))
        enc_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=mlp_dim,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers)
        self.head = nn.Linear(embed_dim, 4)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N = x.shape
        assert N == 16_000, "Input must be exactly 16 000 samples"
        x = x.view(B, -1, self.patch_size)
        x = self.patch_embed(x) + self.pos_embed
        x = self.encoder(x)
        x = x.mean(dim=1)
        x = torch.sigmoid(self.head(x)) * 10
        return x

# ── 2. Instantiate & dummy input ──────────────────────────────────────────
model = AudioTransformer()
model.eval()

dummy = torch.randn(1, 16_000)  # (batch=1, samples)

# ── 3. Export to ONNX ─────────────────────────────────────────────────────
onnx_path = Path("small_audio_transformer_static.onnx")
torch.onnx.export(
    model,
    dummy,
    onnx_path.as_posix(),
    export_params=True,
    opset_version=17,           # >=17 recommended for latest runtimes
    do_constant_folding=True,
    input_names=["waveform"],
    output_names=["scores"],

)
print(f"✔️  Exported to {onnx_path.resolve()}")


  assert N == 16_000, "Input must be exactly 16 000 samples"


✔️  Exported to /media/dadatron/squirrel/notebooks/zk/small_audio_transformer_static.onnx


In [17]:
import onnx, math
print(sum(math.prod(t.dims) for t in onnx.load("small_audio_transformer_static.onnx").graph.initializer))


20612


In [18]:

# ── 4. Quick correctness check ────────────────────────────────────────────
ort_sess = ort.InferenceSession(onnx_path.as_posix(), providers=["CPUExecutionProvider"])
out   = ort_sess.run(None, {"waveform": dummy.numpy()})[0]
print("ONNX output:", out, out.shape)   # (1, 4) in [0,10]


ONNX output: [[5.570454  6.7199993 3.6120117 5.174857 ]] (1, 4)


In [19]:
dummy = torch.randn(1, 16_000)  # (batch=1, samples)


In [20]:
from onnxruntime.quantization import quantize_dynamic, QuantType

float_model = "small_audio_transformer_static.onnx"
int8_model  = "small_audio_transformer_static_int8_dyn.onnx"

quantize_dynamic(
    model_input=float_model,
    model_output=int8_model,
    weight_type=QuantType.QInt8,   # signed 8-bit weights
    # per_channel=True,            # optional: finer granularity
    # reduce_range=True,           # optional: 7-bit weights
)
print("✅ dynamic-INT8 model saved to", int8_model)




✅ dynamic-INT8 model saved to small_audio_transformer_static_int8_dyn.onnx


In [21]:
import onnx
from onnx import TensorProto

model = onnx.load("small_audio_transformer_static_int8_dyn.onnx")

dtype_map = {
    TensorProto.INT8:  "INT8",
    TensorProto.UINT8: "UINT8",
    TensorProto.FLOAT: "FP32",
    TensorProto.FLOAT16: "FP16",
    # add others if needed
}

counts = {}
for init in model.graph.initializer:
    dt = dtype_map.get(init.data_type, str(init.data_type))
    counts[dt] = counts.get(dt, 0) + 1

print("Initializer dtypes:", counts)


Initializer dtypes: {'FP32': 18, 'INT8': 12}


In [23]:
import onnx, math
print(sum(math.prod(t.dims) for t in onnx.load("small_audio_transformer_static_int8_dyn.onnx").graph.initializer))


26768
