# Enhancing SQL Agent with ChromaDB and LangChain

This notebook builds on the base SQL Agent by integrating **ChromaDB** and **LangChain** to enable retrieval-augmented generation (RAG). With this setup, the agent can first search documentation or table context using semantic similarity before forming the SQL query.

---

## Goals

- Add context-aware capabilities to the SQL Agent using ChromaDB.
- Store and search SQL documentation chunks using embeddings.
- Use LangChain to orchestrate retrieval and response generation.

---

## Pipeline Overview

1. Load SQL documentation.
2. Split into semantic chunks.
3. Embed and store in ChromaDB.
4. On query, retrieve most relevant chunks.
5. Feed context into the LLM for SQL generation.




# Pre requieres

Load the pre requierements functions from *SQL Agent Notebook*. We will use:
 1. `load_schema_from_json` to retrieve the `.json` and create the `str` type that the LLM needs for the prompt

 2. `prompt_builder` to create the prompt easy, by adding the schema and the user question into the prompt

In [None]:
import json
from google.colab import files # We use it only as a option to make easy the upload of the .json file

def load_schema_from_json():
  """ A function that ask from the client
  the json file and creates the schema
  to parse in into the promt later
  Inputs:
  Outputs:
    schema_info: the schema as the information that will pass to the prompt
    schema_data: the raw file that the user upload """

  # Upload the .json file
  print('Upload the .json file:')
  uploaded = files.upload()

  # Check for errors
  if not uploaded:
    print('No file uploaded from the user.')
    return None, None

  file_name = list(uploaded.keys())[0]
  try:
    with open(file_name, 'r') as f:
      schema_data = json.load(f)
  except json.JSONDecodeError as e:
    print('Invalid JSON format.')
    return None, None
  except Exception as e:
    print(f"Error while reading the file: {e}")
    return None, None

  # We have a dict, that we have to make it into str to parse it into LLM
  schema_info = ""
  if 'tables' in schema_data:
    for table in schema_data['tables']:
      schema_info += f"Table: {table['name']}\n"
      for column in table.get('columns', []):
        schema_info += f" {column['name']} {column['type']}\n"
        schema_info += "\n"
  else:
    print("JSON file does not have the expected 'tables' structure.")
    return None, None

  return schema_data, schema_info

In [None]:
def prompt_builder(schmema_info: str, user_question: str) -> str:
  """ The prompt that we will pass to the LLM
      Inputs:
        schemma_info: The json schemma into str
        user_question: The physical language question from the client
      Outputs:
        prompt: The prompt that we will pass to the LLM
  """

  prompt = f"""
  You are a SQL expert. Given the following database schema:
  {schema_info}
  answer the following question:
  {user_question}
  by returning a well formed SQL query, as a raw string and no a markdown.
  You will only return the SQL query, nothing else.
  """
  return prompt


# 1. chroma_loader.py

Compute vector embeddings for each chunk and store them in ChromaDB for fast similarity search.

In [None]:
# install the libraries
!pip install --upgrade chromadb sentence-transformers langchain-community langchain-core langchain huggingface_hub



In [None]:
pip install -U langchain-google-genai

Collecting langchain-google-genai
  Using cached langchain_google_genai-2.1.5-py3-none-any.whl.metadata (5.2 kB)
Collecting filetype<2.0.0,>=1.2.0 (from langchain-google-genai)
  Using cached filetype-1.2.0-py2.py3-none-any.whl.metadata (6.5 kB)
Collecting google-ai-generativelanguage<0.7.0,>=0.6.18 (from langchain-google-genai)
  Using cached google_ai_generativelanguage-0.6.18-py3-none-any.whl.metadata (9.8 kB)
Using cached langchain_google_genai-2.1.5-py3-none-any.whl (44 kB)
Using cached filetype-1.2.0-py2.py3-none-any.whl (19 kB)
Using cached google_ai_generativelanguage-0.6.18-py3-none-any.whl (1.4 MB)
Installing collected packages: filetype, google-ai-generativelanguage, langchain-google-genai
  Attempting uninstall: google-ai-generativelanguage
    Found existing installation: google-ai-generativelanguage 0.6.15
    Uninstalling google-ai-generativelanguage-0.6.15:
      Successfully uninstalled google-ai-generativelanguage-0.6.15
[31mERROR: pip's dependency resolver does not 

In [None]:
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer

def initialize_chroma(collection_name = 'schema_collection'):
  """Initialize ChromaDB with a local persistent collection
    Input:
      collection_name: The name of the collection that we will use
    Output:
      collection: The collection that we will use
  """

  client = chromadb.Client(Settings(anonymized_telemetry = False))

  try:
    collection = client.get_collection(name = collection_name)
  except:
    collection = client.create_collection(name = collection_name)
  return collection

def split_schema(schema_info: str):
  """ Splits the full schema string into individual table chunks
    Input:
      schema_info: The full schema string
    Output:
      table_chunks: A list of table chunks
  """
  chunks = schema_info.strip().split('Table: ')
  table_chunks = [f'Table: {chunk.strip()}' for chunk in chunks if chunk.strip()]
  return table_chunks

def embed_and_store(schema_chunks, collection, embed_model = 'all-MiniLM-L6-v2'):
  """ Embeds each schema chunk and stores it in Chroma
    Input:
      schema_chunks: A list of table chunks
      collection: The collection that we will use
      embed_model: The embedding model that we will use
    Output:
      None
  """

  embedder = SentenceTransformer(embed_model)
  embeddings = embedder.encode(schema_chunks).tolist()

  for i, (chunk, emb) in enumerate(zip(schema_chunks, embeddings)):
    collection.add(
        documents = [chunk],
        embeddings = [emb],
        ids = [f"schema_chunk_{i}"]
    )

  print(f'Stored {len(schema_chunks)} schema chunks in Chroma.')

In [None]:
def upload_and_update_chroma_from_json(
    collection_name='schema_collection',
    embed_model='all-MiniLM-L6-v2'
):
    """
    Uploads a JSON schema, converts it to text chunks,
    embeds them and stores into ChromaDB collection.

    Uses:
    - load_schema_from_json
    - initialize_chroma
    - split_schema
    - embed_and_store
    """

    # Load schema from JSON upload
    schema_data, schema_info = load_schema_from_json()

    if schema_data is None or schema_info is None:
      print('Failed to load the schema from JSON. Skipping ChromaDB update.')
      return

    # Initialize ChromaDB collection
    collection = initialize_chroma(collection_name=collection_name)

    # Split the schema into table-level chunks
    schema_chunks = split_schema(schema_info)

    # Embed and store chunks
    embed_and_store(schema_chunks, collection, embed_model=embed_model)

    print("✅ Schema successfully added to ChromaDB.")


# 2. langchain_agent.py

LangChain’s `RetrievalQA` wraps the retrieval and LLM components together, allowing question-answering with context fetched from the vector store.

In [None]:
from langchain_community.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_core.runnables import Runnable
from langchain_core.prompts import PromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.chains import RetrievalQA
from google.colab import userdata

import os



# Load the retriever from existing Chroma collection

def load_retriever(persist_directory = 'chroma', collection_name = 'schema_collection'):
  """ Loads the retriever from existing Chroma collection
    Input:
      persist_directory: The directory where the Chroma collection is stored
      collection_name: The name of the collection that we will use
    Output:
      retriever: The retriever that we will use
  """

  embedding_model = HuggingFaceEmbeddings(model_name = 'sentence-transformers/all-MiniLM-L6-v2')
  db = Chroma(
      collection_name = collection_name,
      embedding_function= embedding_model,
      persist_directory = persist_directory,
      client = chromadb.Client(Settings(anonymized_telemetry=False)))
  retriever = db.as_retriever(search_kwargs = {'k': 3})
  return retriever

# Define the prompt template to use it in LLM

prompt_template = PromptTemplate(
    input_variables = ['context', 'question'],
    template = """
    You are an expert SQL assistant.
    Based only on the database schema below,
    write a SQL query that answers the user
    question.
    {context}
    User question = {question}
    Give only the SQL query, no explanation.
    """
)



# Load LLM using ChatGoogleGenerativeAI
def load_llm():
    # Ensure the API key is set before initializing the model
    api_key = userdata.get('GOOGLE_API_KEY')
    if not api_key:
        raise ValueError("No found GEMINI_API_KEY at userdata from Colab.")

    # Define environment variable
    os.environ["GOOGLE_API_KEY"] = api_key

    return ChatGoogleGenerativeAI(
        model="models/gemini-2.0-flash",
        temperature=0.2
    )

# Combine everything with RetrievalQA chain

def initialize_agent():
  retriever = load_retriever()
  llm = load_llm()

  qa_chain = RetrievalQA.from_chain_type(
      llm = llm,
      chain_type = 'stuff',
      retriever = retriever,
      return_source_documents = True,
      chain_type_kwargs = {'prompt': prompt_template}
  )

  return qa_chain

In [None]:
# toy examply
upload_and_update_chroma_from_json()

agent = initialize_agent()

query = 'Get the average transaction amount per user in the last month'
response = agent.invoke({"query": query})
print(response['result'])

Upload the .json file:


Saving toy.json to toy (1).json
Stored 4 schema chunks in Chroma.
✅ Schema successfully added to ChromaDB.
```sql
SELECT AVG(total_amount)
FROM orders
WHERE order_date BETWEEN date('now', '-1 month') AND date('now');
```


# 3. agent_runner.py

In [None]:
import os

def run_agent():
    # Upload a new json.
    print('Welcome to the SQL Agent. Do you want to add a new schema to the database? Please answer with "Yes" or "No"')
    answer = input().lower()

    if answer == 'yes':
        upload_and_update_chroma_from_json()

        print('If you need to add another schema in the database, type "add" as a requested question.')

    # Create agent
    agent = initialize_agent()

    print("This is the SQL Agent. Type your question or 'exit' to quit.")
    while True:
        query = input("Enter your question: ").strip()
        if query.lower() == 'exit':
            print("Goodbye!")
            break
        if query.lower() == 'add':
            upload_and_update_chroma_from_json()
            continue

        try:
            response = agent.invoke({"query": query})
            print("SQL Query:\n", response['result'])
            print("\n" + "-"*40 + "\n")
        except Exception as e:
            print(f"Error processing query: {e}")

if __name__ == "__main__":
    run_agent()


Welcome to the SQL Agent. Do you want to add a new schema to the database? Please answer with "Yes" or "No"
Yes
Upload the .json file:


No file uploaded from the user.
Failed to load the schema from JSON. Skipping ChromaDB update.
If you need to add another schema in the database, type "add" as a requested question.
This is the SQL Agent. Type your question or 'exit' to quit.
Enter your question: exit
Goodbye!


# Conclusion

With the integration of ChromaDB and LangChain, this enhanced SQL Agent now performs **context-aware** query generation, making it far more reliable in real-world scenarios. Future improvements could include:

- Support for multi-hop retrieval
- Fine-tuned embedding models
- API deployment with FastAPI