<a href="https://colab.research.google.com/github/IoT-gamer/t5gemma2-onnx/blob/main/notebooks/t5gemma2_decoder_onnx_export.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# T5gemma-2 Deccoder ONNX Export
- exports only the decoder
- uses hugging face `transformers` and `torch.onnx.export`

## References/Acknowledgements
- [T5Gemma 2 Hugging Face](https://huggingface.co/docs/transformers/model_doc/t5gemma2)
- [Google Blog](https://blog.google/innovation-and-ai/technology/developers-tools/t5gemma-2/)

## Install Dependencies

In [None]:
!pip install --upgrade git+https://github.com/huggingface/transformers.git
!pip install onnxscript

## Load Model, Create Wrapper and Export

In [None]:
import torch
import transformers.models.t5gemma2.modeling_t5gemma2 as model_module
from transformers import T5Gemma2ForConditionalGeneration

# ---------------------------------------------------------
# LOAD MODEL
# ---------------------------------------------------------
model_id = "google/t5gemma-2-270m-270m"
print(f"Loading {model_id}...")
model = T5Gemma2ForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float32)
model.eval()

# ---------------------------------------------------------
# CACHE CLASSES
# ---------------------------------------------------------
class SimpleGrowingCache:
    def __init__(self): self.key, self.val = [], []

    def update(self, k, v, i, cache_kwargs=None, **kwargs):
        if len(self.key) <= i:
            self.key.append(k); self.val.append(v)
        else:
            self.key[i] = torch.cat([self.key[i], k], dim=2)
            self.val[i] = torch.cat([self.val[i], v], dim=2)
        return self.key[i], self.val[i]

    def get_seq_length(self, i=0): return self.key[i].shape[2] if len(self.key) > i else 0
    def get_max_length(self): return None

class SimpleStaticCache:
    def __init__(self): self.key, self.val = [], []

    def update(self, k, v, i, cache_kwargs=None, **kwargs):
        if len(self.key) <= i:
            self.key.append(k); self.val.append(v)
        else:
            self.key[i] = k; self.val[i] = v
        return self.key[i], self.val[i]

    def get_seq_length(self, i=0): return self.key[i].shape[2] if len(self.key) > i else 0
    def get_max_length(self): return None

class WrapperCache:
    def __init__(self, self_c, cross_c):
        self.self_attention_cache = self_c
        self.cross_attention_cache = cross_c
        self.is_updated = {}
    def get_seq_length(self, i=0): return self.self_attention_cache.get_seq_length(i)

# Monkey Patch
model_module.EncoderDecoderCache = WrapperCache

# ---------------------------------------------------------
# WRAPPER
# ---------------------------------------------------------
class FinalWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.decoder = model.get_decoder()
        self.lm_head = model.lm_head  # Include Head
        self.num_layers = model.config.decoder.num_hidden_layers

    def forward(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, position_ids, past_key_values):
        self_c = SimpleGrowingCache()
        cross_c = SimpleStaticCache()

        # Populate
        for i in range(self.num_layers):
            base = i * 4
            self_c.key.append(past_key_values[base])
            self_c.val.append(past_key_values[base+1])
            cross_c.key.append(past_key_values[base+2])
            cross_c.val.append(past_key_values[base+3])

        out = self.decoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            past_key_values=WrapperCache(self_c, cross_c),
            position_ids=position_ids,
            use_cache=True
        )

        # Output Logits directly
        logits = self.lm_head(out.last_hidden_state)

        # Flatten
        flat_out = []
        for i in range(self.num_layers):
            flat_out.extend([
                out.past_key_values.self_attention_cache.key[i],
                out.past_key_values.self_attention_cache.val[i],
                out.past_key_values.cross_attention_cache.key[i],
                out.past_key_values.cross_attention_cache.val[i]
            ])

        return logits, *flat_out

# ---------------------------------------------------------
# EXPORT
# ---------------------------------------------------------
d_conf = model.config.decoder
num_layers = d_conf.num_hidden_layers
head_dim = d_conf.head_dim
hidden_size = d_conf.hidden_size
num_kv_heads = d_conf.num_key_value_heads

# Dummy Inputs
batch, seq, enc_seq = 1, 1, 128
dummy_input_ids = torch.ones((batch, seq), dtype=torch.long)
dummy_mask = torch.ones((batch, 1, seq, seq), dtype=torch.float32)
dummy_enc_mask = torch.ones((batch, 1, 1, enc_seq), dtype=torch.float32)
dummy_enc_hidden = torch.randn((batch, enc_seq, hidden_size))
dummy_pos_ids = torch.zeros((batch, seq), dtype=torch.long)
flat_dummy_kv = [torch.randn(batch, num_kv_heads, 0, head_dim) for _ in range(num_layers * 4)]

kv_names = [f"past_{i}_{j}" for i in range(num_layers) for j in range(4)]
out_kv_names = [f"present_{i}_{j}" for i in range(num_layers) for j in range(4)]
input_names = ["input_ids", "attention_mask", "encoder_hidden_states", "encoder_attention_mask", "position_ids"] + kv_names


print("Exporting...")
try:
    torch.onnx.export(
        FinalWrapper(model),
        (dummy_input_ids, dummy_mask, dummy_enc_hidden, dummy_enc_mask, dummy_pos_ids, flat_dummy_kv),
        "t5gemma2_decoder.onnx",
        opset_version=14,
        input_names=input_names,
        output_names=["logits"] + out_kv_names, # Note: Output is LOGITS
        dynamic_axes={
            "input_ids": {0: "batch", 1: "seq"},
            "position_ids": {0: "batch", 1: "seq"},
            "encoder_hidden_states": {0: "batch", 1: "enc_seq"},
            "attention_mask": {0: "batch", 2: "seq", 3: "seq"},
            "encoder_attention_mask": {0: "batch", 3: "enc_seq"},
            **{n: {0: "batch", 2: "past_seq"} for n in kv_names}
        },
        dynamo=False
    )
    print("SUCCESS: The decoder has been exported.")
except Exception as e:
    print(f"Export failed: {e}")