<a href="https://colab.research.google.com/github/IoT-gamer/t5gemma2-onnx/blob/main/notebooks/t5gemma2_multimodal_encoder_onnx_export.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# T5gemma-2 Multimodal Encoder ONNX Export
- exports only the encoder
- encoder accepts text and/or image input
- uses hugging face `transformers` and `torch.onnx.export`

## References/Acknowledgements
- [T5Gemma 2 Hugging Face](https://huggingface.co/docs/transformers/model_doc/t5gemma2)
- [Google Blog](https://blog.google/innovation-and-ai/technology/developers-tools/t5gemma-2/)

## Install Dependencies

In [None]:
!pip install --upgrade git+https://github.com/huggingface/transformers.git
!pip install onnxscript

## Load Model, Create Wrapper and Export

In [None]:
import torch
import torch.nn as nn
import transformers.masking_utils
from transformers import T5Gemma2ForConditionalGeneration

# ---------------------------------------------------------
# PATCHING
# ---------------------------------------------------------
# We patch the mask generation to avoid vmap/complex logic during tracing.
def patched_create_bidirectional_mask(config, input_embeds, attention_mask, **kwargs):
    batch_size, seq_length = input_embeds.shape[:2]
    return torch.zeros((batch_size, 1, seq_length, seq_length), device=input_embeds.device)

transformers.masking_utils.create_bidirectional_mask = patched_create_bidirectional_mask

# ---------------------------------------------------------
# WRAPPER
# ---------------------------------------------------------
class T5Gemma2EncoderMultimodalWrapper(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder

    def forward(self, input_ids, attention_mask, pixel_values):
        batch, seq = input_ids.shape

        # Create the 4D attention mask [batch, 1, 1, seq] expanded to [batch, 1, seq, seq]
        # This allows the ONNX model to handle text/image masking externally.
        extended_mask = attention_mask.view(batch, 1, 1, seq).expand(batch, 1, seq, seq)
        extended_mask = (1.0 - extended_mask.float()) * torch.finfo(torch.float32).min

        # The encoder handles the vision tower projection and masked_scatter internally
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=extended_mask,
            pixel_values=pixel_values,
            return_dict=False
        )
        return outputs[0] # last_hidden_state

# ---------------------------------------------------------
# LOAD MODEL
# ---------------------------------------------------------
model_id = "google/t5gemma-2-270m-270m"
print(f"Loading {model_id}...")
model = T5Gemma2ForConditionalGeneration.from_pretrained(
    model_id, torch_dtype=torch.float16, device_map="cpu" # Use float16 to save 50% RAM
)
model.eval()

encoder = model.get_encoder()

# Disable kernel wrappers if they exist to simplify the trace
for layer in encoder.layers:
    if hasattr(layer.self_attn.forward, "__wrapped__"):
        layer.self_attn.forward = layer.self_attn.forward.__wrapped__

# ---------------------------------------------------------
# PREPARE DUMMY INPUTS
# ---------------------------------------------------------
# Dynamically pull constants from the config to prevent "tensor a vs tensor b" errors
mm_tokens_per_image = model.config.encoder.mm_tokens_per_image
image_token_index = model.config.encoder.image_token_index
vision_cfg = model.config.encoder.vision_config

batch = 1
text_seq_len = 10
# The total sequence must contain exactly mm_tokens_per_image instances of image_token_index
total_seq_len = mm_tokens_per_image + text_seq_len

# Build input_ids: [BOS, Image_Tokens..., Text_Tokens...]
dummy_input_ids = torch.full((batch, total_seq_len), 1, dtype=torch.long)
dummy_input_ids[0, 0] = 2 # BOS
dummy_input_ids[0, 1 : 1 + mm_tokens_per_image] = image_token_index

dummy_attention_mask = torch.ones((batch, total_seq_len), dtype=torch.long)

# Create dummy pixel values based on vision config (standard Siglip is 224x224)
dummy_pixel_values = torch.randn(
    batch,
    3,
    vision_cfg.image_size,
    vision_cfg.image_size
)

dummy_inputs = (dummy_input_ids, dummy_attention_mask, dummy_pixel_values)

# ---------------------------------------------------------
# EXPORT
# ---------------------------------------------------------
print("Exporting Multimodal Encoder...")
try:
    torch.onnx.export(
        T5Gemma2EncoderMultimodalWrapper(encoder),
        dummy_inputs,
        "t5gemma2_encoder_multimodal.onnx",
        opset_version=17, # Supports RoPE and newer vision ops
        input_names=["input_ids", "attention_mask", "pixel_values"],
        output_names=["last_hidden_state"],
        dynamic_axes={
            "input_ids": {0: "batch", 1: "seq"},
            "attention_mask": {0: "batch", 1: "seq"},
            "pixel_values": {0: "batch"},
        },
        dynamo=False
    )
    print("SUCCESS: The multimodal encoder has been exported.")
except Exception as e:
    print(f"Export failed: {e}")