# FunctionGemma ‚Üí TFLite Conversion

Conversion notebook for [flutter_gemma](https://github.com/DenisovAV/flutter_gemma) plugin.

Converts fine-tuned FunctionGemma model from PyTorch/SafeTensors format to TFLite for on-device inference.

**Pipeline:**
1. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DenisovAV/flutter_gemma/blob/main/colabs/functiongemma_finetuning.ipynb) [functiongemma_finetuning.ipynb](https://github.com/DenisovAV/flutter_gemma/blob/main/colabs/functiongemma_finetuning.ipynb) - Fine-tune model ‚úÖ
2. **This notebook** - Convert to TFLite
3. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DenisovAV/flutter_gemma/blob/main/colabs/functiongemma_tflite_to_task.ipynb) [functiongemma_tflite_to_task.ipynb](https://github.com/DenisovAV/flutter_gemma/blob/main/colabs/functiongemma_tflite_to_task.ipynb) - Bundle as .task for Flutter

**What this notebook does:**
1. Loads fine-tuned model from Google Drive
2. Validates model works correctly BEFORE conversion (HuggingFace test)
3. Converts to TFLite with int8 quantization
4. Saves TFLite + tokenizer to Google Drive

**‚ö†Ô∏è CRITICAL Loading Parameters:**
- `torch_dtype=torch.bfloat16` (NOT float16!)
- `attn_implementation="eager"`

**Requirements:**
- A100 or L4 GPU runtime (Runtime ‚Üí Change runtime type)
- Fine-tuned model on Google Drive (from Step 1)

**Output:** `.tflite` file saved to Google Drive

## Step 1: Install Dependencies

**What we're installing:**
- `ai-edge-torch` - Google's library for converting PyTorch models to TFLite
- `transformers` - HuggingFace library for loading/testing the model
- `numpy<2.1` - Required for compatibility with ai-edge-torch
- `sentencepiece` - Tokenizer for Gemma models

**Why specific versions:**
- `numpy<2.1` - ai-edge-torch breaks with numpy 2.1+
- `Pillow` reinstall - ai-edge-torch may corrupt Colab's Pillow

**‚ö†Ô∏è RESTART RUNTIME** after running this cell!

In [None]:
# =============================================================================
# Step 1: Install ai-edge-torch
# =============================================================================
!pip uninstall -y tensorflow 2>/dev/null || true

# Install ai-edge-torch
!pip install ai-edge-torch --force-reinstall -q

# CRITICAL: Install numpy<2.1 AFTER ai-edge-torch (it may override)
!pip install "numpy<2.1" --force-reinstall -q

# Install transformers with pinned version
!pip install transformers==4.57.3 huggingface_hub sentencepiece -q

# Restore Colab's native Pillow (ai-edge-torch may break it)
!pip install Pillow --force-reinstall -q

print("\nInstalled:")
!pip show ai-edge-torch | grep Version
!pip show transformers | grep Version
!pip show numpy | grep Version
!pip show Pillow | grep Version

print("\n‚ö†Ô∏è  RESTART RUNTIME after this step! (Runtime ‚Üí Restart session)")

## Step 2: Load Model from Google Drive

Loads the fine-tuned model from Google Drive.

**Expected location:**
- Folder: `My Drive/functiongemma-flutter-demo-final/`
- Or ZIP: `My Drive/functiongemma-flutter-demo-final.zip`

**Required files in the folder:**
- `model.safetensors` - Model weights (~540MB)
- `config.json` - Model configuration
- `tokenizer.model` - SentencePiece tokenizer
- `tokenizer_config.json` - Tokenizer settings

**Customize:** Change `MODEL_NAME` if your model has a different name.

In [None]:
# =============================================================================
# Step 2: Load fine-tuned model from Google Drive
# =============================================================================
from google.colab import drive
import os

drive.mount('/content/drive')

MODEL_NAME = "functiongemma-flutter-demo-final"
MODEL_DIR = MODEL_NAME
DRIVE_MODEL_DIR = f"/content/drive/MyDrive/{MODEL_NAME}"
DRIVE_ZIP = f"/content/drive/MyDrive/{MODEL_NAME}.zip"

if os.path.exists(DRIVE_MODEL_DIR):
    print(f"Found folder: {DRIVE_MODEL_DIR}")
    !cp -r "{DRIVE_MODEL_DIR}" .
elif os.path.exists(DRIVE_ZIP):
    print(f"Found ZIP: {DRIVE_ZIP}")
    !unzip -q "{DRIVE_ZIP}"
else:
    raise FileNotFoundError(f"Model not found!\nUpload to: {DRIVE_MODEL_DIR}/ or {DRIVE_ZIP}")

print(f"\nModel ready:")
!ls -la "{MODEL_DIR}/"

## Step 3: Test Model Before Conversion

**CRITICAL**: Verify the model works BEFORE converting to TFLite.

We load the model using HuggingFace transformers and test it with a sample prompt. If it outputs garbage here, the problem is in fine-tuning, not conversion.

**What we check:**
- Model outputs `<start_function_call>` tag
- No `<pad>` tokens in output (indicates wrong loading params)
- No Chinese characters or garbage (indicates broken fine-tuning)

In [None]:
# =============================================================================
# Step 3: Test model BEFORE conversion (using HuggingFace transformers)
# =============================================================================
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

print(f"Loading model from {MODEL_DIR} via HuggingFace transformers...")

# CRITICAL: Must use same parameters as training!
# - bfloat16 (NOT float16!)
# - attn_implementation="eager"
hf_model = AutoModelForCausalLM.from_pretrained(
    MODEL_DIR,
    torch_dtype=torch.bfloat16,           # CRITICAL: same as training!
    device_map="auto",
    attn_implementation="eager"            # CRITICAL: same as training!
)
hf_model.eval()
print(f"HuggingFace model loaded on {hf_model.device}, dtype={hf_model.dtype}")

tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)

# FunctionGemma test prompt - MUST match training format EXACTLY
test_prompt = """<start_of_turn>developer
You are a model that can do function calling with the following functions
<start_function_declaration>declaration:change_background_color{description:<escape>Changes the app background color<escape>,parameters:{properties:{color:{description:<escape>The color name (red, green, blue, yellow, purple, orange)<escape>,type:<escape>STRING<escape>}},required:[<escape>color<escape>],type:<escape>OBJECT<escape>}}<end_function_declaration>
<start_function_declaration>declaration:change_app_title{description:<escape>Changes the application title text in the AppBar<escape>,parameters:{properties:{title:{description:<escape>The new title text to display<escape>,type:<escape>STRING<escape>}},required:[<escape>title<escape>],type:<escape>OBJECT<escape>}}<end_function_declaration>
<start_function_declaration>declaration:show_alert{description:<escape>Shows an alert dialog with a custom message and title<escape>,parameters:{properties:{title:{description:<escape>The title of the alert dialog<escape>,type:<escape>STRING<escape>},message:{description:<escape>The message content of the alert dialog<escape>,type:<escape>STRING<escape>}},required:[<escape>title<escape>,<escape>message<escape>],type:<escape>OBJECT<escape>}}<end_function_declaration>
<end_of_turn>
<start_of_turn>user
make it red
<end_of_turn>
<start_of_turn>model
"""

print("\n" + "=" * 50)
print("TESTING FINE-TUNED MODEL (HuggingFace)")
print("=" * 50)
print(f"Input: 'make it red'")

inputs = tokenizer(test_prompt, return_tensors="pt").to(hf_model.device)

with torch.no_grad():
    outputs = hf_model.generate(
        inputs["input_ids"],
        max_new_tokens=50,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id
    )

response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False)
print(f"\nModel output:")
print(response)
print("=" * 50)

# Check if output looks valid
if "change_background_color" in response or "call:" in response:
    print("‚úÖ Fine-tuned model outputs function call - GOOD!")
    print("   Proceeding with conversion...")
elif "<pad>" in response[:50]:
    print("‚ùå Model outputs <pad> - wrong loading parameters!")
    print("   Make sure: torch_dtype=bfloat16, attn_implementation='eager'")
    raise ValueError("STOP: Wrong model loading parameters")
elif "apologize" in response.lower() or "sorry" in response.lower():
    print("‚ùå Model refuses to call function - fine-tuning didn't work!")
    raise ValueError("STOP: Model not fine-tuned correctly")
elif any(c in response for c in "‰∏∫Ë∂≥ÁêÉÊî∂Ê∂àÊ∞î"):
    print("‚ùå Model outputs garbage - fine-tuning is broken!")
    raise ValueError("STOP: Model outputs garbage")
else:
    print("‚ö†Ô∏è Unexpected output - review manually before proceeding")

# Clean up HF model to free memory
del hf_model
torch.cuda.empty_cache()
print("\nHuggingFace model unloaded.")

## Step 4: Convert to TFLite

This is the main conversion step using `ai-edge-torch`.

**What happens:**
1. Model is loaded using `gemma3.build_model_270m()` (ai-edge-torch's loader)
2. Converted to TFLite format with `dynamic_int8` quantization
3. KV-cache is configured for efficient inference

**Conversion parameters (official Google):**

| Parameter | Value | Description |
|-----------|-------|-------------|
| `prefill_seq_len` | 256 | Input sequence length |
| `kv_cache_max_len` | 1024 | Maximum context length |
| `quantize` | dynamic_int8 | Reduces size ~50% |

**Time:** ~5-10 min on A100, ~10-15 min on L4/T4

In [None]:
# =============================================================================
# Step 4: Convert to TFLite
# =============================================================================
import os
from ai_edge_torch.generative.examples.gemma3 import gemma3
from ai_edge_torch.generative.utilities import converter
from ai_edge_torch.generative.utilities.export_config import ExportConfig
from ai_edge_torch.generative.layers import kv_cache

TFLITE_OUTPUT_DIR = "tflite_output"
os.makedirs(TFLITE_OUTPUT_DIR, exist_ok=True)

# Load model using ai-edge-torch
print(f"Loading model from {MODEL_DIR} via ai-edge-torch...")
pytorch_model = gemma3.build_model_270m(MODEL_DIR)
pytorch_model.eval()
print("Model loaded!")

# Configure export
export_config = ExportConfig()
export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
export_config.mask_as_input = True

print("\n" + "=" * 50)
print("Converting to TFLite...")
print("=" * 50)

# Convert with official Google parameters
converter.convert_to_tflite(
    pytorch_model,
    output_path=TFLITE_OUTPUT_DIR,
    output_name_prefix="functiongemma-flutter",
    prefill_seq_len=256,       # Official Google parameter
    kv_cache_max_len=1024,     # Official Google parameter
    quantize="dynamic_int8",
    export_config=export_config,
)

print("\n‚úÖ Conversion complete!")
!ls -lah {TFLITE_OUTPUT_DIR}/

In [None]:
# =============================================================================
# Step 4.5: Test converted model BEFORE saving to .tflite file
# =============================================================================
# NOTE: Full TFLite LLM inference requires MediaPipe (Android/iOS/Web).
# Python can only test the edge_model before exporting to .tflite.
# This test validates the conversion itself, not the .tflite file.

import torch
from transformers import AutoTokenizer

print("=" * 50)
print("TESTING CONVERTED MODEL (via ai-edge-torch)")
print("=" * 50)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)

# Same test prompt (shorter - one tool only for faster test)
test_prompt = """<start_of_turn>developer
You are a model that can do function calling with the following functions
<start_function_declaration>declaration:change_background_color{description:<escape>Changes the app background color<escape>,parameters:{properties:{color:{description:<escape>The color name (red, green, blue, yellow, purple, orange)<escape>,type:<escape>STRING<escape>}},required:[<escape>color<escape>],type:<escape>OBJECT<escape>}}<end_function_declaration>
<end_of_turn>
<start_of_turn>user
make it red
<end_of_turn>
<start_of_turn>model
"""

print(f"Input: 'make it red'")

# Tokenize
input_ids = tokenizer.encode(test_prompt, return_tensors="pt")
print(f"Input tokens: {input_ids.shape[1]}")

# Test with pytorch_model (which was loaded for conversion in Step 4)
# pytorch_model is still in memory from the conversion step
print("\n--- Testing PyTorch model (post-reauthoring) ---")
try:
    with torch.no_grad():
        # Simple forward pass to check model works
        # Note: Full generation requires KV-cache setup
        logits = pytorch_model(input_ids)
        print(f"Output logits shape: {logits.shape}")
        
        # Get the most likely next token
        next_token_logits = logits[0, -1, :]
        next_token_id = torch.argmax(next_token_logits).item()
        next_token = tokenizer.decode([next_token_id])
        print(f"Next predicted token: '{next_token}' (id={next_token_id})")
        
        # Check if it's a reasonable FunctionGemma token
        if next_token.strip() in ['<', 'call', '<start', '<start_function_call>']:
            print("‚úÖ Model predicts function call start - GOOD!")
        else:
            print(f"‚ö†Ô∏è Unexpected first token: '{next_token}'")
            
except Exception as e:
    print(f"‚ö†Ô∏è Forward pass failed: {e}")
    print("   This might indicate conversion issues")

# Summary
print("\n" + "=" * 50)
print("TEST SUMMARY")
print("=" * 50)
print("‚úÖ TFLite file created successfully")
print("‚ö†Ô∏è Full text generation test requires MediaPipe on device")
print("   The .task bundling and MediaPipe inference is the next step")
print("")
print("If model outputs garbage on device, check:")
print("1. BundleConfig in functiongemma_tflite_to_task.ipynb")
print("   - prompt_prefix should be EMPTY for FunctionGemma")
print("   - prompt_suffix should be EMPTY for FunctionGemma")
print("2. Quantization settings (int8 may affect quality)")
print("3. MediaPipe version compatibility")

In [None]:
# =============================================================================
# Step 5: Save to Google Drive
# =============================================================================
import glob
import shutil

DRIVE_OUTPUT_DIR = "/content/drive/MyDrive/flutter_gemma_models"
os.makedirs(DRIVE_OUTPUT_DIR, exist_ok=True)

# Save TFLite file
tflite_files = glob.glob(f"{TFLITE_OUTPUT_DIR}/*.tflite")
for f in tflite_files:
    size = os.path.getsize(f) / 1e6
    dest = f"{DRIVE_OUTPUT_DIR}/{os.path.basename(f)}"
    shutil.copy(f, dest)
    print(f"‚úÖ Saved: {dest} ({size:.1f} MB)")

# Save tokenizer (needed for bundling)
tokenizer_src = f"{MODEL_DIR}/tokenizer.model"
if os.path.exists(tokenizer_src):
    tokenizer_dest = f"{DRIVE_OUTPUT_DIR}/tokenizer.model"
    shutil.copy(tokenizer_src, tokenizer_dest)
    print(f"‚úÖ Saved: {tokenizer_dest}")

print("\n" + "=" * 50)
print("TFLite saved to Google Drive!")
print("=" * 50)
print(f"\nLocation: {DRIVE_OUTPUT_DIR}/")
print("\nNext: Run functiongemma_tflite_to_task.ipynb to create .task file")

## Optional: Upload to HuggingFace Hub

Upload the TFLite model to HuggingFace for easy sharing and versioning.

**Setup:**
1. Create a new model repository on [huggingface.co](https://huggingface.co/new)
2. Login to HuggingFace (uncomment login code below)
3. Change `HUB_REPO_ID` to your repository
4. Run the cell

In [None]:
# =============================================================================
# Optional: Upload to HuggingFace Hub
# =============================================================================
# Uncomment the code below to upload

# from huggingface_hub import login, HfApi
# from google.colab import userdata
#
# # Login (uses token from Colab Secrets)
# HF_TOKEN = userdata.get('HF_TOKEN')
# login(token=HF_TOKEN)
#
# # Upload to HuggingFace
# HUB_REPO_ID = "your-username/functiongemma-flutter-tflite"  # Change this!
#
# api = HfApi()
# api.create_repo(repo_id=HUB_REPO_ID, exist_ok=True)
#
# # Upload TFLite files
# for f in glob.glob(f"{TFLITE_OUTPUT_DIR}/*.tflite"):
#     api.upload_file(
#         path_or_fileobj=f,
#         path_in_repo=os.path.basename(f),
#         repo_id=HUB_REPO_ID,
#     )
#     print(f"‚úÖ Uploaded: {os.path.basename(f)}")
#
# # Upload tokenizer
# if os.path.exists(tokenizer_src):
#     api.upload_file(
#         path_or_fileobj=tokenizer_src,
#         path_in_repo="tokenizer.model",
#         repo_id=HUB_REPO_ID,
#     )
#     print("‚úÖ Uploaded: tokenizer.model")
#
# print(f"\nüéâ Model uploaded to: https://huggingface.co/{HUB_REPO_ID}")