**OPENAI RAG LLM setup**

In [2]:
# Installing the required packages
%pip install langchain pymongo openai gradio requests langchain_community langchain-openai langchain-mongodb sentence_transformers python-dotenv

Collecting python-dotenv
  Using cached python_dotenv-1.0.1-py3-none-any.whl.metadata (23 kB)
Using cached python_dotenv-1.0.1-py3-none-any.whl (19 kB)
Installing collected packages: python-dotenv
Successfully installed python-dotenv-1.0.1
Note: you may need to restart the kernel to use updated packages.


In [1]:
# Importing the required libraries
from pymongo import MongoClient
from langchain_openai import OpenAIEmbeddings
from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain.document_loaders import DirectoryLoader
from langchain_openai import OpenAI
from langchain_openai import ChatOpenAI
from langchain.chains import RetrievalQA
import gradio as gr
from gradio.themes.base import Base
from langchain_core.output_parsers import StrOutputParser
from langchain.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from sentence_transformers import SentenceTransformer # https://huggingface.co/thenlper/gte-large
import os
from dotenv import load_dotenv

***Accessing secrets***

In [2]:
# In Google Colab, you can use the following code to access the secret
#from google.colab import userdata
#MONGO_URI = userdata.get('MONGO_URI')

# In your local environment, you can use the following code to access the secret
load_dotenv()
MONGO_URI = os.getenv("MONGO_URI")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

***Generating the embedding***

In [3]:
# MongoDB setup
client = MongoClient(MONGO_URI)
dbName = "MTGemma"
collectionName = "MTGemma"
collection = client[dbName][collectionName]
index_name = "vector_index"

# Embedding model setup
embedding_model = SentenceTransformer("thenlper/gte-large")

class CustomEmbeddingFunction:
    def __init__(self, model):
        self.model = model

    def embed_documents(self, texts): # --> is this needed for this use case?
        """Embeds a list of documents."""
        embeddings = self.model.encode(texts)
        return embeddings.tolist()

    def embed_query(self, text):
        """Embeds a single query."""
        embedding = self.model.encode(text)
        return embedding.tolist()

# Wrap the SentenceTransformer model
embedding_function = CustomEmbeddingFunction(embedding_model)

# Vector store setup
vector_store = MongoDBAtlasVectorSearch(
    client=client,
    database=dbName,
    collection=collection,
    index_name=index_name,
    embedding=embedding_function,
    text_key="Query"  
)

# Retriever setup
retriever = vector_store.as_retriever(search_kwargs={"k": 4})

# Define a custom logging retriever to see what the retriever is passing on
class LoggingRetriever:
    def __init__(self, retriever):
        self.retriever = retriever

    def __call__(self, query):
        # Retrieve the documents
        documents = self.retriever.invoke(query)
        
        # Log or print the retrieved documents
        print("Retrieved Documents:")
        for doc in documents:
            print(doc)
        
        # Return the retrieved documents
        return documents

# Wrap your retriever with the logging retriever
logging_retriever = LoggingRetriever(retriever)

***Chain setup***

In [10]:
# Model and parsing setup
model = ChatOpenAI(api_key=OPENAI_API_KEY, model="gpt-3.5-turbo")
parser = StrOutputParser()

# Define prompt template
template = """
Provide a natural language translation and explanation of the SQL statement. Go through it step by step and use the information of the Context as examples. If you can't answer the question, reply "I don't know".

Context: {context}

Question: {question}
"""

prompt = ChatPromptTemplate.from_template(template)

# Chain setup
chain = (
    {"context": logging_retriever, "question": RunnablePassthrough()}
    | prompt
    | model
    | parser
)

# Execute the chain with the logging retriever
chain.invoke("SELECT T2.name ,  T2.capacity FROM concert AS T1 JOIN stadium AS T2 ON T1.stadium_id  =  T2.stadium_id WHERE T1.year  >=  2014 GROUP BY T2.stadium_id ORDER BY count(*) DESC LIMIT 1?")

Retrieved Documents:
page_content='SELECT t1.name FROM stadium AS t1 JOIN event AS t2 ON t1.id  =  t2.stadium_id GROUP BY t2.stadium_id ORDER BY count(*) DESC LIMIT 1' metadata={'_id': '66b60159361489e02771bd8d', 'Question': 'What is the name of the stadium which held the most events?'}
page_content='SELECT t3.name FROM record AS t1 JOIN event AS t2 ON t1.event_id  =  t2.id JOIN stadium AS t3 ON t3.id  =  t2.stadium_id GROUP BY t2.stadium_id ORDER BY count(*) DESC LIMIT 1' metadata={'_id': '66b60159361489e02771bd96', 'Question': 'Find the names of stadiums that the most swimmers have been to.'}
page_content='SELECT T2.lastname FROM Performance AS T1 JOIN Band AS T2 ON T1.bandmate  =  T2.id WHERE stageposition  =  "back" GROUP BY lastname ORDER BY count(*) DESC LIMIT 1' metadata={'_id': '66b60159361489e02771bbe7', 'Question': 'What is the last name of the musician that has been at the back position the most?'}
page_content='SELECT T2.lastname FROM Performance AS T1 JOIN Band AS T2 ON T1

Retrieved Documents:
page_content='SELECT t1.name FROM stadium AS t1 JOIN event AS t2 ON t1.id  =  t2.stadium_id GROUP BY t2.stadium_id ORDER BY count(*) DESC LIMIT 1' metadata={'_id': '66b60159361489e02771bd8d', 'Question': 'What is the name of the stadium which held the most events?'}
page_content='SELECT t3.name FROM record AS t1 JOIN event AS t2 ON t1.event_id  =  t2.id JOIN stadium AS t3 ON t3.id  =  t2.stadium_id GROUP BY t2.stadium_id ORDER BY count(*) DESC LIMIT 1' metadata={'_id': '66b60159361489e02771bd96', 'Question': 'Find the names of stadiums that the most swimmers have been to.'}
page_content='SELECT T2.lastname FROM Performance AS T1 JOIN Band AS T2 ON T1.bandmate  =  T2.id WHERE stageposition  =  "back" GROUP BY lastname ORDER BY count(*) DESC LIMIT 1' metadata={'_id': '66b60159361489e02771bbe7', 'Question': 'What is the last name of the musician that has been at the back position the most?'}
page_content='SELECT T2.lastname FROM Performance AS T1 JOIN Band AS T2 ON T1

***Chat interface setup***

In [12]:
# Define the chain_invoke function
def chain_invoke(question):
    # Execute the chain with the logging retriever
    result = chain.invoke(question)
    # Return the result 
    return result

# Create a web interface for the app, using Gradio
with gr.Blocks(theme=Base(), title="Question Answering App using Vector Search + RAG") as demo:
    gr.Markdown(
        """
        # Question Answering App using Atlas Vector Search + RAG Architecture
        """)
    textbox = gr.Textbox(label="Enter your SQL statement:")
    with gr.Row():
        button = gr.Button("Submit", variant="primary")
    output = gr.Textbox(lines=1, max_lines=30, label="Natural language translation and explanation:")

# Call chain_invoke function upon clicking the Submit button
    button.click(chain_invoke, textbox, outputs=output)

demo.launch()

Running on local URL:  http://127.0.0.1:7862

To create a public link, set `share=True` in `launch()`.




Retrieved Documents:
page_content='SELECT t1.name FROM stadium AS t1 JOIN event AS t2 ON t1.id  =  t2.stadium_id GROUP BY t2.stadium_id ORDER BY count(*) DESC LIMIT 1' metadata={'_id': '66b60159361489e02771bd8d', 'Question': 'What is the name of the stadium which held the most events?'}
page_content='SELECT t3.name FROM record AS t1 JOIN event AS t2 ON t1.event_id  =  t2.id JOIN stadium AS t3 ON t3.id  =  t2.stadium_id GROUP BY t2.stadium_id ORDER BY count(*) DESC LIMIT 1' metadata={'_id': '66b60159361489e02771bd96', 'Question': 'Find the names of stadiums that the most swimmers have been to.'}
page_content='SELECT T2.lastname FROM Performance AS T1 JOIN Band AS T2 ON T1.bandmate  =  T2.id WHERE stageposition  =  "back" GROUP BY lastname ORDER BY count(*) DESC LIMIT 1' metadata={'_id': '66b60159361489e02771bbe7', 'Question': 'What is the last name of the musician that has been at the back position the most?'}
page_content='SELECT T2.lastname FROM Performance AS T1 JOIN Band AS T2 ON T1