# FunctionGemma Base → TFLite Conversion

Converts the base FunctionGemma model to TFLite format (no fine-tuning).

**Pipeline (Base Model - No Fine-tuning):**
1. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DenisovAV/flutter_gemma/blob/main/colabs/functiongemma_base_download.ipynb) [functiongemma_base_download.ipynb](https://github.com/DenisovAV/flutter_gemma/blob/main/colabs/functiongemma_base_download.ipynb) - Download base model ✅
2. **This notebook** - Convert to TFLite
3. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DenisovAV/flutter_gemma/blob/main/colabs/functiongemma_base_tflite_to_task.ipynb) [functiongemma_base_tflite_to_task.ipynb](https://github.com/DenisovAV/flutter_gemma/blob/main/colabs/functiongemma_base_tflite_to_task.ipynb) - Bundle as .task for Flutter

**Requirements:**
- A100 or L4 GPU runtime
- Base model on Google Drive (from Step 1)

**Output:** `.tflite` file saved to Google Drive

## Step 1: Install Dependencies

**RESTART RUNTIME** after running this cell!

In [None]:
!pip uninstall -y tensorflow 2>/dev/null || true

!pip install ai-edge-torch --force-reinstall -q
!pip install "numpy<2.1" --force-reinstall -q
!pip install transformers==4.57.3 huggingface_hub sentencepiece -q
!pip install Pillow --force-reinstall -q

print("\nInstalled:")
!pip show ai-edge-torch | grep Version
!pip show numpy | grep Version

print("\n RESTART RUNTIME after this step! (Runtime -> Restart session)")

## Step 2: Load Model from Google Drive

In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')

MODEL_NAME = "functiongemma-base"
MODEL_DIR = MODEL_NAME
DRIVE_MODEL_DIR = f"/content/drive/MyDrive/{MODEL_NAME}"

if os.path.exists(DRIVE_MODEL_DIR):
    print(f"Found folder: {DRIVE_MODEL_DIR}")
    !cp -r "{DRIVE_MODEL_DIR}" .
else:
    raise FileNotFoundError(f"Model not found!\nRun functiongemma_base_download.ipynb first!")

print(f"\nModel ready:")
!ls -la "{MODEL_DIR}/"

## Step 3: Test Model Before Conversion

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

print(f"Loading model from {MODEL_DIR}...")

hf_model = AutoModelForCausalLM.from_pretrained(
    MODEL_DIR,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="eager"
)
hf_model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)

test_prompt = """<start_of_turn>developer
You are a model that can do function calling with the following functions
<start_function_declaration>declaration:change_background_color{description:<escape>Changes the app background color<escape>,parameters:{properties:{color:{description:<escape>The color name (red, green, blue, yellow, purple, orange)<escape>,type:<escape>STRING<escape>}},required:[<escape>color<escape>],type:<escape>OBJECT<escape>}}<end_function_declaration>
<end_of_turn>
<start_of_turn>user
make it red
<end_of_turn>
<start_of_turn>model
"""

print("\n" + "=" * 50)
print("TESTING BASE MODEL")
print("=" * 50)

inputs = tokenizer(test_prompt, return_tensors="pt").to(hf_model.device)

with torch.no_grad():
    outputs = hf_model.generate(
        inputs["input_ids"],
        max_new_tokens=50,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id
    )

response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False)
print(f"Input: 'make it red'")
print(f"Output: {response}")
print("=" * 50)

del hf_model
torch.cuda.empty_cache()

## Step 4: Convert to TFLite

**Time:** ~5-10 min on A100

In [None]:
import os
from ai_edge_torch.generative.examples.gemma3 import gemma3
from ai_edge_torch.generative.utilities import converter
from ai_edge_torch.generative.utilities.export_config import ExportConfig
from ai_edge_torch.generative.layers import kv_cache

TFLITE_OUTPUT_DIR = "tflite_output"
os.makedirs(TFLITE_OUTPUT_DIR, exist_ok=True)

print(f"Loading model via ai-edge-torch...")
pytorch_model = gemma3.build_model_270m(MODEL_DIR)
pytorch_model.eval()
print("Model loaded!")

export_config = ExportConfig()
export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
export_config.mask_as_input = True

print("\n" + "=" * 50)
print("Converting to TFLite...")
print("=" * 50)

converter.convert_to_tflite(
    pytorch_model,
    output_path=TFLITE_OUTPUT_DIR,
    output_name_prefix="functiongemma-base",
    prefill_seq_len=256,
    kv_cache_max_len=1024,
    quantize="dynamic_int8",
    export_config=export_config,
)

print("\n Conversion complete!")
!ls -lah {TFLITE_OUTPUT_DIR}/

## Step 5: Save to Google Drive

In [None]:
import glob
import shutil

DRIVE_OUTPUT_DIR = "/content/drive/MyDrive/flutter_gemma_models"
os.makedirs(DRIVE_OUTPUT_DIR, exist_ok=True)

# Save TFLite file
tflite_files = glob.glob(f"{TFLITE_OUTPUT_DIR}/*.tflite")
for f in tflite_files:
    size = os.path.getsize(f) / 1e6
    dest = f"{DRIVE_OUTPUT_DIR}/{os.path.basename(f)}"
    shutil.copy(f, dest)
    print(f"Saved: {dest} ({size:.1f} MB)")

# Save tokenizer
tokenizer_src = f"{MODEL_DIR}/tokenizer.model"
if os.path.exists(tokenizer_src):
    tokenizer_dest = f"{DRIVE_OUTPUT_DIR}/tokenizer.model"
    shutil.copy(tokenizer_src, tokenizer_dest)
    print(f"Saved: {tokenizer_dest}")

print("\n" + "=" * 50)
print("DONE!")
print("=" * 50)
print("\nNext: Run functiongemma_base_tflite_to_task.ipynb")