In [None]:
!pip install lerobot transformers num2words onnx onnxruntime



Export SmolVLA's vision backbone (+ connector) to ONNX
- Portable: CPU float32, opset=17, no custom ops
- Input: pixel_values [B,3,H,W] from the model's processor size
- Output: image_hidden_states after connector [B, S_img, D]

In [None]:
import os
from pathlib import Path
import torch
import onnx
from onnx import checker, shape_inference
import onnxruntime as ort
from transformers import AutoModelForImageTextToText, AutoProcessor

MODEL_ID = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
# MODEL_ID = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct"
ONNX_VISION_OUT = Path("smolvla_vision_connector.onnx")
ONNX_TEXT_OUT = Path("smolvla_embedding.onnx")
OPSET = 17  # safest widely-supported set; increase later if you need newer ops

# %% Load processor (for canonical image size) & model
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)

# # Determine (H, W) used by the vision encoder
# def _resolve_image_hw(proc):
#     imgp = getattr(proc, "image_processor", None) or getattr(proc, "image_preprocessor", None)
#     size = getattr(imgp, "size", None) or getattr(imgp, "crop_size", None)
#     if isinstance(size, dict):
#         h = size.get("height") or size.get("shortest_edge")
#         w = size.get("width")  or size.get("shortest_edge")
#         if h is None or w is None:  # fallback if dict doesn’t have h/w keys
#             h = w = next(iter(size.values()))
#     elif isinstance(size, int):
#         h = w = size
#     else:  # sensible default if missing
#         h = w = 378
#     return int(h), int(w)

# H, W = _resolve_image_hw(processor)
H, W = 512, 512
print(f"Exporting for image size: H={H}, W={W}")

# Load only once, CPU/float32 for portability
vlm = AutoModelForImageTextToText.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    torch_dtype=torch.float32,
    low_cpu_mem_usage=True,
)
vlm.eval().to("cpu")



Exporting for image size: H=512, W=512


SmolVLMForConditionalGeneration(
  (model): SmolVLMModel(
    (vision_model): SmolVLMVisionTransformer(
      (embeddings): SmolVLMVisionEmbeddings(
        (patch_embedding): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), padding=valid)
        (position_embedding): Embedding(1024, 768)
      )
      (encoder): SmolVLMEncoder(
        (layers): ModuleList(
          (0-11): 12 x SmolVLMEncoderLayer(
            (self_attn): SmolVLMVisionAttention(
              (k_proj): Linear(in_features=768, out_features=768, bias=True)
              (v_proj): Linear(in_features=768, out_features=768, bias=True)
              (q_proj): Linear(in_features=768, out_features=768, bias=True)
              (out_proj): Linear(in_features=768, out_features=768, bias=True)
            )
            (layer_norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
            (mlp): SmolVLMVisionMLP(
              (activation_fn): PytorchGELUTanh()
              (fc1): Linear(in_features=768, out_

In [None]:
# %% Wrap vision backbone + connector
class VisionBackboneConnector(torch.nn.Module):
    """
    pixel_values [B,3,H,W]  ->  vision_model (...) -> connector -> image_hidden_states [B,S_img,D]
    """
    def __init__(self, vlm_model):
        super().__init__()
        self.vision_model = vlm_model.model.vision_model
        self.connector    = vlm_model.model.connector

    def forward(self, pixel_values: torch.Tensor):
        # No patch_attention_mask for portability (default=None in upstream code)
        vout = self.vision_model(pixel_values=pixel_values, patch_attention_mask=None)
        img_hidden = vout.last_hidden_state              # [B, S_img, D_vit]
        img_hidden = self.connector(img_hidden)          # [B, S_img, D_llm]
        return img_hidden

wrapper = VisionBackboneConnector(vlm).eval()
for p in wrapper.parameters():
    p.requires_grad_(False)

# %% Dummy input (CPU/float32), using model's expected HxW
dummy = torch.zeros(1, 3, H, W, dtype=torch.float32)

# %% Export to ONNX
dynamic_axes = {"pixel_values": {0: "batch"}, "image_hidden_states": {0: "batch", 1: "seq"}}
torch.onnx.export(
    wrapper,
    (dummy,),
    str(ONNX_VISION_OUT),
    input_names=["pixel_values"],
    output_names=["image_hidden_states"],
    do_constant_folding=True,
    opset_version=OPSET,
    dynamic_axes=dynamic_axes,
    training=torch.onnx.TrainingMode.EVAL,
)

  for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
  inverted_mask = torch.tensor(1.0, dtype=dtype) - expanded_mask
  height = width = int(seq**0.5)
  x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))


In [None]:
# Inspect dims (just informative)
hidden = vlm.config.text_config.hidden_size
vocab  = vlm.config.text_config.vocab_size
print(f"Text hidden size: {hidden}, vocab: {vocab}")

# %% Wrap text embedder only
class TextEmbedder(torch.nn.Module):
    """
    input_ids [B,S] (int64) -> embed_tokens -> float32 embeddings [B,S,D]
    """
    def __init__(self, vlm_model):
        super().__init__()
        # Use the same path as in smolvlm_with_expert: text_model.get_input_embeddings()
        self.embed = vlm_model.model.text_model.get_input_embeddings()

    def forward(self, input_ids: torch.LongTensor):
        emb = self.embed(input_ids)         # [B,S,D] float32
        # Ensure portable dtype
        return emb.to(torch.float32)

wrapper = TextEmbedder(vlm).eval()
for p in wrapper.parameters():
    p.requires_grad_(False)

# %% Dummy IDs (tokenization happens outside ONNX; ONNX only maps ids->vectors)
dummy_ids = torch.zeros(1, 16, dtype=torch.long)  # [B=1, S=16]

dynamic_axes = {
    "input_ids": {0: "batch", 1: "seq"},
    "text_embeddings": {0: "batch", 1: "seq"},
}


Text hidden size: 576, vocab: 49280


In [None]:
torch.onnx.export(
    wrapper,
    (dummy_ids,),
    str(ONNX_TEXT_OUT),
    input_names=["input_ids"],
    output_names=["text_embeddings"],
    do_constant_folding=True,
    opset_version=OPSET,
    dynamic_axes=dynamic_axes,
    training=torch.onnx.TrainingMode.EVAL,
)

# Validate + shape inference (helps downstream tooling)
m = onnx.load(str(ONNX_TEXT_OUT))
checker.check_model(m)
m = shape_inference.infer_shapes(m)
onnx.save(m, str(ONNX_TEXT_OUT))
print("Saved:", ONNX_TEXT_OUT.resolve())

# %% Sanity check with ORT (CPU)
sess = ort.InferenceSession(str(ONNX_TEXT_OUT), providers=["CPUExecutionProvider"])
out = sess.run(["text_embeddings"], {"input_ids": dummy_ids.numpy()})[0]
print("Output shape:", out.shape)   # expect [1, 16, D]

Saved: /content/smolvla_embedding.onnx
Output shape: (1, 16, 576)


In [None]:
from onnx import helper, TensorProto, checker, shape_inference
def patch_gather_idx_to_int64_simple(path: str):
    m = onnx.load(path)
    g = m.graph

    # Gather a set of all existing names to avoid collisions
    existing = {vi.name for vi in list(g.input)+list(g.output)+list(g.value_info)}
    for n in g.node: existing.update(n.output)
    for init in g.initializer: existing.add(init.name)

    def uniq(base: str) -> str:
        name, k = base, 0
        while name in existing:
            k += 1
            name = f"{base}_{k}"
        existing.add(name)
        return name

    cache = {}          # index_tensor_name -> cast_output_name
    new_nodes = []

    for n in g.node:
        if n.op_type == "Gather" and len(n.input) >= 2:
            data_in, idx_in = n.input[0], n.input[1]

            # reuse a single Cast per shared index tensor
            cast_out = cache.get(idx_in)
            if cast_out is None:
                cast_out = uniq(idx_in + "_to_i64")
                cast_name = uniq(idx_in + "_CastToInt64")
                cast = helper.make_node("Cast", [idx_in], [cast_out], name=cast_name, to=TensorProto.INT64)
                new_nodes.append(cast)
                cache[idx_in] = cast_out

            # recreate the Gather, preserving attributes and outputs
            new_gather = helper.make_node(
                "Gather",
                inputs=[data_in, cast_out] + list(n.input[2:]),
                outputs=list(n.output),
                name=n.name if n.name else uniq(data_in + "_Gather"),
            )
            new_gather.attribute.extend(list(n.attribute))
            new_nodes.append(new_gather)
        else:
            # copy other nodes as-is
            new_nodes.append(n)

    # replace nodes (no slice assignment)
    g.ClearField("node")
    g.node.extend(new_nodes)

    # infer & validate
    m = shape_inference.infer_shapes(m)
    checker.check_model(m)
    onnx.save(m, path)
    print("Patched & saved:", path)

In [None]:
# Validate & add inferred shapes (helps downstream tools)
patch_gather_idx_to_int64_simple(str(ONNX_VISION_OUT))
print("Saved:", ONNX_VISION_OUT.resolve())



Patched & saved: smolvla_vision_connector.onnx
Saved: /content/smolvla_vision_connector.onnx


In [None]:
import onnxruntime as ort
import torch
from pathlib import Path

ONNX_OUT = Path("smolvla_vision_connector.onnx")
dummy = torch.zeros(1, 3, 512, 512, dtype=torch.float32)
# %% Quick runtime sanity check (CPU)
sess = ort.InferenceSession(str(ONNX_OUT), providers=["CUDAExecutionProvider"])
out = sess.run(["image_hidden_states"], {"pixel_values": dummy.numpy()})[0]
print("Output shape:", out.shape)

Output shape: (1, 64, 576)
