##### Copyright 2025 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# RAG with EmbeddingGemma

<table align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_3]RAG_with_EmbeddingGemma.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
</table>

EmbeddingGemma is a lightweight, open embedding model designed for fast, high-quality retrieval on everyday devices like mobile phones. At only 308 million parameters, it's efficient enough to run advanced AI techniques, such as Retrieval Augmented Generation (RAG), directly on your local machine with no internet connection required.

## Setup

Before starting this tutorial, complete the following steps:

* Get access to EmbeddingGemma by logging into [Hugging Face](https://huggingface.co/google/embeddinggemma-300M) and selecting **Acknowledge license** for a Gemma model.
* Select a Colab runtime with sufficient resources to run
  the Gemma model size you want to run. [Learn more](https://ai.google.dev/gemma/docs/core#sizes).
* Generate a Hugging Face [Access Token](https://huggingface.co/docs/hub/en/security-tokens#how-to-manage-user-access-token) and use it to login from Colab.

This notebook will run on an NVIDIA T4 GPU.

### Install Python packages

Install the libraries required for running the EmbeddingGemma model and generating embeddings. Sentence Transformers is a Python framework for text and image embeddings. For more information, see the [Sentence Transformers](https://www.sbert.net/) documentation.

In [None]:
!pip install -U sentence-transformers git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma

After you have accepted the license, you need a valid Hugging Face Token to access the model.

In [None]:
# Login into Hugging Face Hub
from huggingface_hub import login
login()

### Load language model

You will use Gemma 3 to generate responses.

In [None]:
# Load Gemma 3
from transformers import pipeline

pipeline = pipeline(
    task="text-generation",
    model="google/gemma-3-4b-it",
    device_map="auto",
    dtype="auto"
)

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

Device set to use cuda:0


### Load embedding model

Use the `sentence-transformers` libraries to create an instance of a model class with EmbeddingGemma.

In [None]:
import torch
from sentence_transformers import SentenceTransformer

device = "cuda" if torch.cuda.is_available() else "cpu"

model_id = "google/embeddinggemma-300M"
model = SentenceTransformer(model_id).to(device=device)

print(f"Device: {model.device}")
print(model)
print("Total number of parameters in the model:", sum([p.numel() for _, p in model.named_parameters()]))

modules.json:   0%|          | 0.00/573 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/997 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/16.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/58.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.49k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.21G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/312 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/134 [00:00<?, ?B/s]

2_Dense/model.safetensors:   0%|          | 0.00/9.44M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/134 [00:00<?, ?B/s]

3_Dense/model.safetensors:   0%|          | 0.00/9.44M [00:00<?, ?B/s]

Device: cuda:0
SentenceTransformer(
  (0): Transformer({'max_seq_length': 2048, 'do_lower_case': False, 'architecture': 'Gemma3TextModel'})
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Dense({'in_features': 768, 'out_features': 3072, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
  (3): Dense({'in_features': 3072, 'out_features': 768, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
  (4): Normalize()
)
Total number of parameters in the model: 307581696


### Using Prompts with EmbeddingGemma

For RAG systems, use the following `prompt_name` values to create specialized embeddings for your queries and documents:

* **For Queries:** Use `prompt_name="Retrieval-query"`.<br>
    ```python
    query_embedding = model.encode(
        "How do I use prompts with this model?",
        prompt_name="Retrieval-query"
    )
    ```

* **For Documents:** Use `prompt_name="Retrieval-document"`. To further improve document embeddings, you can also include a title by using the `prompt` argument directly:<br>
  * **With a title:**<br>
    ```python
    doc_embedding = model.encode(
        "The document text...",
        prompt="title: Using Prompts in RAG | text: "
    )
    ```
  * **Without a title:**<br>
    ```python
    doc_embedding = model.encode(
        "The document text...",
        prompt="title: none | text: "
    )
    ```

### Further Reading

* For details on all available EmbeddingGemma prompts, see the [model card](http://ai.google.dev/gemma/docs/embeddinggemma/model_card#prompt_instructions).
* For general information on prompt templates, see the [Sentence Transformer documentation](https://sbert.net/examples/sentence_transformer/applications/computing-embeddings/README.html#prompt-templates).


In [None]:
print("Available tasks:")
for name, prefix in model.prompts.items():
  print(f" {name}: \"{prefix}\"")

Available tasks:
 query: "task: search result | query: "
 document: "title: none | text: "
 BitextMining: "task: search result | query: "
 Clustering: "task: clustering | query: "
 Classification: "task: classification | query: "
 InstructionRetrieval: "task: code retrieval | query: "
 MultilabelClassification: "task: classification | query: "
 PairClassification: "task: sentence similarity | query: "
 Reranking: "task: search result | query: "
 Retrieval: "task: search result | query: "
 Retrieval-query: "task: search result | query: "
 Retrieval-document: "title: none | text: "
 STS: "task: sentence similarity | query: "
 Summarization: "task: summarization | query: "


## Simple RAG example

Retrieval is the task of finding the most relevant pieces of information from a large collection (a database, a set of documents, a website) based on the meaning of a query, not just keywords.

Imagine you work for a company, and you need to find information from the internal employee handbook, which is stored as a collection of hundreds of documents.

In [None]:
#@title Corp knowledge base
corp_knowledge_base = [
  {
    "category": "HR & Leave Policies",
    "documents": [
      {
        "title": "Procedure for Unscheduled Absence",
        "content": "In the event of an illness or emergency preventing you from working, please notify both your direct manager and the HR department via email by 9:30 AM JST. The subject line should be 'Sick Leave - [Your Name]'. If the absence extends beyond two consecutive days, a doctor's certificate (診断書) will be required upon your return."
      },
      {
        "title": "Annual Leave Policy",
        "content": "Full-time employees are granted 10 days of annual paid leave in their first year. This leave is granted six months after the date of joining and increases each year based on length of service. For example, an employee in their third year of service is entitled to 14 days per year. For a detailed breakdown, please refer to the attached 'Annual Leave Accrual Table'."
      },
    ]
  },
  {
    "category": "IT & Security",
    "documents": [
      {
        "title": "Account Password Management",
        "content": "If you have forgotten your password or your account is locked, please use the self-service reset portal at https://reset.ourcompany. You will be prompted to answer your pre-configured security questions. For security reasons, the IT Help Desk cannot reset passwords over the phone or email. If you have not set up your security questions, please visit the IT support desk on the 12th floor of the Shibuya office with your employee ID card."
      },
      {
        "title": "Software Procurement Process",
        "content": "All requests for new software must be submitted through the 'IT Service Desk' portal under the 'Software Request' category. Please include a business justification for the request. All software licenses require approval from your department head before procurement can begin. Please note that standard productivity software is pre-approved and does not require this process."
      },
    ]
  },
  {
    "category": "Finance & Expenses",
    "documents": [
      {
        "title": "Expense Reimbursement Policy",
        "content": "To ensure timely processing, all expense claims for a given month must be submitted for approval no later than the 5th business day of the following month. For example, all expenses incurred in July must be submitted by the 5th business day of August. Submissions after this deadline may be processed in the next payment cycle."
      },
      {
        "title": "Business Trip Expense Guidelines",
        "content": "Travel expenses for business trips will, as a rule, be reimbursed based on the actual cost of the most logical and economical route. Please submit a travel expense application in advance when using the Shinkansen or airplanes. Taxis are permitted only when public transportation is unavailable or when transporting heavy equipment. Receipts are mandatory."
      },
    ]
  },
  {
    "category": "Office & Facilities",
    "documents": [
      {
        "title": "Conference Room Booking Instructions",
        "content": "All conference rooms in the Shibuya office can be reserved through your Calendar App. Create a new meeting invitation, add the attendees, and then use the 'Room Finder' feature to select an available room. Please be sure to select the correct floor. For meetings with more than 10 people, please book the 'Sakura' or 'Fuji' rooms on the 14th floor."
      },
      {
        "title": "Mail and Delivery Policy",
        "content": "The company's mail services are intended for business-related correspondence only. For security and liability reasons, employees are kindly requested to refrain from having personal parcels or mail delivered to the Shibuya office address. The front desk will not be able to accept or hold personal deliveries."
      },
    ]
  },
]


And imagine you have a question like below.

In [None]:
question = "How do I reset my password?" # @param ["How many days of annual paid leave do I get?", "How do I reset my password?", "What travel expenses can be reimbursed for a business trip?", "Can I receive personal packages at the office?"] {type:"string", allow-input: true}

# Define a minimum confidence threshold for a match to be considered valid
similarity_threshold = 0.5 # @param {"type":"slider","min":0,"max":1,"step":0.1}

Search relevant document from the corporate knowledge base.

In [None]:
# --- Helper Functions for Semantic Search ---

def find_best_category(model, query, candidates):
    """
    Finds the most relevant category from a list of candidates.

    Args:
        model: The SentenceTransformer model.
        query: The user's query string.
        candidates: A list of category name strings.

    Returns:
        A tuple containing the index of the best category and its similarity score.
    """
    if not candidates:
        return None, 0.0

    # Encode the query and candidate categories for classification
    query_embedding = model.encode(query, prompt_name="Classification")
    candidate_embeddings = model.encode(candidates, prompt_name="Classification")

    # Calculate cosine similarity
    similarities = model.similarity(query_embedding, candidate_embeddings)

    print(candidates)
    print(similarities)

    # Find the index and value of the highest score
    best_index = similarities.argmax().item()
    best_score = similarities[0, best_index].item()

    return best_index, best_score

def find_best_doc(model, query, candidates):
    """
    Finds the most relevant document from a list of candidates.

    Args:
        model: The SentenceTransformer model.
        query: The user's query string.
        candidates: A list of document dictionaries, each with 'title' and 'content'.

    Returns:
        A tuple containing the index of the best document and its similarity score.
    """
    if not candidates:
        return None, 0.0

    # Encode the query for retrieval
    query_embedding = model.encode(query, prompt_name="Retrieval-query")

    # Encode the document for similarity check
    doc_texts = [
        f"title: {doc.get('title', 'none')} | text: {doc.get('content', '')}"
        for doc in candidates
    ]
    candidate_embeddings = model.encode(doc_texts)

    # Calculate cosine similarity
    similarities = model.similarity(query_embedding, candidate_embeddings)

    print([doc['title'] for doc in candidates])
    print(similarities)

    # Find the index and value of the highest score
    best_index = similarities.argmax().item()
    best_score = similarities[0, best_index].item()

    return best_index, best_score

# --- Main Search Logic ---

# In your application, `best_document` would result from a search.
# We initialize it to None to ensure it always exists.
best_document = None

# 1. Find the most relevat category
print("Step 1: Finding the best category...")
categories = [item["category"] for item in corp_knowledge_base]
best_category_index, category_score = find_best_category(
    model, question, categories
)

# Check if the category score meets the threshold
if category_score < similarity_threshold:
    print(f" `-> 🤷 No relevant category found. The highest score was only {category_score:.2f}.")
else:
    best_category = corp_knowledge_base[best_category_index]
    print(f" `-> ✅ Category Found: '{best_category['category']}' (Score: {category_score:.2f})")

    # 2. Find the most relevant document ONLY if a good category was found
    print("\nStep 2: Finding the best document in that category...")
    best_document_index, document_score = find_best_doc(
        model, question, best_category["documents"]
    )

    # Check if the document score meets the threshold
    if document_score < similarity_threshold:
        print(f" `-> 🤷 No relevant document found. The highest score was only {document_score:.2f}.")
    else:
        best_document = best_category["documents"][best_document_index]
        # 3. Display the final successful result
        print(f" `-> ✅ Document Found: '{best_document['title']}' (Score: {document_score:.2f})")


Step 1: Finding the best category...
['HR & Leave Policies', 'IT & Security', 'Finance & Expenses', 'Office & Facilities']
tensor([[0.5063, 0.5937, 0.4702, 0.4221]])
 `-> ✅ Category Found: 'IT & Security' (Score: 0.59)

Step 2: Finding the best document in that category...
['Account Password Management', 'Software Procurement Process']
tensor([[0.5829, 0.1531]])
 `-> ✅ Document Found: 'Account Password Management' (Score: 0.58)


Next, generate the answer with the retrieved context

In [None]:
qa_prompt_template = """Answer the following QUESTION based only on the CONTEXT provided. If the answer cannot be found in the CONTEXT, write "I don't know."

---
CONTEXT:
{context}
---
QUESTION:
{question}
"""

# First, check if a valid document was found before proceeding.
if best_document and "content" in best_document:
    # If the document exists and has a "content" key, generate the answer.
    context = best_document["content"]

    prompt = qa_prompt_template.format(context=context, question=question)

    messages = [
        {
            "role": "user",
            "content": [{"type": "text", "text": prompt}],
        },
    ]

    print("Question🙋‍♂️: " + question)
    # This part assumes your pipeline and response parsing logic are correct
    answer = pipeline(messages, max_new_tokens=256, disable_compile=True)[0]["generated_text"][1]["content"]
    print("Answer🤖: " + answer)

else:
    # If best_document is None or doesn't have content, give a direct response.
    print("Question🙋‍♂️: " + question)
    print("Answer🤖: I'm sorry, I could not find a relevant document to answer that question.")


Question🙋‍♂️: How do I reset my password?
Answer🤖: Please use the self-service reset portal at https://reset.ourcompany. You will be prompted to answer your pre-configured security questions.


## Summary and next steps

You have now learned how to build a practical RAG system with EmbeddingGemma.

Explore what more you can do with EmbeddingGemma:

* [Generate embeddings with Sentence Transformers](https://ai.google.dev/gemma/docs/embeddinggemma/inference-embeddinggemma-with-sentence-transformers)
* [Fine-tune EmbeddingGemma](https://ai.google.dev/gemma/docs/embeddinggemma/fine-tuning-embeddinggemma-with-sentence-transformers)
* [Mood Palette Generator](https://huggingface.co/spaces/google/mood-palette), an interactive application using EmbeddingGemma