## Small-to-big Retrieval-Augmented Generation 

Keith Ballinger and Megan O'Keefe, 2024

Small-to-big RAG is a form of recursive RAG, where small chunks are connected to larger chunks (eg. a passage of a document is connected to the full document). Initially, only small chunks are retrieved during RAG. If needed, the model can request additional context, at which point the larger context can be added.

Small-to-big RAG is a way to work around the limits of dense, meaning-rich vector embeddings, where passing in too much context can result in a loss of meeting in the vector representation.  
(Text-embedding-gecko has a text input limit of 3,072 tokens)  

Small-to-big is also a good option if you are working with long or complex documents that require context exceeding the ideal chunk size. 

Lastly, Small-to-big can help you optimize inference costs, especially if you're paying per token (eg. Gemini API on Vertex AI) — because if a small amount of input context is enough for the model to generate an answer, you don't have to provide the whole large document as context. Said another way - you're only paying for the number of context tokens you actually need to generate an accurate response from Gemini.

In this example, we'll walk through a Small-to-big RAG example using a GitHub codebase called [Online Boutique](https://github.com/GoogleCloudPlatform/microservices-demo). This is a microservices, polyglot sample application. We'll help a new Online Boutique user (or contributor) navigate this large codebase by providing a Q&A chatbot functionality, grounded on the codebase. 


![](architecture.png)

Prerequisites:
- Google Cloud account
- Google Cloud project with billing enabled 
- Enable the Vertex AI API

This notebook uses the following products and tools:
- Vertex AI - Gemini API 
- Vertex AI - Text Embeddings API 
- Chroma (in-memory vector database)  




### Setup 

In [24]:
EMBEDDING_MODEL="textembedding-gecko@003"
GENERATIVE_MODEL="gemini-1.0-pro"
PROJECT_ID="mokeefe-genai-test"
REGION="us-central1"

In [None]:
! pip install "google-cloud-aiplatform>=1.38"

In [28]:
from google.cloud import aiplatform
import vertexai
from vertexai.language_models import TextEmbeddingModel
from vertexai.generative_models import GenerativeModel, ChatSession
import os 


In [18]:
def get_text_embedding(doc) -> list:
    model = TextEmbeddingModel.from_pretrained(EMBEDDING_MODEL)
    embeddings = model.get_embeddings([doc])
    if len(embeddings) > 1:
        raise ValueError("More than one embedding returned.")
    if len(embeddings) == 0:
        raise ValueError("No embedding returned.")
    return embeddings[0].values

In [43]:
vertexai.init(project=PROJECT_ID, location=REGION)
model = GenerativeModel(GENERATIVE_MODEL)
chat = model.start_chat()

def gemini_inference(chat: ChatSession, prompt: str) -> str:
    text_response = []
    responses = chat.send_message(prompt, stream=False)
    for chunk in responses:
        text_response.append(chunk.text)
    return "".join(text_response)

In [45]:
code = """
func init() {
	log = logrus.New()
	log.Level = logrus.DebugLevel
	log.Formatter = &logrus.JSONFormatter{
		FieldMap: logrus.FieldMap{
			logrus.FieldKeyTime:  "timestamp",
			logrus.FieldKeyLevel: "severity",
			logrus.FieldKeyMsg:   "message",
		},
		TimestampFormat: time.RFC3339Nano,
	}
	log.Out = os.Stdout
"""
prompt = """
Summarize the following code: 
{}
""".format(code)
result = gemini_inference(chat, prompt)
print(result)

This code sets up a logging mechanism using the Logrus library. It configures the logging level to debug, uses a custom JSONFormatter with specific field names and timestamps, and sets the logging output to the standard output stream.


### Get summaries of code files 

First, we'll use Gemini on Vertex AI to get short summaries of each code file. 

In [46]:
# for every file in onlineboutique-codefiles/, read it in, and get a summary 
summaries = {} 
for file in os.listdir("onlineboutique-codefiles/"):
    temp = {}
    with open("onlineboutique-codefiles/" + file, "r") as f:
        print("Processing file: ", file)
        content = f.read()     
        temp["content"] = content 
        prompt = """ 
        You are a helpful code summarizer. Here is a source code file. Please identify the programming language and summarize it in three sentences or less. Give as much detail as possible, including function names and libraries used. Code: 
        {}
        """.format(content)
        try: 
            summary = gemini_inference(chat, prompt)
            temp["summary"] = summary
            summaries[file] = temp
        except Exception as e:
            print("Error processing file: ", file)
            print(e)

Processing file:  checkoutservice.go
Processing file:  shippingservice.go
Processing file:  adservice.java
Processing file:  paymentservice.js
Processing file:  recommendationservice.py
Processing file:  emailservice.py
Processing file:  frontend.go
Processing file:  productcatalog.go
Processing file:  cartservice.cs
Processing file:  currencyservice.js


In [72]:
import pandas as pd
df =  pd.DataFrame.from_dict(summaries, orient='index')

In [None]:
df.head()

In [74]:
# the first column should be named "filename"
df = df.reset_index()
df = df.rename(columns = {'index':'filename'})
df.head()

Unnamed: 0,filename,content,summary
0,checkoutservice.go,// Copyright 2018 Google LLC\n//\n// Licensed ...,This Go program implements an order checkout m...
1,shippingservice.go,// Copyright 2018 Google LLC\n//\n// Licensed ...,This Go program implements a shipping service ...
2,adservice.java,"/*\n * Copyright 2018, Google LLC.\n *\n * Lic...",This Java program implements an ad service usi...
3,paymentservice.js,// Copyright 2018 Google LLC\n//\n// Licensed ...,This JavaScript code defines a class `HipsterS...
4,recommendationservice.py,#!/usr/bin/python\n#\n# Copyright 2018 Google ...,This Python code implements a gRPC server for ...


In [76]:
# write to csv
df.to_csv("code_summaries.csv", index=False)

### Convert summaries to embeddings

In [None]:
! pip install chromadb



In [115]:
import chromadb
chroma_client = chromadb.Client()

In [117]:
collection = chroma_client.create_collection(name="code_summaries2")



In [None]:

# iterate over dataframe. convert summary into embeddings. insert summary into collection. 
for index, row in df.iterrows():
    fn = row["filename"]
    print("Getting embedding for: ", fn)
    summ = row["summary"] 
    print(summ)
    e = get_text_embedding(summ) 
    print(e)
    # add vector embedding to in-memory Chroma database. 
    # the "small" summary embedding is linked to the "big" raw code file through the metadata key, "filename." 
    collection.add(
        embeddings=[e],
        documents=[summ],
        metadatas=[{"filename": fn}],
        ids=[fn])
    

### Complete the Small-to-big RAG workflow 

In [130]:
def small_to_big(user_prompt):
    # SMALL: first, run RAG with the summary embeddings to try to get a response
    query_emb = get_text_embedding(user_prompt)
    result = collection.query(
        query_embeddings=[query_emb], 
        n_results=1
    )
        
    # process results into something we can pass to Gemini 
    processed_result = {}
    d = result["documents"][0]
    for i in range(0, len(d)):
        summary = d[i]
        filename = result["metadatas"][0][i]["filename"]
        processed_result[filename] = summary
  
    prompt_with_small = """
    You are a codebase helper. You will be given a user's question about the codebase, along with 
    summaries of relevant code files. Do your best to generate an answer with the short summaries.  
    If you feel confident in your answer, respond with "The answer is:" followed by your answer.
    If you need more information, respond with "Need full context," and you will be prompted again with the 
    full relevant code file. 

    The user query is: {} 

    The summaries are: {}
    """.format(user_prompt, str(processed_result))
    print(prompt_with_small)
    small_result = gemini_inference(chat, prompt_with_small) 
    print("🐝 SMALL RESULT: " + small_result)
    # we're done if Gemini is confident with just the summaries as context... 
    if "need full context" not in small_result.lower():
        return "Gemini responded with just the small summaries as context: " + small_result + " 🎉"
        
    # otherwise, move on to BIG: 
    # IF we need the full context, get the filename that most closely matches the user's question
    prompt_to_get_filename = """ 
    You are a codebase helper. The list of code files that you know about: 
    - checkoutservice.go
    - shippingservice.go
    - adservice.java
    - paymentservice.js
    - recommendationservice.py
    - emailservice.py
    - frontend.go
    - productcatalog.go
    - cartservice.cs
    - currencyservice.js

    The user asks the following question about the codebase: {}

    Please respond with the filename that most closely matches the user's question. Respond with ONLY the filename. 
    """.format(user_prompt)
    filename = gemini_inference(chat, prompt_to_get_filename)
    print("Prompted for a specific filename, Gemini said: " + filename)
    # is the filename in the dataframe? 
    if filename not in df["filename"].values:
        return "⚠️ Error: filename {} not found in dataframe".format(filename)
    
    # get the full code file 
    full_code = df[df["filename"] == filename]["content"].values[0]
    prompt_with_big = """ 
    You are a codebase helper. You will be given a user's question about the codebase, along with a complete source code file. Respond to the user's question with as much detail as possible.

    The user query is: {}
    
    The full code file is: {}
    """.format(user_prompt, full_code) 

    return gemini_inference(chat, prompt_with_big)

### Test it out 

In [129]:
# an example of a query where only the small (summary) step is needed
small_to_big("How does the ad service work?")


    You are a codebase helper. You will be given a user's question about the codebase, along with 
    summaries of relevant code files. Do your best to generate an answer with the short summaries.  
    If you feel confident in your answer, respond with "The answer is:" followed by your answer.
    If you need more information, respond with "Need full context," and you will be prompted again with the 
    full relevant code file. 

    The user query is: How does the ad service work? 

    The summaries are: {'adservice.java': 'This JavaScript code sets up a gRPC server for a currency conversion service using the `@grpc/grpc-js` library. It includes functions for:\n- Initializing OpenTelemetry tracing (if enabled) and Stackdriver Profiler (if enabled).\n- Loading protocol buffer definitions and adding corresponding service definitions to the gRPC server.\n- Handling requests for supported currencies and currency conversions.\n- Handling health checks.\n- Starting the gRPC server on a

'Gemini responded with just the small summaries as context: The Ad service retrieves ads for a user, targeted based on the context provided in the request. If no context is provided, it serves random ads. It uses a Guava library for data structures and collections. The main class `AdService` starts and stops the gRPC server and blocks the main thread until server shutdown to prevent daemon threads from terminating the program prematurely. It also includes methods for retrieving ads by category or serving random ads if no category is specified. Additionally, it initializes OpenTelemetry stats and tracing (if enabled). The `AdServiceImpl` class implements the `AdServiceGrpc.AdServiceImplBase` interface, which defines the `getAds` method that processes ad requests and returns an `AdResponse` with a list of ads. 🎉'

In [131]:
# an example of a detailed query that requires the full code file (big)
small_to_big("exactly how does the SearchProducts() function work in the product catalog service?")


    You are a codebase helper. You will be given a user's question about the codebase, along with 
    summaries of relevant code files. Do your best to generate an answer with the short summaries.  
    If you feel confident in your answer, respond with "The answer is:" followed by your answer.
    If you need more information, respond with "Need full context," and you will be prompted again with the 
    full relevant code file. 

    The user query is: exactly how does the SearchProducts() function work in the product catalog service? 

    The summaries are: {'productcatalog.go': 'This JavaScript code sets up a gRPC server for a currency conversion service using the `@grpc/grpc-js` library. It includes functions for:\n- Initializing OpenTelemetry tracing (if enabled) and Stackdriver Profiler (if enabled).\n- Loading protocol buffer definitions and adding corresponding service definitions to the gRPC server.\n- Handling requests for supported currencies and currency conversions.\n-

'The `SearchProducts` function in the product catalog service takes a `SearchProductsRequest` that contains a query string and returns a `SearchProductsResponse` containing a list of products that match the query. The function first sleeps for a specified latency period to simulate network latency. It then iterates through the catalog of products and checks if the product name or description contains the query string in a case-insensitive manner. If a match is found, the product is added to the list of results. Finally, the function returns the list of results.'

In [None]:
# the solution snippet of productcatalog.go ^ matches Gemini's response.  
"""
func (p *productCatalog) SearchProducts(ctx context.Context, req *pb.SearchProductsRequest) (*pb.SearchProductsResponse, error) {
	time.Sleep(extraLatency)

	var ps []*pb.Product
	for _, product := range p.parseCatalog() {
		if strings.Contains(strings.ToLower(product.Name), strings.ToLower(req.Query)) ||
			strings.Contains(strings.ToLower(product.Description), strings.ToLower(req.Query)) {
			ps = append(ps, product)
		}
	}

	return &pb.SearchProductsResponse{Results: ps}, nil
}
"""