<a href="https://colab.research.google.com/github/IoT-gamer/t5gemma2-onnx/blob/main/notebooks/test_t5gemma2_onnx.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Test T5gemma-2 ONNX Export
- **prerequisite:** encoder and decoder ONNX models

## Install Dependencies

In [None]:
!pip install --upgrade git+https://github.com/huggingface/transformers.git
!pip install onnxruntime

## If ONNX models stored in google drive and using colab, connect google drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Select ONNX model paths

In [None]:
decoder_path = "/content/drive/MyDrive/t5gemma2/t5gemma2_decoder.onnx"
encoder_path = "/content/drive/MyDrive/t5gemma2/t5gemma2_encoder.onnx"

## Setup ONNX Models

In [None]:
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer

# Setup
model_id = "google/t5gemma-2-270m-270m"

print(f"Loading Tokenizer from {model_id}...")
tokenizer = AutoTokenizer.from_pretrained(model_id)

print("Loading ONNX Models...")
sess_encoder = ort.InferenceSession(encoder_path)
sess_decoder = ort.InferenceSession(decoder_path)

def generate_text(prompt, max_length=50):
    print(f"\nPrompt: '{prompt}'")

    # ENCODE
    inputs = tokenizer(prompt, return_tensors="np")
    input_ids = inputs["input_ids"]
    mask_int64 = inputs["attention_mask"]
    mask_float = mask_int64.astype(np.float32)

    enc_outs = sess_encoder.run(None, {
        "input_ids": input_ids,
        "attention_mask": mask_int64
    })
    encoder_hidden_states = enc_outs[0]
    batch_size = encoder_hidden_states.shape[0]

    # DECODER PREP
    decoder_input_ids = np.array([[2]], dtype=np.int64) # BOS

    # Init Cache (256 dim)
    kv_cache = {}
    for node in sess_decoder.get_inputs():
        if "past_" in node.name:
            kv_cache[node.name] = np.zeros((batch_size, 1, 0, 256), dtype=np.float32)

    generated_tokens = []
    print("Generating", end="", flush=True)

    current_length = 0
    for step in range(max_length):

        # Prepare Inputs
        position_ids = np.array([[current_length]], dtype=np.int64)
        self_attn_mask = np.ones((batch_size, 1, 1, current_length + 1), dtype=np.float32)
        cross_attn_mask = mask_float[:, None, None, :]

        inputs_feed = {
            "input_ids": decoder_input_ids,
            "attention_mask": self_attn_mask,
            "encoder_attention_mask": cross_attn_mask,
            "position_ids": position_ids,
            "encoder_hidden_states": encoder_hidden_states,
        }
        inputs_feed.update(kv_cache)

        # Run
        outputs = sess_decoder.run(None, inputs_feed)
        logits = outputs[0]

        # Apply simple penalty to avoid repeating the exact previous token
        if step > 0:
            logits[:, 0, generated_tokens[-1]] /= 2.0

        next_token_id = np.argmax(logits[:, -1, :], axis=-1).item()

        # Stop conditions
        if next_token_id == tokenizer.eos_token_id:
            break

        generated_tokens.append(next_token_id)
        print(".", end="", flush=True)

        decoder_input_ids = np.array([[next_token_id]], dtype=np.int64)
        current_length += 1

        # Update Cache
        out_names = [n.name for n in sess_decoder.get_outputs()]
        for i, name in enumerate(out_names):
            if "present_" in name:
                past_name = name.replace("present_", "past_")
                if past_name in kv_cache:
                    kv_cache[past_name] = outputs[i]

    print("\n")
    return tokenizer.decode(generated_tokens, skip_special_tokens=True)

## Prompt Model
- can vary `max_length`

In [None]:
text = "Summarize: Jupiter is the fifth planet from the Sun and the largest in the Solar System. It is a gas giant with a mass one-thousandth that of the Sun, but two-and-a-half times that of all the other planets in the Solar System combined."
output = generate_text(text, max_length=10)
print(f"Result: {output}")