**OPENAI RAG LLM setup**

In [1]:
# Installing the required packages
%pip install langchain pymongo openai gradio requests langchain_community langchain-openai langchain-mongodb sentence_transformers python-dotenv




In [2]:
# Importing the required libraries
from pymongo import MongoClient
from langchain_openai import OpenAIEmbeddings
from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain.document_loaders import DirectoryLoader
from langchain_openai import OpenAI
from langchain_openai import ChatOpenAI
from langchain.chains import RetrievalQA
import gradio as gr
from gradio.themes.base import Base
from langchain_core.output_parsers import StrOutputParser
from langchain.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from sentence_transformers import SentenceTransformer # https://huggingface.co/thenlper/gte-large
import os
from dotenv import load_dotenv

***Accessing secrets***

In [3]:
# In Google Colab, you can use the following code to access the secret
#from google.colab import userdata
#MONGO_URI_SQL = userdata.get('MONGO_URI_SQL')
#MONGO_URI_schema = userdata.get('MONGO_URI_Schema')
#OPENAI_API_KEY = userdata.get("OPENAI_API_KEY")

# In your local environment, you can use the following code to access the secret
load_dotenv()
MONGO_URI_SQL = os.getenv("MONGO_URI_SQL")
MONGO_URI_schema = os.getenv("MONGO_URI_Schema")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

***Generating the embedding***

In [4]:
# Embedding model setup
embedding_model = SentenceTransformer("thenlper/gte-large")

class CustomEmbeddingFunction:
    def __init__(self, model):
        self.model = model

    def embed_documents(self, texts):
        """Embeds a list of documents."""
        embeddings = self.model.encode(texts)
        return embeddings.tolist()

    def embed_query(self, text):
        """Embeds a single query."""
        embedding = self.model.encode(text)
        return embedding.tolist()

# Wrap the SentenceTransformer model
embedding_function = CustomEmbeddingFunction(embedding_model)

In [5]:
## MongoDB setup
# SQL Vector
client_SQL = MongoClient(MONGO_URI_SQL)
dbName_SQL = "MVector"
collectionName_SQL = "MTSQL"
collection_SQL = client_SQL[dbName_SQL][collectionName_SQL]
index_name_SQL = "vector_index_SQL"

## SQL Vector setup
# Vector store setup
vector_store_SQL = MongoDBAtlasVectorSearch(
    client=client_SQL,
    database=dbName_SQL,
    collection=collection_SQL,
    index_name=index_name_SQL,
    embedding=embedding_function,
    text_key="Query"  
)

# Retriever setup
retriever_SQL = vector_store_SQL.as_retriever(search_kwargs={"k": 4})

# Define a custom logging retriever to see what the retriever is passing on
class LoggingRetrieverSQL:
    def __init__(self, retriever_SQL):
        self.retriever_SQL = retriever_SQL

    def __call__(self, query):
        # Retrieve the documents
        documents_SQL = self.retriever_SQL.invoke(query)
        
        # Log or print the retrieved documents
        print("Retrieved Documents:")
        for doc in documents_SQL:
            print(doc)
        
        # Return the retrieved documents
        return documents_SQL

# Wrap your retriever with the logging retriever
logging_retriever_SQL = LoggingRetrieverSQL(retriever_SQL)

In [6]:
# Schema Vector
client_schema = MongoClient(MONGO_URI_schema)
dbName_schema = "MVector"
collectionName_schema = "MTSchema"
collection_schema = client_schema[dbName_schema][collectionName_schema]
index_name_schema = "vector_index_schema"

## Schema Vector setup
# Vector store setup
vector_store_schema = MongoDBAtlasVectorSearch(
    client=client_schema,
    database=dbName_schema,
    collection=collection_schema,
    index_name=index_name_schema,
    embedding=embedding_function,
    text_key="Table_name"  
)

# Retriever setup
retriever_schema = vector_store_schema.as_retriever(search_kwargs={"k": 5})

# Define a custom logging retriever to see what the retriever is passing on
class LoggingRetrieverSchema:
    def __init__(self, retriever_schema):
        self.retriever_schema = retriever_schema

    def __call__(self, query):
        # Retrieve the documents
        documents_schema = self.retriever_schema.invoke(query)
        
        # Log or print the retrieved documents
        print("Retrieved Schema:")
        for doc in documents_schema:
            print(doc)
        
        # Return the retrieved documents
        return documents_schema

# Wrap your retriever with the logging retriever
logging_retriever_schema = LoggingRetrieverSchema(retriever_schema)

***Chain setup***

In [8]:
# Model and parsing setup
model = ChatOpenAI(api_key=OPENAI_API_KEY, model="gpt-3.5-turbo")
parser = StrOutputParser()

# Define prompt template
template = """
Provide a natural language translation and explanation of the SQL statement. Go through it step by step and use the information of the Context as examples and the schema to get the names of tables and columns. If you can't answer the question, reply "I don't know".

Context: {context}

Schema: {schema}

Question: {question}
"""

prompt = ChatPromptTemplate.from_template(template)

# Chain setup
chain = (
    {"context": logging_retriever_SQL, "schema": logging_retriever_schema, "question": RunnablePassthrough()}
    | prompt
    | model
    | parser
)

# Execute the chain with the logging retriever
chain.invoke("SELECT T2.name ,  T2.capacity FROM concert AS T1 JOIN stadium AS T2 ON T1.stadium_id  =  T2.stadium_id WHERE T1.year  >=  2014 GROUP BY T2.stadium_id ORDER BY count(*) DESC LIMIT 1?")

Retrieved Documents:
page_content='SELECT t1.name FROM stadium AS t1 JOIN event AS t2 ON t1.id  =  t2.stadium_id GROUP BY t2.stadium_id ORDER BY count(*) DESC LIMIT 1' metadata={'_id': '66cf12c13c2173e47d7afdbc', 'Question': 'What is the name of the stadium which held the most events?'}
page_content='SELECT t3.name FROM record AS t1 JOIN event AS t2 ON t1.event_id  =  t2.id JOIN stadium AS t3 ON t3.id  =  t2.stadium_id GROUP BY t2.stadium_id ORDER BY count(*) DESC LIMIT 1' metadata={'_id': '66cf12c13c2173e47d7afdc5', 'Question': 'Find the names of stadiums that the most swimmers have been to.'}
page_content='SELECT T2.lastname FROM Performance AS T1 JOIN Band AS T2 ON T1.bandmate  =  T2.id WHERE stageposition  =  "back" GROUP BY lastname ORDER BY count(*) DESC LIMIT 1' metadata={'_id': '66cf12c13c2173e47d7afc16', 'Question': 'What is the last name of the musician that has been at the back position the most?'}
page_content='SELECT T2.lastname FROM Performance AS T1 JOIN Band AS T2 ON T1

'Translation: Find the name and capacity of the stadium that held the most concerts starting from the year 2014.\n\nExplanation:\n1. SELECT T2.name, T2.capacity: We are selecting the name and capacity columns from the stadium table, which we are aliasing as T2.\n2. FROM concert AS T1: We are specifying that we are getting data from the concert table and aliasing it as T1.\n3. JOIN stadium AS T2 ON T1.stadium_id = T2.stadium_id: We are joining the concert table with the stadium table based on the stadium_id column to match the stadiums where concerts were held.\n4. WHERE T1.year >= 2014: We are filtering the data to only consider concerts that happened in the year 2014 or after.\n5. GROUP BY T2.stadium_id: We are grouping the results by the stadium_id column to count how many concerts were held at each stadium.\n6. ORDER BY count(*) DESC: We are ordering the results by the count of concerts held at each stadium in descending order, so the stadium that held the most concerts will be at t

***Chat interface setup***

In [9]:
# Define the chain_invoke function
def chain_invoke(question):
    # Execute the chain with the logging retriever
    result = chain.invoke(question)
    # Return the result 
    return result

# Create a web interface for the app, using Gradio
with gr.Blocks(theme=Base(), title="Question Answering App using Vector Search + RAG") as demo:
    gr.Markdown(
        """
        # Question Answering App using Atlas Vector Search + RAG Architecture
        """)
    textbox = gr.Textbox(label="Enter your SQL statement:")
    with gr.Row():
        button = gr.Button("Submit", variant="primary")
    output = gr.Textbox(lines=1, max_lines=30, label="Natural language translation and explanation:")

# Call chain_invoke function upon clicking the Submit button
    button.click(chain_invoke, textbox, outputs=output)

demo.launch()

Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.


