In [None]:
import os
import torch
from transformers import AutoConfig, AutoTokenizer, AutoProcessor
from transformers.models.gemma3 import modeling_gemma3
import logging

In [None]:
logger = logging.getLogger(__name__)

In [None]:
class EmbeddingLayer(torch.nn.Module):
    def __init__(self, full_model):
        super().__init__()
        self.embedding_layer = full_model.language_model.model.embed_tokens

    def forward(self, input_ids, image_features):
        image_token_index = 262144
        inputs_embeds = self.embedding_layer(input_ids)

        special_image_mask = (input_ids == image_token_index).unsqueeze(-1)
        special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
        image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
        return inputs_embeds.masked_scatter(special_image_mask, image_features)

In [None]:
cache_dir='/tmp/cache_dir'
workfolder = os.getcwd()  # point to where the export artifacts of this notebook to be saved

model_name = 'gemma_4b'
model_id="google/gemma-3-4b-it"

lmm_config = AutoConfig.from_pretrained(model_id, cache_dir=cache_dir, trust_remote_code=True)

In [None]:
full_model = modeling_gemma3.Gemma3ForConditionalGeneration.from_pretrained(model_id, config=lmm_config, cache_dir=cache_dir)
logger.info("Loaded full model: %s", full_model)

In [None]:
embedding_layer = EmbeddingLayer(full_model)
logger.info("Created embedding-only model: %s", embedding_layer)

In [None]:
input_ids = torch.randint(3000, size = (1, 128)).to(torch.int64)
image_features = torch.randn((1, 256, 2560)).to(torch.float32)

input_names = [ "input_ids", "image_features" ]
output_names = [ "/model/embed_tokens/Mul/output_0" ]
dynamic_axes = {
    "input_ids": { 0: "batch_size", 1: "seq_length" }, 
    "image_features": { 0: "batch_size", 1: "image_tokens_length" }
}

In [None]:
export_dir = os.path.join(workfolder, "embed_fp32")
os.makedirs(export_dir, exist_ok = True)

In [None]:
torch.onnx.export(embedding_layer, {"input_ids": input_ids, "image_features": image_features}, os.path.join(export_dir, "embed_fp32.onnx"), 
                input_names = input_names, output_names = output_names, opset_version = 19, dynamic_axes = dynamic_axes)