# Text-to-SQL Using Retrieval Augmented Generation (RAG)
Use of RAG to improve performance of Text-to-SQL use cases

---
---

## Suggested SageMaker Environment
Sagemaker Image: sagemaker-distribution-cpu

Kernel: Python 3

Instance Type: ml.m5.large

---

## Contents

1. [Install Dependencies](#step-1-install-dependencies)
1. [Configure Bedrock Embeddings](#step-2-configure-bedrock-embedding-model-and-llm)
1. [Configure Athena and Bedrock Client](#step-3-configure-athena-and-bedrock-client)
1. [Create Helper Functions](#step-4-create-helper-functions)
1. [Configure Bedrock Embedding Model](#step-5-configure-bedrock-embedding-model)
1. [Fetch TPD-DS Meta Data](#step-6-fetch-tpc-ds-dataset-tables-and-columns-information)
1. [Embed Questions and Metadata](#step-7-embed-all-the-questions-and-metadata)
1. [Build Prompt and Generate Query](#step-8-build-prompt-and-generate-sql-query)

---

## Objective
This notebook will provide code snippets to assist with implementing one approach to converting a natural language question into a SQL query that would answer it.

---

## The Approach to the Text-to-SQL Problem


We'll walk through seting up a Bedrock embedding Model and LLM to Embed the table metadata.

First, we'll get the metadata from Athena.

Second, we'll use the meta data and ask the LLM to generate possible questions that could be answered with each table. 

Third, we'll embed all the metadata and generated questions in a Vector Store. We'll use an in-memory Vector store called [FAISS](https://faiss.ai/index.html), but one could also use a long-running store such as Amazon OpenSearch. We'll use semantic similarity to retrieve tables and columns that could help us answer the question being asked.

Last, We'll design a robust prompt to incorporate our embeddings, instructions, few-shot examples, and of course our question.

![Alt text](content/rag.png)

### Tools
Langchain, Amazon Bedrock SDK (boto3)

---

### Step 1: Install Dependencies

Here we will install all the required dependencies to run this notebook. 

In [None]:
!python -m ensurepip --upgrade
!pip install "sqlalchemy" --quiet
!pip install "boto3~=1.34"  --quiet
!pip install "jinja2" --quiet
!pip install "botocore" --quiet
!pip install "pandas" --quiet
!pip install "PyAthena" --quiet
!pip install "faiss-cpu" --quiet
!pip install langchain --quiet
!pip install jq --quiet

### Step 2: Configure Bedrock embedding Model and LLM

In [None]:
import os
import sys
import json
from langchain.document_loaders.json_loader import JSONLoader
from langchain.docstore.document import Document
import json
import boto3
from botocore.config import Config
import re
from langchain.vectorstores import FAISS
from langchain.embeddings import BedrockEmbeddings
from functools import reduce
from langchain.prompts import PromptTemplate
from sqlalchemy import MetaData
from sqlalchemy import create_engine

sys.path.append('../')
from libs.din_sql import din_sql_lib as dsl

### Step 3: Configure Athena and Bedrock Client

In [None]:
ATHENA_RESULTS_S3_LOCATION = "<workshop bucket name>" # available in cloudformation outputs
ATHENA_CATALOG_NAME = "<athena catalog name>" # available in cloudformation outputs
DB_NAME = "tpcds1"
DB_FAISS_PATH = './vectorstore/db_faiss'

In [None]:
bedrock_region = athena_region = boto3.session.Session().region_name

In [None]:
retry_config = Config(retries = {'max_attempts': 100})
session = boto3.Session(region_name=bedrock_region)
bedrock = session.client('bedrock-runtime', region_name=bedrock_region, config=retry_config)

### Step 4: Create Helper Functions

Here we wrap our bedrock call into a method to accept a question and a model id

In [None]:
def ask_llm(question,modelId):
    
    body = ''
    if 'titan' in modelId:
         model_kwargs = {"maxTokenCount": 200, "temperature": 0.001} 
         input_body = dict()
         input_body["inputText"] = question
         input_body["textGenerationConfig"] = {**model_kwargs}
         body = json.dumps(input_body)

    else:
          body = json.dumps({
                    "prompt": question,
                    "max_tokens_to_sample":4096,
                    "temperature":0.5,
                    "top_k":250,
                    "top_p":0.5,
                  }) 
    
    accept = 'application/json'
    contentType = 'application/json'

    response = bedrock.invoke_model(body=body, modelId=modelId, accept=accept, contentType=contentType)
    response_body = json.loads(response.get('body').read())
    print(response_body)
    return response_body

Once the LLM returns the list of questions that can be answered using the given table and column, our `write_questions_to_file` method will handle storing them locally in a json file

In [None]:
def write_questions_to_file(question_list_filename, table_name, table_schema, answer):
    data_list = []
    question_list_obj = answer
    questions_list = question_list_obj.splitlines()
    print(questions_list)
    # Open the file in write mode
    with open(question_list_filename, mode="w", newline="") as file:
        for question in questions_list:

            # Skip if it doesn't really have a question
            if "?" not in question:
                continue

            questionSplit = re.split(r"\d{1,5}.||. ||- ", question, maxsplit=1)
            print(questionSplit)
            question = questionSplit[1]
            data = {
                "tableName": table_name,
                "question": question,
                "tableSchema": table_schema.lstrip(" "),
            }
            data_list.append(data)

        json.dump(data_list, file)

We'll need a method to accept a list of documents, and return a list of the same documents with their metadata attached. We'll also need a function to help us load a json and return a JSON object.

In [None]:
# Create new docs with the right metadata we need for indexing
def create_docs_with_correct_metadata(documents):
    # We are going to return a list of new documents
    new_docs = []

    # For each document
    for doc in documents:
        # Get it's metadata and contents
        metadata = doc.metadata
        contents = json.loads(doc.page_content)

        # Now calculate the new metadata that we want to add
        new_metadata = {
            "tableName": contents["tableName"],
            "question": contents["question"],
            "tableSchema": contents["tableSchema"],
        }

        # Print out the new metadata for our documents
        # print(new_metadata)

        new_docs.append(
            Document(page_content=new_metadata["question"], metadata=new_metadata)
        )

    return new_docs

def load_json_file(filename):
    loader = JSONLoader(file_path=filename, jq_schema=".[]", text_content=False)

    # This is our internal Langchain document data structure
    docs = loader.load()
    return docs

This function asks the LLM to inspect a table schema, generate some questions which could be answered by that schema, and then it stores those questions to a file. Finally, it appends them all into a single vector database.
The below prompt is used to get the questions in natural language and it calls all the helper function to embed the questions and table metadata

In [None]:
# This function asks the LLM to inspect a table schema, generate some questions which could be answered
# by that schema, and then it stores those questions to file, loads them all into a single vectorDB
def add_new_table(schema, table_name,model_id,is_incremental, bedrock_embeddings):
    """
    :schema         :   
    :table_name     :
    :model_id       :
    :is_incremental :
    """
    print(f"Adding table {table_name} with schema {schema}")
    
    question = f"""
    \n\nHuman: 
    only return the a bulleted numbered list of unique and detailed questions that could be answered by this table called {table_name} with schema:
    {schema}.
    Instructions:
        Use natural language descriptions only.
        Do not use SQL.
        Produced a varied list of questions, but the questions should be unique and detailed.
        The questions should be in a format that is easy to understand and answer.
        Ask about as much of the information in the table as possible.
        You can ask about more than one aspect of the data at a time.
        Qustions should begin with, 'What', 'Which', 'How', 'When' or 'Can'. Use variable names. 
        The questions should use relevant buisness vocabularly and terminology only. 
        Do not use column names in your output - use relevant natural language descriptions only. 
        Do not output any numeric values.
        Output questions starting with bulleted numbered list. 
         
        \n Questions: 1.
        \n Assistant:
        """
       

    response = ask_llm(question,model_id)
    answer = response['completion']
    question_list_filename = f"../questionList{table_name}.json"

    # # Get rid of anything before the 1.
    # if re.match(r"^[^\d+]\. ", answer) and re.search(r"\d+\. ", answer):
    #     answer = "1. " + answer.split("1. ")[1]
    # else:
    #     answer = "1. " + answer

    print(
        f"Writing questions to {question_list_filename}, with schema {schema}, with table name {table_name} and answer {answer}.\n\n"
    )

    write_questions_to_file(question_list_filename, table_name, schema, answer)
    docs = load_json_file(question_list_filename)
    docs = create_docs_with_correct_metadata(docs)
    new_questions = FAISS.from_documents(docs, bedrock_embeddings)
    db_exists = True if os.path.exists(f"{DB_FAISS_PATH}/index.faiss") else False
    # Add new tables
    if is_incremental and db_exists:
            question_db = FAISS.load_local(DB_FAISS_PATH, bedrock_embeddings, allow_dangerous_deserialization=True)
            question_db.merge_from(new_questions)
            question_db.save_local(DB_FAISS_PATH)

    # Load for the first time
    else:
        print(f"is_incremental set to {str(is_incremental)} and/or no vector db found. Creating...")
        new_questions.save_local(DB_FAISS_PATH)

 ### Step 5: Configure Bedrock Embedding Model
 Here we'll create a LangChain Embedding model to be used for converting text to vector embeddings

In [None]:
bedrock_embeddings = BedrockEmbeddings(client=bedrock)

 ### Step 6: Fetch TPC-DS Dataset Tables and Columns Information 

In [None]:
def get_sqlalchemy_athena(database,catalog,s3stagingathena,region):

    athena_connection_str = f'awsathena+rest://:@athena.{region}.amazonaws.com:443/{database}?s3_staging_dir={s3stagingathena}&catalog_name={catalog}'
    # Create Athena engine
    athena_engine = create_engine(athena_connection_str) 
    return athena_engine


def get_tpc_ds_dataset(database,catalog,s3stagingathena,region):
# Reflect db schema
    
    column_table  =[]
    columns_str= ''
    table_name = ''
    metadata = MetaData()
    engine = get_sqlalchemy_athena(database,catalog,s3stagingathena,region)
    metadata.reflect(bind=engine)

    # Get list of table names
    print(metadata.tables.keys()) 

    # Loop through tables
    for table in metadata.tables:
        print(f"Table: {table}")
        table_name= table
        columns_str= ''
        tuple = ''
        print(f"Schema: {metadata.tables[table].schema}")
        print(f"Columns: {metadata.tables[table].columns.keys()}")                

        for column in metadata.tables[table].columns.keys():
                columns_str=columns_str+f'{column}'+"|"  
                
        tuple = columns_str,table_name
        column_table.append(tuple)      
    return  column_table      



In [None]:
tpc_ds = get_tpc_ds_dataset(DB_NAME,ATHENA_CATALOG_NAME,ATHENA_RESULTS_S3_LOCATION,athena_region)

### Step 7: Embed all the questions and metadata 
Here we'll use our helper functions to embed table meta data and generate plausible questions we could ask of them

In [None]:
model_id = 'anthropic.claude-v2'

for x in tpc_ds:
    print(x)
    add_new_table(
        schema=x[0], 
        table_name=x[1],
        model_id=model_id,
        is_incremental=True, 
        bedrock_embeddings=bedrock_embeddings
    )

In [None]:
question_db = FAISS.load_local(DB_FAISS_PATH, bedrock_embeddings, allow_dangerous_deserialization=True)

In [None]:
query = "Find the top 10 customer name by total dollars spent"

### Step 8: Build Prompt and Generate SQL Query
First we'll get the tables and column information using both similarity and Keyword search to pull in possible matches based on the semantic meaning of our question.

In [None]:
schema =  {}

results_with_scores = question_db.similarity_search_with_score(query)
for doc, score in results_with_scores:
    print(doc.metadata['question'])
    schema[doc.metadata['tableName']] = doc.metadata['tableSchema']

Initialize our DIN_SQL class with the anthropic claude v2 model

In [None]:
din_sql = dsl.DIN_SQL(bedrock_model_id=model_id)

Connect to Athena to prepare for executing a query

In [None]:
din_sql.athena_connect(catalog_name=ATHENA_CATALOG_NAME, 
               db_name=DB_NAME, 
               s3_prefix=ATHENA_RESULTS_S3_LOCATION)

Now we'll augment our `schema` object with any table metadata that match the table name to any word in our question, to catch any obvious matches not accounted for.

In [None]:
list_tables = din_sql.find_tables(DB_NAME)
list_words = query.split(" ")

intersection = reduce(lambda acc, x: acc + [x] if x in list_words and x not in acc else acc, list_tables, [])
for table in  intersection :
   if table in schema:
      print("exists")
   else:
      schema_name = din_sql.get_schema(DB_NAME,table)
      schema[table] = schema_name

Let's take a look at what's in our `schema` object now.

In [None]:
schema

With our schema information ready for use, we're now ready for a prompt that can get use quality results. Take a look at the following prompt that use Claude Prompting best practices, and see how our schema is incorporated into the instructions.

In [None]:
prompt_template = PromptTemplate.from_template(
    """\n\nHuman: 
        <Instructions>
            Read database schema inside the <database_schema></database_schema> tags which contains a json list of table names and their pipe-delimited schemas to do the following:
            1. Create a syntactically correct awsathena query to answer the question.
            2. Never query for all the columns from a specific table, only ask for a few relevant columns given the question.
            3. Pay attention to use only the column names that you can see in the schema description. 
            4. Be careful to not query for columns that do not exist. 
            5. Pay attention to which column is in which table. 
            6. Qualify column names with the table name when needed. You are required to use the following format, each taking one line:
            7. Return the sql query inside the <SQL></SQL> tab.
        </Instructions>
          
        <database_schema>{schema}</database_schema>
        
        <examples>
        <Question>"How many users do we have?"</Question>
        <SQL>SELECT SUM(users) FROM customers</SQL>

        <Question>"How many users do we have for Mobile?"</Question>
        <SQL>SELECT SUM(users) FROM customer WHERE source_medium='Mobile'</SQL>
        </examples>
          
        <Question>{input_question}</Question>
        \n\n Assistant:"""
)
prompt_data= prompt_template.format(schema=schema,input_question = query)
print(prompt_data)

With our full prompt ready, let's submit to Claude to see what it comes up with.

In [None]:
body = json.dumps({"prompt": prompt_data, "max_tokens_to_sample": 1500,"temperature":0.0})
accept = 'application/json'
content_type = 'application/json'

response = bedrock.invoke_model(body=body, modelId=model_id, accept=accept, contentType=content_type)
response_body = json.loads(response.get('body').read())

sql = response_body['completion']

In [None]:
print(sql)

Let's query our data with this query we generated.
First we need to strip the `<SQL>` tags.

In [None]:
cleaned_sql = sql.split('<SQL>')[1].split('</SQL>')[0]

In [None]:
results = din_sql.query(cleaned_sql)
results