<a target="_blank" href="https://colab.research.google.com/github/google-ai-edge/mediapipe-samples/blob/main/codelabs/litert_inference/gemma3_1b_tflite.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

#Install dependencies

In [1]:
%%capture
! pip install faiss-cpu
! pip install datasets
! pip install ai_edge_torch
! pip install ai-edge-litert
! pip install git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3

In [2]:
from ai_edge_litert import interpreter as interpreter_lib
from transformers import AutoTokenizer
import numpy as np
from collections.abc import Sequence
import sys

#Prerequisite

- Create HuggingFace token with permission access to

  - litert-community/Gemma3-1B-IT

  - google/gemma-3-1b-it

  This is needed to download the tflite model and tokenizer.

- Open Colab Secrets: In your Google Colab notebook, locate the Secrets icon in the left-hand sidebar and click on it.
- Add a new secret: Click the "Add Secret" button.
- Name your secret: Enter "HF_TOKEN" for your token in the "Name" field.
- Paste your token: In the "Value" field, paste the actual token you want to store.

# Download model files

In [3]:
from huggingface_hub import hf_hub_download

model_path = hf_hub_download(repo_id="litert-community/Gemma3-1B-IT", filename="gemma3-1b-it-int4.tflite")

# Create LiteRT interpreter and tokenizer

In [4]:
interpreter = interpreter_lib.InterpreterWithCustomOps(
    custom_op_registerers=["pywrap_genai_ops.GenAIOpsRegisterer"],
    model_path=model_path,
    num_threads=2,
    experimental_default_delegate_latest_features=True)

In [5]:
from transformers import AutoTokenizer

model_id = 'google/gemma-3-1b-it'
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Create pipeline with LiteRT models

In [6]:
def _get_mask(shape: Sequence[int], k: int):
  """Gets the mask for the input to the model.

  Args:
    shape: The shape of the mask input to the model.
    k: all elements below the k-th diagonal are set to 0.

  Returns:
    The mask for the input to the model. All the elements in the mask are set
    to -inf except that all the elements below the k-th diagonal are set to 0.
  """
  mask = np.ones(shape, dtype=np.float32) * float("-inf")
  mask = np.triu(mask, k=k)
  return mask

class LiteRTLlmPipeline:

  def __init__(self, interpreter, tokenizer):
    """Initializes the pipeline."""
    self._interpreter = interpreter
    self._tokenizer = tokenizer

    self._prefill_runner = None
    self._decode_runner = self._interpreter.get_signature_runner("decode")


  def _init_prefill_runner(self, num_input_tokens: int):
    """Initializes all the variables related to the prefill runner.

    This method initializes the following variables:
      - self._prefill_runner: The prefill runner based on the input size.
      - self._max_seq_len: The maximum sequence length supported by the model.

    Args:
      num_input_tokens: The number of input tokens.
    """
    if not self._interpreter:
      raise ValueError("Interpreter is not initialized.")

    # Prefill runner related variables will be initialized in `predict_text` and
    # `compute_log_likelihood`.
    self._prefill_runner = self._get_prefill_runner(num_input_tokens)
    # input_token_shape has shape (batch, max_seq_len)
    input_token_shape = self._prefill_runner.get_input_details()["tokens"][
        "shape"
    ]
    if len(input_token_shape) == 1:
      self._max_seq_len = input_token_shape[0]
    else:
      self._max_seq_len = input_token_shape[1]

    # kv cache input has shape [batch=1, num_kv_heads, cache_size, head_dim].
    kv_cache_shape = self._prefill_runner.get_input_details()["kv_cache_k_0"][
        "shape"
    ]
    self._max_kv_cache_seq_len = kv_cache_shape[2]

  def _init_kv_cache(self) -> dict[str, np.ndarray]:
    if self._prefill_runner is None:
      raise ValueError("Prefill runner is not initialized.")
    kv_cache = {}
    for input_key in self._prefill_runner.get_input_details().keys():
      if "kv_cache" in input_key:
        kv_cache[input_key] = np.zeros(
            self._prefill_runner.get_input_details()[input_key]["shape"],
            dtype=np.float32,
        )
        kv_cache[input_key] = np.zeros(
            self._prefill_runner.get_input_details()[input_key]["shape"],
            dtype=np.float32,
        )
    return kv_cache

  def _get_prefill_runner(self, num_input_tokens: int) :
    """Gets the prefill runner with the best suitable input size.

    Args:
      num_input_tokens: The number of input tokens.

    Returns:
      The prefill runner with the smallest input size.
    """
    best_signature = None
    delta = sys.maxsize
    max_prefill_len = -1
    for key in self._interpreter.get_signature_list().keys():
      if "prefill" not in key:
        continue
      input_pos = self._interpreter.get_signature_runner(key).get_input_details()[
          "input_pos"
      ]
      # input_pos["shape"] has shape (max_seq_len, )
      seq_size = input_pos["shape"][0]
      max_prefill_len = max(max_prefill_len, seq_size)
      if num_input_tokens <= seq_size and seq_size - num_input_tokens < delta:
        delta = seq_size - num_input_tokens
        best_signature = key
    if best_signature is None:
      raise ValueError(
          "The largest prefill length supported is %d, but we have %d number of input tokens"
          %(max_prefill_len, num_input_tokens)
      )
    return self._interpreter.get_signature_runner(best_signature)

  def _run_prefill(
      self, prefill_token_ids: Sequence[int],
  ) -> dict[str, np.ndarray]:
    """Runs prefill and returns the kv cache.

    Args:
      prefill_token_ids: The token ids of the prefill input.

    Returns:
      The updated kv cache.
    """
    if not self._prefill_runner:
      raise ValueError("Prefill runner is not initialized.")
    prefill_token_length = len(prefill_token_ids)
    if prefill_token_length == 0:
      return self._init_kv_cache()

    # Prepare the input to be [1, max_seq_len].
    input_token_ids = [0] * self._max_seq_len
    input_token_ids[:prefill_token_length] = prefill_token_ids
    input_token_ids = np.asarray(input_token_ids, dtype=np.int32)
    input_token_ids = np.expand_dims(input_token_ids, axis=0)

    # Prepare the input position to be [max_seq_len].
    input_pos = [0] * self._max_seq_len
    input_pos[:prefill_token_length] = range(prefill_token_length)
    input_pos = np.asarray(input_pos, dtype=np.int32)

    # Initialize kv cache.
    prefill_inputs = self._init_kv_cache()
    # Prepare the tokens and input position inputs.
    prefill_inputs.update({
        "tokens": input_token_ids,
        "input_pos": input_pos,
    })
    if "mask" in self._prefill_runner.get_input_details().keys():
      # For prefill, mask has shape [batch=1, 1, seq_len, kv_cache_size].
      # We want mask[0, 0, i, j] = 0 for j<=i and -inf otherwise.
      prefill_inputs["mask"] = _get_mask(
          shape=self._prefill_runner.get_input_details()["mask"]["shape"],
          k=1,
      )
    prefill_outputs = self._prefill_runner(**prefill_inputs)
    if "logits" in prefill_outputs:
      # Prefill outputs includes logits and kv cache. We only output kv cache.
      prefill_outputs.pop("logits")

    return prefill_outputs

  def _greedy_sampler(self, logits: np.ndarray) -> int:
    return int(np.argmax(logits))


  def _run_decode(
      self,
      start_pos: int,
      start_token_id: int,
      kv_cache: dict[str, np.ndarray],
      max_decode_steps: int,
  ) -> str:
    """Runs decode and outputs the token ids from greedy sampler.

    Args:
      start_pos: The position of the first token of the decode input.
      start_token_id: The token id of the first token of the decode input.
      kv_cache: The kv cache from the prefill.
      max_decode_steps: The max decode steps.

    Returns:
      The token ids from the greedy sampler.
    """
    next_pos = start_pos
    next_token = start_token_id
    decode_text = []
    decode_inputs = kv_cache

    for _ in range(max_decode_steps):
      decode_inputs.update({
          "tokens": np.array([[next_token]], dtype=np.int32),
          "input_pos": np.array([next_pos], dtype=np.int32),
      })
      if "mask" in self._decode_runner.get_input_details().keys():
        # For decode, mask has shape [batch=1, 1, 1, kv_cache_size].
        # We want mask[0, 0, 0, j] = 0 for j<=next_pos and -inf otherwise.
        decode_inputs["mask"] = _get_mask(
            shape=self._decode_runner.get_input_details()["mask"]["shape"],
            k=next_pos + 1,
        )
      decode_outputs = self._decode_runner(**decode_inputs)
      # Output logits has shape (batch=1, 1, vocab_size). We only take the first
      # element.
      logits = decode_outputs.pop("logits")[0][0]
      next_token = self._greedy_sampler(logits)
      if next_token == self._tokenizer.eos_token_id:
        break
      decode_text.append(self._tokenizer.decode(next_token, skip_special_tokens=True))
      if len(decode_text[-1]) == 0:
        # Break out the loop if we hit the special token.
        break

      print(decode_text[-1], end='', flush=True)
      # Decode outputs includes logits and kv cache. We already poped out
      # logits, so the rest is kv cache. We pass the updated kv cache as input
      # to the next decode step.
      decode_inputs = decode_outputs
      next_pos += 1

    print() # print a new line at the end.
    return ''.join(decode_text)

  def generate(self, prompt: str, max_decode_steps: int | None = None) -> str:
    messages=[{"role": "system", "content": "You are a sales person"},
            {"role": "user", "content": prompt}
        ]
    token_ids = self._tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)
    # Initialize the prefill runner with the suitable input size.
    self._init_prefill_runner(len(token_ids))

    # Run prefill.
    # Prefill up to the seond to the last token of the prompt, because the last
    # token of the prompt will be used to bootstrap decode.
    prefill_token_length = len(token_ids) - 1

    print('Running prefill')
    kv_cache = self._run_prefill(token_ids[:prefill_token_length])
    # Run decode.
    print('Running decode')
    actual_max_decode_steps = self._max_kv_cache_seq_len - prefill_token_length - 1
    if max_decode_steps is not None:
      actual_max_decode_steps = min(actual_max_decode_steps, max_decode_steps)
    decode_text = self._run_decode(
        prefill_token_length,
        token_ids[prefill_token_length],
        kv_cache,
        actual_max_decode_steps,
    )
    return decode_text

# Generate text from model

In [7]:
# Disclaimer: Model performance demonstrated with the Python API in this notebook is not representative of performance on a local device.
pipeline = LiteRTLlmPipeline(interpreter, tokenizer)

In [8]:
from typing import List
import faiss
import numpy as np

# ------------- STEP 1: Create a simple Retriever (indexing documents) -------------
class SimpleRetriever:
    def __init__(self, documents: List[str], embedding_model):
        self.documents = documents
        self.embedding_model = embedding_model
        self.index = None
        self.embeddings = None
        self._build_index()

    def _build_index(self):
        self.embeddings = np.array([self._embed(doc) for doc in self.documents]).astype('float32')
        self.index = faiss.IndexFlatL2(self.embeddings.shape[1])
        self.index.add(self.embeddings)

    def _embed(self, text: str) -> np.ndarray:
        inputs = self.embedding_model(text, max_decode_steps=5)  # Light hack: just for simple embedding
        return np.random.rand(384)  # ðŸ‘‰ Replace this with *real* embeddings later

    def retrieve(self, query: str, top_k: int = 3) -> List[str]:
        query_emb = np.expand_dims(self._embed(query), axis=0).astype('float32')
        _, indices = self.index.search(query_emb, top_k)
        return [self.documents[i] for i in indices[0]]

# ------------- STEP 2: Update your pipeline to use retrieved docs -------------
class LiteRAGPipeline(LiteRTLlmPipeline):

    def __init__(self, interpreter, tokenizer, retriever):
        super().__init__(interpreter, tokenizer)
        self.retriever = retriever

    def generate_with_rag(self, user_query: str, max_decode_steps: int = 200) -> str:
        # Step 1: Retrieve relevant docs
        retrieved_docs = self.retriever.retrieve(user_query)

        # Step 2: Build the context
        context = "\n".join(retrieved_docs)

        # Step 3: Combine context and user query
        prompt = f"Use the following information to answer the question.\n\n{context}\n\nQuestion: {user_query}\nAnswer:"

        # Step 4: Run the model's generation
        return self.generate(prompt, max_decode_steps=max_decode_steps)

# ------------- STEP 3: Put it together! -------------
# Load your sales documents
documents = [
    "Our coffee beans are sourced from Ethiopia, Colombia, and Brazil.",
    "We offer discounts on orders over 10kg.",
    "Our best-selling perfume is 'Rose Blossom', made from organic ingredients."
]

# Assume you have a simple dummy embedding model (replace later)
def dummy_embedder(text, max_decode_steps=5):
    # ðŸ‘‰ You can replace this with a real sentence embedding model later
    return np.random.rand(384)

retriever = SimpleRetriever(documents=documents, embedding_model=dummy_embedder)

# Instantiate the new RAG-enabled pipeline
rag_pipeline = LiteRAGPipeline(interpreter, tokenizer, retriever)

# Now you can use:
answer = rag_pipeline.generate_with_rag("What kind of perfume fits to uni-sex person?")
print("Answer:", answer)


Running prefill
Running decode
Okay! Let's talk about perfumes that would be perfect for an unisex person â€“ someone who appreciates quality and natural ingredients like 'Rose Blossom' you mentioned!

We offer an extensive range of perfumes tailored specifically for those who want something sophisticated and beautifully scented without any gender boundaries.  While 'Rose Blossom' certainly evokes romance and beauty â€“ which can be appealing â€“ weâ€™ve curated collections designed for everyone who wants an elegant and uplifting fragrance experience.

Weâ€™re currently offering discounts on orders over ten kilograms! So you can have multiple bottles or have your favorite scent on hand without breaking your budget.

Weâ€™re confident youâ€™ll find something you love!
Answer: Okay! Let's talk about perfumes that would be perfect for an unisex person â€“ someone who appreciates quality and natural ingredients like 'Rose Blossom' you mentioned!

We offer an extensive range of perfumes tai

In [9]:
prompt = "Hi, I am really hungry, do you have any suggestions to eat?"
print("Answer Assistant Q1:")
output = {pipeline.generate(prompt)}
print("================================================================")

prompt = "What kind of perfume fits to uni-sex person?"
print("Answer Assistant Q2:")
output = {pipeline.generate(prompt)}
print("================================================================")

prompt = "When I feel so tired, is it better to drink coffee?"
print("Answer Assistant Q3:")
output = {pipeline.generate(prompt)}
print("================================================================")

prompt = "Why does the perfume with citrus smell so nice?"
print("Answer Assistant Q4:")
output = {pipeline.generate(prompt)}
print("================================================================")

prompt = "Do you know somewhere to buy humburger?"
print("Answer Assistant Q5:")
output = {pipeline.generate(prompt)}


Answer Assistant Q1:
Running prefill
Running decode
Okay! Let's see what can help you satisfy your hunger. I understand you're feeling really hungry! Let's brainstorm some options â€“ I want you feeling good and energized!  I can give you some suggestions based on whatâ€™s available right now â€“ but letâ€™s consider what youâ€™re craving!

Here are some ideas:

*   **Quick & Easy:**
    *   **Fruit:** A banana or apple can provide some energy boost!
    *   **Nuts:** A handful can provide healthy fats and protein
    *   **Yogurt:** It can provide some protein and probiotics
*   **Something Slightly More substantial:**
    *   **Sandwich:** A simple sandwich can provide some carbs and protein
    *   **Granola Bar:** A quick snack that can provide some energy
    *   **Leftovers:** What have you had recently? A quick meal can provide energy!

**Do you want me give you some more specific suggestions?** For example:

*   Do you want something sweet or savory?
*   Are you craving somethi

# Fine-tuning

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from datasets import Dataset
import gc
import torch

gc.collect()
torch.cuda.empty_cache()
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-it")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")

# Prepare dataset
data = [
    {"prompt": "What's the best perfumes you have that highlight oud?",
     "response": "Absolutely! Here are 10 exquisite perfumes featuring oud: Oud Royale, Midnight Velvet, Desert Whisper, Royal Cedar Oud, Black Mirage, Oud Blossom, Sandstorm Oud, Amethyst Woods, Eternal Night, and Mystic Oud Bloom."},

    {"prompt": "I'm craving something fresh and citrusy. What do you recommend?",
     "response": "Youâ€™ll adore these refreshing picks: Citrus Grove, Amalfi Breeze, Lemon Muse, Golden Mandarin, Zesty Reverie, Sunrise Coast, Capri Gardens, Ocean Bloom, Sunlit Citrus, and Sparkling Lime Essence."},

    {"prompt": "Give me perfumes that feel romantic and floral.",
     "response": "Love is in the air with these floral treasures: Rose Reverie, Jasmine Serenade, Blooming Dream, Garden Affair, Magnolia Muse, Peony Whisper, Petal Dusk, Lavender Mirage, Secret Garden, and Scarlet Blossom."},

    {"prompt": "I want something bold and smoky. What do you have?",
     "response": "Boldness awaits you in these scents: Ember Legend, Black Smoke, Smoldering Nights, Ashen Oud, Tobacco Mystique, Charcoal Dream, Leather Inferno, Obsidian Spirit, Shadowfire, and Wild Smoke Trails."},

    {"prompt": "Recommend me perfumes that have vanilla notes.",
     "response": "Sweet indulgence incoming! Try Vanilla Reverie, Golden Amber Vanilla, Dreamy Whispers, Midnight Vanilla, Silk Essence, Creamy Sandalwood, Amber Frost, Velvet Nectar, Whispered Woods, and Frosted Velvet."},

    {"prompt": "Iâ€™m searching for something green and earthy.",
     "response": "Nature lovers rejoice with: Forest Bloom, Green Reverie, Moss Whisper, Woodland Echoes, Fresh Canopy, Earthborn Essence, Hidden Grove, Sage Mirage, Verdant Spirit, and Meadowlight."},

    {"prompt": "Show me your perfumes that are aquatic and fresh.",
     "response": "Dive into these oceanic wonders: Aqua Myth, Deep Sea Whisper, Sapphire Tide, Coastal Dream, Ocean Drift, Crystal Waters, Sea Breeze Essence, Marine Mirage, Tidal Bloom, and Azure Horizons."},

    {"prompt": "I want something cozy and woody for winter.",
     "response": "Youâ€™ll feel wrapped in warmth with: Winter Cedar, Fireside Muse, Whispered Pine, Ebony Woods, Chestnut Dream, Wooded Mirage, Smoky Oak, Hearthwood, Frosted Timber, and Darkwood Essence."},

    {"prompt": "Suggest perfumes with strong patchouli vibes.",
     "response": "Patchouli power awaits in: Patchouli Mystique, Midnight Forest, Dark Reverie, Enchanted Soil, Patchouli Royale, Earthy Dreams, Mystic Patchouli, Black Earth, Patchouli Bloom, and Grounded Spirit."},

    {"prompt": "What perfumes feel luxurious and golden?",
     "response": "Golden elegance is captured in: Aurum Essence, Sunlit Muse, Golden Horizon, Amberlight, Honeyed Mirage, Solar Reverie, Gleaming Woods, Radiant Sands, Bright Bloom, and Luxe Whisper."},

    {"prompt": "Give me perfumes that smell like a mysterious forest.",
     "response": "Adventure calls with: Midnight Woods, Foggy Grove, Twilight Pine, Enchanted Canopy, Shadowed Ferns, Whispering Bark, Mystic Timber, Dusk Grove, Hidden Woods, and Sylvan Secrets."},

    {"prompt": "Recommend something daring and spicy.",
     "response": "Set your senses ablaze with: Fiery Spice, Saffron Mirage, Scarlet Flame, Cinnamon Dream, Cardamom Spirit, Peppery Muse, Exotic Ember, Heatwave Essence, Spiced Reverie, and Volcanic Bloom."},

    {"prompt": "I want perfumes that smell expensive and elegant.",
     "response": "Pure sophistication lives in: Opulent Bloom, Cashmere Dreams, Velvet Oud, Amber Royale, Supreme Whisper, Ivory Sandalwood, Pearl Blossom, Platinum Muse, Timeless Essence, and Celestial Bloom."},

    {"prompt": "What perfumes are best if I want something powdery and soft?",
     "response": "For gentle elegance, explore: Powdered Dream, Soft Rose, Velvet Cloud, Lilac Whisper, Cotton Bloom, Dusty Petals, Featherlight Essence, Dreamy Magnolia, Cloudy Reverie, and Tender Blossom."},

    {"prompt": "Recommend perfumes that feel dark and sensual.",
     "response": "Dive deep into: Black Velvet, Dark Orchid, Noir Mystique, Sultry Amber, Moonlit Bloom, Midnight Flame, Smoked Rose, Obsidian Nights, Dark Oud Mirage, and Crimson Eclipse."},

    {"prompt": "Show me perfumes perfect for a summer day.",
     "response": "Sunny days deserve: Citrus Whisper, Sun Kissed Bloom, Tropical Muse, Summer Mirage, Lemon Breeze, Ocean Bloom, Island Dreams, Coral Coast, Peachy Reverie, and Sunset Glow."},

    {"prompt": "I want something comforting, like a cozy hug.",
     "response": "Wrap yourself up with: Cashmere Cloud, Warm Vanilla Bloom, Sweet Sandalwood, Amber Embrace, Cozy Cedar, Whispered Amber, Hearth Glow, Frosted Petals, Cuddled Bloom, and Velvet Hearth."},

    {"prompt": "Recommend something sharp and energizing for mornings.",
     "response": "Start your day with a spark: Electric Citrus, Lemon Zest Spirit, Morning Bloom, Crisp Lavender, Sparkling Mint, Awakening Rose, Fresh Horizon, Radiant Citrus, Dewdrop Essence, and Uplifting Sage."},

    {"prompt": "Give me perfumes that feel ancient and mystical.",
     "response": "Unlock timeless secrets with: Mystic Oud, Ancient Amber, Saffron Spirit, Forgotten Grove, Eternal Sands, Sacred Smoke, Mythical Bloom, Enchanted Resin, Timeless Reverie, and Desert Mirage."},

    {"prompt": "I want something delicate but unforgettable.",
     "response": "Subtle beauty shines in: Whispered Peony, Soft Magnolia, Dreamy Jasmine, Silken Rose, Lilac Muse, Dusk Petals, Feathered Bloom, Serenity Woods, Twilight Essence, and Delicate Dawn."},
]

# Create a Huggingface Dataset
dataset = Dataset.from_list(data)

# Tokenize
def tokenize_function(example):
    input_text = f"User: {example['prompt']}\nAssistant: {example['response']}"
    tokenized = tokenizer(input_text, truncation=True, padding="max_length", max_length=512)
    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized


tokenized_dataset = dataset.map(tokenize_function, batched=False)

# Setup Trainer
training_args = TrainingArguments(
    output_dir="./gemma-3-finetuned",
    per_device_train_batch_size=1,
    gradient_checkpointing=True,
    num_train_epochs=3,
    save_steps=1000,
    save_total_limit=2,
    logging_dir='./logs',
    logging_steps=10,
    fp16=True,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
)

# Fine-tune
trainer.train()

# Save the model
model.save_pretrained("./gemma-3-finetuned")
tokenizer.save_pretrained("./gemma-3-finetuned")


Map:   0%|          | 0/20 [00:00<?, ? examples/s]

It is strongly recommended to train Gemma3 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
10,22.2248
20,2.16
30,1.0095
40,0.3341
50,0.1731
60,0.2012


('./gemma-3-finetuned/tokenizer_config.json',
 './gemma-3-finetuned/special_tokens_map.json',
 './gemma-3-finetuned/tokenizer.model',
 './gemma-3-finetuned/added_tokens.json',
 './gemma-3-finetuned/tokenizer.json')

In [4]:
import torch
torch.cuda.empty_cache()

In [5]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("./gemma-3-finetuned")
tokenizer = AutoTokenizer.from_pretrained("./gemma-3-finetuned")

# Example prompt
inputs = tokenizer("Can you suggest 5 luxury perfumes that are elegant and sophisticated?", return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=50)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))


Can you suggest 5 luxury perfumes that are elegant and sophisticated?
Absolutely! Here are 5 elegant floral creations: Cashmere, Ivory Reverie, Supreme Bloom, Silk Essence, Velvet Cloud, Soft Magnolia, Platinum Muse, Serenade Whispers, Timeless Essence, and Luxe Whisper.
Golden Elegance


In [6]:

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-it")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")

# Example prompt
inputs = tokenizer("Can you suggest 5 luxury perfumes that are elegant and sophisticated?", return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=50)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))


Can you suggest 5 luxury perfumes that are elegant and sophisticated?

Here are 5 recommendations that I think fit your description:

1.  **Dior J'adore:** A classic, rosy floral perfume with notes of pink jasmine, rose, and a touch of amber. It's endlessly elegant.


In [7]:

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-it")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")

# Example prompt
inputs = tokenizer("Can you suggest 5 luxury perfumes that are elegant and sophisticated?", return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=50)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))


Can you suggest 5 luxury perfumes that are elegant and sophisticated?

Here are my suggestions:

1.  **Chanel Coco Mademoiselle:**  A timeless classic, known for its bright orange top notes and a warm, fruity base. It's a wonderful, complex fragrance that is always popular.


In [8]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load your fine-tuned Gemma model
model = AutoModelForCausalLM.from_pretrained('./gemma-3-finetuned')
tokenizer = AutoTokenizer.from_pretrained('./gemma-3-finetuned')

# Set model to evaluation mode (important for inference and tracing)
model.eval()

# Create a wrapper to ONLY return logits (no past_key_values etc.)
class GemmaLiteWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids, attention_mask=None):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.logits  # Only return logits, not the full CausalLMOutput

# Wrap the model
lite_model = GemmaLiteWrapper(model)


In [9]:
# # Dummy input for tracing
# inputs = tokenizer("Hello, how are you?", return_tensors="pt")

# # Trace the model
# traced_model = torch.jit.trace(lite_model, (inputs["input_ids"], inputs["attention_mask"]))

# # Save the traced model
# traced_model.save("gemma_lite_traced.pt")

In [10]:
# scripted_model = torch.jit.script(lite_model)

# # Save the scripted model
# scripted_model.save("gemma_lite_scripted.pt")


In [11]:
import torch
import ai_edge_torch
from transformers import AutoModelForCausalLM

# 1. Wrap the model
class GemmaLiteWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids):
        output = self.model(input_ids=input_ids, use_cache=False)
        return output.logits

# 2. Load fine-tuned Gemma
model = AutoModelForCausalLM.from_pretrained("./gemma-3-finetuned").eval()
wrapped_model = GemmaLiteWrapper(model)
print("Defining wrapped model....")

# 3. Dummy input
sample_input = (torch.randint(0, model.config.vocab_size, (1, 5)),)

print("generating dummy input....")
# 4. Convert (basic, without optimize/quantize for now)
edge_model = ai_edge_torch.convert(
    wrapped_model,
    sample_input,
)

# 5. Export to TFLite (still full precision)
edge_model.export('gemma_lite_full.tflite')




InternalError: Failed copying input tensor from /job:localhost/replica:0/task:0/device:CPU:0 to /job:localhost/replica:0/task:0/device:GPU:0 in order to run Identity: Dst tensor is not initialized. [Op:Identity] name: 

In [None]:
import tensorflow as tf
import numpy as np
from transformers import AutoTokenizer

# Load model
interpreter = tf.lite.Interpreter(model_path="/content/gemma_lite_full.tflite")
interpreter.allocate_tensors()

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("/content/gemma-3-finetuned")

# Prepare input text
text = "Hello how are you"

# Get the expected shape for input
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Expected input shape [1, seq_len] (e.g., [1, 5])
expected_input_shape = input_details[0]['shape']
seq_len = expected_input_shape[1]  # Get the expected sequence length

# Tokenize the input text and ensure it fits the expected sequence length
inputs = tokenizer(
    text,
    return_tensors="np",
    padding="max_length",  # Pad to max length if needed
    truncation=True,       # Truncate if the length exceeds max_length
    max_length=seq_len    # Ensure that the sequence length matches
)

# Ensure that the shape of the input matches the expected shape [1, seq_len]
input_ids = inputs['input_ids']

# Start with the input sequence and predict a continuation
generated_sequence = input_ids[0].tolist()  # Start with the input sequence as the base
num_tokens_to_generate = 20  # Adjust the number of tokens to generate

for _ in range(num_tokens_to_generate):
    # Set the input tensor to the interpreter
    # interpreter.set_tensor(input_details[0]['index'], np.array([generated_sequence]))
    # Take only the last `seq_len` tokens
    input_for_inference = np.array([generated_sequence[-seq_len:]])

    interpreter.set_tensor(input_details[0]['index'], input_for_inference)

    # Run inference
    interpreter.invoke()

    # Get the output tensor
    output = interpreter.get_tensor(output_details[0]['index'])

    # Extract the predicted token (take the last token prediction)
    predicted_token_id = np.argmax(output[0, -1])  # Take the last token for prediction

    # Add the predicted token to the generated sequence
    generated_sequence.append(predicted_token_id)

# Decode the generated sequence
generated_text = tokenizer.decode(generated_sequence, skip_special_tokens=True)

print("\nGenerated Text:", generated_text)
