<a href="https://colab.research.google.com/github/IoT-gamer/t5gemma2-onnx/blob/main/notebooks/t5gemma2_encoder_onnx_export.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# T5gemma-2 Encoder ONNX Export
- exports only the encoder
- uses hugging face `trasnformers` and `torch.onnx.export`

## References/Acknowledgements
- [T5Gemma 2 Hugging Face](https://huggingface.co/docs/transformers/model_doc/t5gemma2)
- [Google Blog](https://blog.google/innovation-and-ai/technology/developers-tools/t5gemma-2/)

## Install Dependencies

In [None]:
!pip install --upgrade git+https://github.com/huggingface/transformers.git
!pip install onnxscript

## Load Model, Create Wrapper and Export

In [None]:
import torch
import torch.nn as nn
import transformers.masking_utils
from transformers import T5Gemma2ForConditionalGeneration

# Patch
def patched_create_bidirectional_mask(config, input_embeds, attention_mask, **kwargs):
    # We return a simple zero-mask (no masking) to allow the tracer to pass.
    # The actual masking logic will be handled by the attention_mask input in ONNX.
    batch_size, seq_length = input_embeds.shape[:2]
    return torch.zeros((batch_size, 1, seq_length, seq_length), device=input_embeds.device)

# Apply the patch globally before the tracer starts
transformers.masking_utils.create_bidirectional_mask = patched_create_bidirectional_mask

# Load the model
model_id = "google/t5gemma-2-270m-270m"
model = T5Gemma2ForConditionalGeneration.from_pretrained(
    model_id, torch_dtype=torch.float32, device_map="cpu"
)
model.eval()

# Patch kernels to avoid hidden ops
encoder = model.get_encoder()
for layer in encoder.layers:
    if hasattr(layer.self_attn.forward, "__wrapped__"):
        layer.self_attn.forward = layer.self_attn.forward.__wrapped__

class T5Gemma2EncoderWrapper(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder

    def forward(self, input_ids, attention_mask):
        # Create a simplified 4D attention mask externally
        # to avoid the model's internal vmap-based mask generation
        batch, seq = input_ids.shape
        # Standard 4D mask: [batch, 1, seq, seq]
        extended_mask = attention_mask.view(batch, 1, 1, seq).expand(batch, 1, seq, seq)
        extended_mask = (1.0 - extended_mask) * torch.finfo(torch.float32).min

        return self.encoder(
            input_ids=input_ids,
            attention_mask=extended_mask, # Pass the pre-computed mask
            return_dict=False
        )[0]

# Export with Legacy Tracer
trace_len = 128
dummy_inputs = (torch.ones((1, trace_len), dtype=torch.long), torch.ones((1, trace_len), dtype=torch.long))


try:
    torch.onnx.export(
        T5Gemma2EncoderWrapper(encoder),
        dummy_inputs,
        "t5gemma2_encoder.onnx",
        opset_version=17, # Supports RoPE and SWA
        input_names=["input_ids", "attention_mask"],
        output_names=["last_hidden_state"],
        dynamic_axes={
            "input_ids": {0: "batch", 1: "seq"},
            "attention_mask": {0: "batch", 1: "seq"},
        },
        dynamo=False
    )
    print("SUCCESS: The encoder has been exported.")
except Exception as e:
    print(f"Export failed: {e}")