<a href="https://colab.research.google.com/github/IoT-gamer/t5gemma2-onnx/blob/main/notebooks/test_multimodal_t5gemma2_onnx.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Test Mutimodal T5gemma-2 ONNX Export
- **prerequisite:** *multimodal* encoder and decoder ONNX models

## Setup Device
- CPU or GPU

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Install Dependencies

In [None]:
!pip install --upgrade git+https://github.com/huggingface/transformers.git
if device.type == "cuda":
  print("Installing onnxruntime-gpu...")
  !pip install onnxruntime-gpu
else:
  print("Installing onnxruntime...")
  !pip install onnxruntime

## If ONNX models stored in google drive and using colab, connect google drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Define ONNX model paths

In [None]:
encoder_path = "/content/drive/MyDrive/t5gemma2/t5gemma2_encoder_multimodal.onnx"
decoder_path = "/content/drive/MyDrive/t5gemma2/t5gemma2_decoder.onnx"

## Setup ONNX Models

In [None]:
import numpy as np
import onnxruntime as ort
import requests
import gc
from PIL import Image
from transformers import AutoProcessor

# Load Processor Only (Lightweight)
model_id = "google/t5gemma-2-270m-270m"
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)

def generate_multimodal_text_optimized(prompt, image_url=None, max_length=50):
    print(f"\nPrompt: '{prompt}'")

    # Process Inputs
    image = None
    if image_url:
        print("Fetching image...")
        image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")

    # Get PyTorch tensors first
    inputs_pt = processor(text=prompt, images=image, return_tensors="pt")

    # Helper to convert to NumPy
    def to_numpy(tensor):
        return tensor.detach().cpu().numpy() if tensor is not None else None

    # Create the NumPy inputs dictionary
    enc_inputs = {
        "input_ids": to_numpy(inputs_pt["input_ids"]),
        "attention_mask": to_numpy(inputs_pt["attention_mask"]),
        "pixel_values": to_numpy(inputs_pt.get("pixel_values"))
    }

    # Handle missing image case
    if enc_inputs["pixel_values"] is None:
         enc_inputs["pixel_values"] = np.zeros((1, 3, 224, 224), dtype=np.float32)
    else:
         enc_inputs["pixel_values"] = enc_inputs["pixel_values"].astype(np.float32)

    # Run Encoder (Load -> Run -> Delete)
    print("Loading Encoder Session...")
    # Note: Uses CUDA if you have onnxruntime-gpu installed and CUDA GPU
    sess_encoder = ort.InferenceSession(encoder_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])

    print("Running Encoder...")
    enc_outs = sess_encoder.run(None, enc_inputs)
    encoder_hidden_states = enc_outs[0]

    del sess_encoder
    gc.collect()
    print("Encoder unloaded.")

    # Run Decoder (Load -> Generate)
    print("Loading Decoder Session...")
    sess_decoder = ort.InferenceSession(decoder_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])

    # Detect KV Cache dimensions dynamically
    batch_size = encoder_hidden_states.shape[0]
    past_shape_example = next(node.shape for node in sess_decoder.get_inputs() if "past_" in node.name)
    num_kv_heads = past_shape_example[1]
    head_dim = past_shape_example[3]

    kv_cache = {}
    for node in sess_decoder.get_inputs():
        if "past_" in node.name:
            kv_cache[node.name] = np.zeros((batch_size, num_kv_heads, 0, head_dim), dtype=np.float32)

    decoder_input_ids = np.array([[2]], dtype=np.int64) # BOS token
    generated_tokens = []
    current_length = 0

    mask_float = enc_inputs["attention_mask"].astype(np.float32)

    print("Generating", end="", flush=True)
    for _ in range(max_length):
        position_ids = np.array([[current_length]], dtype=np.int64)
        self_attn_mask = np.ones((batch_size, 1, 1, current_length + 1), dtype=np.float32)
        cross_attn_mask = mask_float[:, None, None, :]

        inputs_feed = {
            "input_ids": decoder_input_ids,
            "attention_mask": self_attn_mask,
            "encoder_attention_mask": cross_attn_mask,
            "position_ids": position_ids,
            "encoder_hidden_states": encoder_hidden_states,
        }
        inputs_feed.update(kv_cache)

        outputs = sess_decoder.run(None, inputs_feed)
        logits = outputs[0]

        # --- Repetition Penalty ---
        next_token_logits = logits[:, -1, :].copy()
        penalty = 1.2
        if len(generated_tokens) > 0:
            unique_tokens = set(generated_tokens)
            for token_id in unique_tokens:
                if next_token_logits[0, token_id] > 0:
                    next_token_logits[0, token_id] /= penalty
                else:
                    next_token_logits[0, token_id] *= penalty

        next_token_id = np.argmax(next_token_logits, axis=-1).item()

        if next_token_id == processor.tokenizer.eos_token_id:
            break

        generated_tokens.append(next_token_id)
        print(".", end="", flush=True)

        decoder_input_ids = np.array([[next_token_id]], dtype=np.int64)
        current_length += 1

        # Update KV Cache
        out_names = [n.name for n in sess_decoder.get_outputs()]
        for i, name in enumerate(out_names):
            if "present_" in name:
                past_name = name.replace("present_", "past_")
                if past_name in kv_cache:
                    kv_cache[past_name] = outputs[i]

    del sess_decoder
    gc.collect()

    print("\n")
    return processor.batch_decode([generated_tokens], skip_special_tokens=True)[0]

## Prompt Model
- can vary `max_length`

In [None]:
# Test with an image of a cat
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/cat.jpg"

prompt_text = "<start_of_image> What is in this image?"

result = generate_multimodal_text_optimized(prompt_text, image_url=img_url, max_length=10)
print(f"Final Result: {result}")