## Small-to-big Retrieval-Augmented Generation 

Keith Ballinger and Megan O'Keefe, 2024

Small-to-big RAG is a form of recursive RAG, where small chunks are connected to larger chunks (eg. a passage of a document is connected to the full document). Initially, only small chunks are retrieved during RAG. If needed, the model can request additional context, at which point the larger context can be added.

Small-to-big RAG is a way to work around the limits of dense, meaning-rich vector embeddings, where passing in too much context can result in a loss of meeting in the vector representation.  
(Text-embedding-gecko has a text input limit of 3,072 tokens)  

Small-to-big is also a good option if you are working with long or complex documents that require context exceeding the ideal chunk size. 

Lastly, Small-to-big can help you optimize inference costs, especially if you're paying per token (eg. Gemini API on Vertex AI) — because if a small amount of input context is enough for the model to generate an answer, you don't have to provide the whole large document as context. Said another way - you're only paying for the number of context tokens you actually need to generate an accurate response from Gemini.

In this example, we'll walk through a Small-to-big RAG example using a GitHub codebase called [Online Boutique](https://github.com/GoogleCloudPlatform/microservices-demo). This is a microservices, polyglot sample application. We'll help a new Online Boutique user (or contributor) navigate this large codebase by providing a Q&A chatbot functionality, grounded on the codebase. 


![](architecture.png)

Prerequisites:
- Google Cloud account
- Google Cloud project with billing enabled 
- Enable the Vertex AI API

This notebook uses the following products and tools:
- Vertex AI - Gemini API 
- Vertex AI - Text Embeddings API 
- Chroma (in-memory vector database)  




### Setup 

In [1]:
EMBEDDING_MODEL="textembedding-gecko@003"
GENERATIVE_MODEL="gemini-1.0-pro"
PROJECT_ID="mokeefe-genai-test"
REGION="us-central1"

In [15]:
! pip install "google-cloud-aiplatform>=1.38"


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.10 -m pip install --upgrade pip[0m


In [16]:
from google.cloud import aiplatform
import vertexai
from vertexai.language_models import TextEmbeddingModel
from vertexai.generative_models import GenerativeModel, ChatSession
import os 


### Create helper functions

We'll create one function that calls Vertex AI text-embeddings-gecko, and another that inferences Gemini Pro on Vertex AI.

In [17]:
def get_text_embedding(doc) -> list:
    model = TextEmbeddingModel.from_pretrained(EMBEDDING_MODEL)
    embeddings = model.get_embeddings([doc])
    if len(embeddings) > 1:
        raise ValueError("More than one embedding returned.")
    if len(embeddings) == 0:
        raise ValueError("No embedding returned.")
    return embeddings[0].values

In [20]:
vertexai.init(project=PROJECT_ID, location=REGION)
model = GenerativeModel(GENERATIVE_MODEL)
chat = model.start_chat()

In [25]:
def gemini_inference(chat: ChatSession, prompt: str) -> str:
    text_response = model.generate_content(prompt)
    return text_response.text

In [26]:
gemini_inference(chat, "hello world!")

'Bonjour le monde!'

### Get summaries of code files 

First, we'll use Gemini on Vertex AI to get short summaries of each code file.  We'll do this by recursively traversing the files in `onlineboutique-codefiles`. 

In [27]:
# for every file in onlineboutique-codefiles/, read it in, and get the full tree filename, and a code summary 
summaries = {} 
for root, dirs, files in os.walk("onlineboutique-codefiles/"):
    for file in files:
        temp = {}
        full_file_path = os.path.join(root, file)
        with open(full_file_path, "r") as f:
            print("Processing file: ", full_file_path)
            try:
                content = f.read()     
                temp["content"] = content 
                prompt = """ 
                You are a helpful code summarizer. Here is a source code file. Please identify the programming language and summarize it in three sentences or less. Give as much detail as possible, including function names and libraries used. Code: 
                {}
                """.format(content)
                summary = gemini_inference(chat, prompt)
                temp["summary"] = summary
                summaries[full_file_path] = temp
            except Exception as e:
                print("⚠️ Error processing file: {} - {}".format(full_file_path, e))

Processing file:  onlineboutique-codefiles/.DS_Store
⚠️ Error processing file: onlineboutique-codefiles/.DS_Store - 'utf-8' codec can't decode byte 0xff in position 578: invalid start byte
Processing file:  onlineboutique-codefiles/LICENSE
Processing file:  onlineboutique-codefiles/cloudbuild.yaml
Processing file:  onlineboutique-codefiles/README.md
Processing file:  onlineboutique-codefiles/skaffold.yaml
Processing file:  onlineboutique-codefiles/terraform/output.tf
Processing file:  onlineboutique-codefiles/terraform/main.tf
Processing file:  onlineboutique-codefiles/terraform/terraform.tfvars
Processing file:  onlineboutique-codefiles/terraform/providers.tf
Processing file:  onlineboutique-codefiles/terraform/README.md
Processing file:  onlineboutique-codefiles/terraform/memorystore.tf
Processing file:  onlineboutique-codefiles/terraform/variables.tf
Processing file:  onlineboutique-codefiles/protos/demo.proto
Processing file:  onlineboutique-codefiles/protos/grpc/health/v1/health.p

In [28]:
import pandas as pd
df =  pd.DataFrame.from_dict(summaries, orient='index')

In [29]:
df.head()

Unnamed: 0,content,summary
onlineboutique-codefiles/LICENSE,\n Apache Lice...,This is a text file describing the Apache Lice...
onlineboutique-codefiles/cloudbuild.yaml,# Copyright 2020 Google LLC\n#\n# Licensed und...,This configuration file uses Google Cloud Buil...
onlineboutique-codefiles/README.md,"<p align=""center"">\n<img src=""/src/frontend/st...",Online Boutique is a web-based microservices d...
onlineboutique-codefiles/skaffold.yaml,# Copyright 2021 Google LLC\n#\n# Licensed und...,This Skaffold configuration file defines build...
onlineboutique-codefiles/terraform/output.tf,# Copyright 2022 Google LLC\n#\n# Licensed und...,This Terraform configuration outputs data from...


In [31]:
# number of file summaries 
print("Number of rows: ", df.shape[0])

Number of rows:  109


In [32]:
# the first column should be named "filename"
df = df.reset_index()
df = df.rename(columns = {'index':'filename'})
df.head()

Unnamed: 0,filename,content,summary
0,onlineboutique-codefiles/LICENSE,\n Apache Lice...,This is a text file describing the Apache Lice...
1,onlineboutique-codefiles/cloudbuild.yaml,# Copyright 2020 Google LLC\n#\n# Licensed und...,This configuration file uses Google Cloud Buil...
2,onlineboutique-codefiles/README.md,"<p align=""center"">\n<img src=""/src/frontend/st...",Online Boutique is a web-based microservices d...
3,onlineboutique-codefiles/skaffold.yaml,# Copyright 2021 Google LLC\n#\n# Licensed und...,This Skaffold configuration file defines build...
4,onlineboutique-codefiles/terraform/output.tf,# Copyright 2022 Google LLC\n#\n# Licensed und...,This Terraform configuration outputs data from...


In [39]:
# print a random list of 10 summaries
import random
for i in range(10):
    print("Filename: ", df.iloc[random.randint(0, df.shape[0])]["filename"])
    print("Summary: ", df.iloc[random.randint(0, df.shape[0])]["summary"])
    print("\n")

Filename:  onlineboutique-codefiles/src/currencyservice/proto/demo.proto
Summary:  This Go code defines a gRPC service for a product catalog. It includes methods for listing products, getting a specific product, searching products, and health checks. The service uses the gRPC health API for health checks and the `github.com/GoogleCloudPlatform/microservices-demo/src/productcatalogservice/genproto` library for the gRPC service definition.


Filename:  onlineboutique-codefiles/src/productcatalogservice/server.go
Summary:  This code is written in Dockerfile and defines multi-stage Docker image builds. It creates a Go application and packages it into a container. Key functions include `go mod download`, `go build`, and `COPY`. The image uses the `golang:1.22.1-alpine` and `alpine:3.19.1` base images and exposes port 3550.


Filename:  onlineboutique-codefiles/src/productcatalogservice/genproto.sh
Summary:  This JavaScript code defines a function called `charge` that validates and processes

In [33]:
# write to csv
df.to_csv("code_summaries.csv", index=False)

### Convert summaries to embeddings

In [None]:
! pip install chromadb



In [35]:
import chromadb
chroma_client = chromadb.Client()

In [37]:
collection = chroma_client.create_collection(name="code_summaries")



In [None]:

# iterate over dataframe. convert summary into embeddings. insert summary into collection. 
for index, row in df.iterrows():
    fn = row["filename"]
    print("Getting embedding for: ", fn)
    summ = row["summary"] 
    print(summ)
    e = get_text_embedding(summ) 
    print(e)
    # add vector embedding to in-memory Chroma database. 
    # the "small" summary embedding is linked to the "big" raw code file through the metadata key, "filename." 
    collection.add(
        embeddings=[e],
        documents=[summ],
        metadatas=[{"filename": fn}],
        ids=[fn])
    

### Complete the Small-to-big RAG workflow 

In [None]:
all_files = []
for root, dirs, files in os.walk("onlineboutique-codefiles/"):
    for file in files:
        all_files.append(os.path.join(root, file))
print(all_files)

In [57]:
def small_to_big(user_prompt):
    # SMALL: first, run RAG with the summary embeddings to try to get a response
    query_emb = get_text_embedding(user_prompt)
    result = collection.query(
        query_embeddings=[query_emb], 
        n_results=3
    )
    # process nearest-neighbors 
    processed_result = {}
    d = result["documents"][0]
    for i in range(0, len(d)):
        summary = d[i]
        filename = result["metadatas"][0][i]["filename"]
        processed_result[filename] = summary
    prompt_with_small = """
    You are a codebase helper. You will be given a user's question about the codebase, along with 
    summaries of relevant code files. Attempt to answer the question and only respond if you're confident in the answer. 
    If you need any more information, respond with ONLY the phrase "need more context". 

    The user query is: {} 

    The summaries are: {}
    """.format(user_prompt, str(processed_result))
    print(prompt_with_small)
    small_result = gemini_inference(chat, prompt_with_small) 
    # we're done if Gemini is confident with just the summaries as context... 
    if "need more context" not in small_result.lower():
        return "🐝 Completed at small, Gemini had enough context to respond. RESPONSE: \n" + small_result
        
    # otherwise, move on to BIG: 
    # IF we need the full context, get the filename that most closely matches the user's question
    prompt_to_get_filename = """ 
    You are a codebase helper. The list of code files that you know about: 
    {}

    The user asks the following question about the codebase: {}

    Please respond with the filename that most closely matches the user's question. Respond with ONLY the filename. 
    """.format(all_files, user_prompt)
    filename = gemini_inference(chat, prompt_to_get_filename)
    print("Prompted for a specific filename, Gemini said: " + filename)
    # is the filename in the dataframe? 
    if filename not in df["filename"].values:
        # attempt to try again, appending "onlineboutique-codefiles"  
        filename = "onlineboutique-codefiles/" + filename
        if filename not in df["filename"].values:
            return "⚠️ Error: filename {} not found in dataframe".format(filename)
    
    # get the full code file 
    full_code = df[df["filename"] == filename]["content"].values[0]
    prompt_with_big = """ 
    You are a codebase helper. You will be given a user's question about the codebase, along with a complete source code file. Respond to the user's question with as much detail as possible.

    The user query is: {}
    
    The full code file is: {}
    """.format(user_prompt, full_code) 

    big_response = gemini_inference(chat, prompt_with_big) 
    return "🦖 Completed at big. RESPONSE: \n" + big_response

### Test it out 

In [55]:
# an example of a query where only the small (summary) step is needed
small_to_big("How does the ad service work?")


    You are a codebase helper. You will be given a user's question about the codebase, along with 
    summaries of relevant code files. Attempt to answer the question and only respond if you're confident in the answer. 
    If you need any more information, respond with ONLY the phrase "need more context". 

    The user query is: How does the ad service work? 

    The summaries are: {'onlineboutique-codefiles/src/adservice/src/main/java/hipstershop/AdService.java': 'This is a Java gRPC server for an ad service.\nIt has a `getAds` function that takes a request containing context and returns a response with a list of ads.\nIt uses a map of categories to ads to retrieve the ads and a random function to retrieve random ads if no category is specified.', 'onlineboutique-codefiles/src/adservice/src/main/java/hipstershop/AdServiceClient.java': 'This Java code defines a client that interacts with an Ads Service using gRPC. The `AdServiceClient` class establishes a connection to the service

'🐝 Completed at small, Gemini had enough context to respond. RESPONSE: \nThe Ad Service is a gRPC service that provides advertisements based on context keys. It has a `getAds` function that takes a request containing a context key and returns a response with a list of ads.\nThe service uses a map of categories to ads to retrieve the ads, and a random function to retrieve random ads if no category is specified.\nThe Ads Service is implemented in Java and uses the Gradle build tool to compile and package the application. It can be built locally using the `./gradlew installDist` command or by building a Docker image using the `docker build ./` command from the `src/adservice/` directory.'

In [59]:
small_to_big("Exactly how long is the kubectl wait condition in the Terraform deployment of online boutique? Return the right number of seconds")


    You are a codebase helper. You will be given a user's question about the codebase, along with 
    summaries of relevant code files. Attempt to answer the question and only respond if you're confident in the answer. 
    If you need any more information, respond with ONLY the phrase "need more context". 

    The user query is: Exactly how long is the kubectl wait condition in the Terraform deployment of online boutique? Return the right number of seconds 

    The summaries are: {'onlineboutique-codefiles/terraform/README.md': "This Terraform script deploys the Online Boutique sample application, a microservices-based application, on a Google Kubernetes Engine (GKE) cluster. It provisions resources such as a GKE cluster, Memorystore (Redis) instance (optional), and a Kubernetes deployment for the application. Functions like `terraform init` and `terraform apply` are used to initialize and create the resources. Libraries like `kubectl` are used to retrieve the frontend's external 

'🦖 Completed at big. RESPONSE: \nThe kubectl wait condition in the Terraform deployment of online boutique is 280 seconds. This is specified in the `command` argument of the `null_resource "wait_conditions"` resource:\n\n```\ncommand     = <<-EOT\n    kubectl wait --for=condition=AVAILABLE apiservice/v1beta1.metrics.k8s.io --timeout=180s\n    kubectl wait --for=condition=ready pods --all -n ${var.namespace} --timeout=280s\n    EOT\n```\n\nThe `--timeout` flag specifies the maximum amount of time to wait for the condition to be met. In this case, the condition is that all pods in the `var.namespace` namespace are ready. If the condition is not met within 280 seconds, the command will fail and the deployment will not proceed.'

# Solution terraform code in main.tf  - 280 seconds is correct  
"""

# Wait condition for all Pods to be ready before finishing
resource "null_resource" "wait_conditions" {
  provisioner "local-exec" {
    interpreter = ["bash", "-exc"]
    command     = <<-EOT
    kubectl wait --for=condition=AVAILABLE apiservice/v1beta1.metrics.k8s.io --timeout=180s
    kubectl wait --for=condition=ready pods --all -n ${var.namespace} --timeout=280s
    EOT
  }

  depends_on = [
    resource.null_resource.apply_deployment
  ]
}

"""

In [61]:
small_to_big("What tracing frameworks are used across the codebase?") 


    You are a codebase helper. You will be given a user's question about the codebase, along with 
    summaries of relevant code files. Attempt to answer the question and only respond if you're confident in the answer. 
    If you need any more information, respond with ONLY the phrase "need more context". 

    The user query is: What tracing frameworks are used across the codebase? 

    The summaries are: {'onlineboutique-codefiles/src/emailservice/requirements.in': 'This Python code uses the `google-cloud-profiler` and `google-cloud-trace` libraries to profile and trace Google Cloud applications. It also uses the `opentelemetry-distro` library for distributed tracing and the `grpcio` library for handling gRPC requests. The `requests` library is employed for making HTTP requests.', 'onlineboutique-codefiles/src/recommendationservice/requirements.in': 'This Python code uses the Google Cloud Profiler library to perform profiling and monitoring operations. It utilizes the gRPC health

'🐝 Completed at small, Gemini had enough context to respond. RESPONSE: \nThe codebase uses `google-cloud-profiler`, `google-cloud-trace`, `opentelemetry-distro`, `grpcio`, and `requests` for profiling and tracing in Python, and `Cloud Profiler`, `grpc`, `grpc-gateway`, `OpenTelemetry`, and `Logrus` for profiling, tracing, logging, and health checking in Go.'

In [62]:
small_to_big("Describe in detail exactly how the ListRecommendations function works.")


    You are a codebase helper. You will be given a user's question about the codebase, along with 
    summaries of relevant code files. Attempt to answer the question and only respond if you're confident in the answer. 
    If you need any more information, respond with ONLY the phrase "need more context". 

    The user query is: Describe in detail exactly how the ListRecommendations function works. 

    The summaries are: {'onlineboutique-codefiles/src/recommendationservice/client.py': 'This is a Python script that uses gRPC to communicate with a RecommendationService. It imports necessary libraries and sets up a server stub. The script forms a request and makes a call to the server, then logs the response using a custom JSON logger. It allows users to specify a port number as an argument.', 'onlineboutique-codefiles/src/productcatalogservice/.dockerignore': 'This is a directory listing. It is not a source code file, so I cannot provide a language identification or summary.', 'onl

"🦖 Completed at big. RESPONSE: \nThe `ListRecommendations` function in the `recommendationservice` retrieves a list of recommended products from the product catalog. It receives a request object containing a list of product IDs, and returns a response object containing a list of recommended product IDs.\n\nThe function works as follows:\n\n1. It fetches a list of all products from the product catalog using the `ListProducts` method of the `product_catalog_stub` gRPC client.\n\n2. It filters out any products that are already included in the request's list of product IDs, to avoid recommending products that the user has already seen.\n\n3. It randomly selects a subset of the remaining products to return as recommendations.\n\n4. It builds and returns a response object containing the list of recommended product IDs."