***OPENAI RAG LLM setup***

**Loading packages, libraries and secrets into notebook**

In [1]:
%run ../Setup.ipynb

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


***Generating the embedding***

In [2]:
# Embedding model setup
embedding_model = SentenceTransformer("thenlper/gte-large")

class CustomEmbeddingFunction:
    def __init__(self, model):
        self.model = model

    def embed_documents(self, texts):
        """Embeds a list of documents."""
        embeddings = self.model.encode(texts)
        return embeddings.tolist()

    def embed_query(self, text):
        """Embeds a single query."""
        embedding = self.model.encode(text)
        return embedding.tolist()

# Wrap the SentenceTransformer model
embedding_function = CustomEmbeddingFunction(embedding_model)

In [3]:
## MongoDB setup
# SQL Vector
client_SQL = MongoClient(MONGO_URI_SQL)
dbName_SQL = "MVector"
collectionName_SQL = "MTSQL"
collection_SQL = client_SQL[dbName_SQL][collectionName_SQL]
index_name_SQL = "vector_index_SQL"

## SQL Vector setup
# Vector store setup
vector_store_SQL = MongoDBAtlasVectorSearch(
    client=client_SQL,
    database=dbName_SQL,
    collection=collection_SQL,
    index_name=index_name_SQL,
    embedding=embedding_function,
    text_key="Query"  
)

# Retriever setup
retriever_SQL = vector_store_SQL.as_retriever(search_kwargs={"k": 4})

# Define a custom logging retriever to see what the retriever is passing on
class LoggingRetrieverSQL:
    def __init__(self, retriever_SQL):
        self.retriever_SQL = retriever_SQL

    def __call__(self, query):
        # Retrieve the documents
        documents_SQL = self.retriever_SQL.invoke(query)
        
        # Log or print the retrieved documents
        print("Retrieved Documents:")
        for doc in documents_SQL:
            print(doc)
        
        # Return the retrieved documents
        return documents_SQL

# Wrap your retriever with the logging retriever
logging_retriever_SQL = LoggingRetrieverSQL(retriever_SQL)

In [4]:
# Schema Vector
client_schema = MongoClient(MONGO_URI_schema)
dbName_schema = "MVector"
collectionName_schema = "MTSchemaAll"
collection_schema = client_schema[dbName_schema][collectionName_schema]
index_name_schema = "vector_index_schema_all"

## Schema Vector setup
# Vector store setup
vector_store_schema = MongoDBAtlasVectorSearch(
    client=client_schema,
    database=dbName_schema,
    collection=collection_schema,
    index_name=index_name_schema,
    embedding=embedding_function,
    text_key="Lookup_name"  
)

# Retriever setup
retriever_schema = vector_store_schema.as_retriever(search_kwargs={"k": 5})

# Define a custom logging retriever to see what the retriever is passing on
class LoggingRetrieverSchema:
    def __init__(self, retriever_schema):
        self.retriever_schema = retriever_schema

    def __call__(self, query):
        # Retrieve the documents
        documents_schema = self.retriever_schema.invoke(query)
        
        # Log or print the retrieved documents
        print("Retrieved Schema:")
        for doc in documents_schema:
            print(doc)
        
        # Return the retrieved documents
        return documents_schema

# Wrap your retriever with the logging retriever
logging_retriever_schema = LoggingRetrieverSchema(retriever_schema)

***Chain setup***

In [5]:
# Model and parsing setup
model = ChatOpenAI(api_key=OPENAI_API_KEY, model="gpt-3.5-turbo")
parser = StrOutputParser()

# Define prompt template
template = """
Provide first a natural language Translation followed by an Explanation of the SQL statement. Go through it step by step and use the information of the Context as examples and the schema to get the names of tables and columns. If you can't answer the question, reply "I don't know".

Context: {context}

Schema: {schema}

Question: {question}

"""

prompt = ChatPromptTemplate.from_template(template)

# Chain setup
chain_1b = (
    {"context": logging_retriever_SQL, "schema": logging_retriever_schema, "question": RunnablePassthrough()}
    | prompt
    | model
    | parser
)

# Execute the chain with the logging retriever
chain_1b.invoke("SELECT T1.Name ,  sum(T2.Sales) FROM singer AS T1 JOIN song AS T2 ON T1.Singer_ID  =  T2.Singer_ID GROUP BY T1.Name")

Retrieved Documents:
page_content='SELECT T1.artist_name ,  count(*) FROM artist AS T1 JOIN song AS T2 ON T1.artist_name  =  T2.artist_name WHERE T2.languages  =  "english" GROUP BY T2.artist_name HAVING count(*)  >=  1' metadata={'_id': '66cf12c13c2173e47d7af5dd', 'Question': 'Find the names and number of works of all artists who have at least one English songs.'}
page_content='SELECT T1.artist_name ,  count(*) FROM artist AS T1 JOIN song AS T2 ON T1.artist_name  =  T2.artist_name WHERE T2.languages  =  "english" GROUP BY T2.artist_name HAVING count(*)  >=  1' metadata={'_id': '66cf12c13c2173e47d7af5de', 'Question': 'What are the names and number of works for all artists who have sung at least one song in English?'}
page_content='SELECT count(DISTINCT title) FROM vocals AS T1 JOIN songs AS T2 ON T1.songid  =  T2.songid WHERE TYPE  =  "shared"' metadata={'_id': '66cf12c13c2173e47d7afc4a', 'Question': 'How many songs have a shared vocal?'}
page_content='SELECT count(DISTINCT title) FROM

"Translation: Find the names of all artists and the total number of works they have.\n\nExplanation: This SQL statement is selecting the artist's name (referred to as T1.Name) and the count of works they have (represented by the sum of T2.Sales) from the tables 'singer' and 'song'. It then joins these tables on the condition that the Singer_ID in both tables is equal. The result is grouped by the artist's name to display each artist's name along with the total count of their works."

***Chat interface setup***

In [6]:
# Define the chain_invoke function
def chain_1b_invoke(question):
    # Execute the chain with the logging retriever
    result = chain_1b.invoke(question)
    # Return the result 
    return result

# Create a web interface for the app, using Gradio
with gr.Blocks(theme=Base(), title="Question Answering App using Vector Search + RAG") as demo:
    gr.Markdown(
        """
        # Question Answering App using Atlas Vector Search + RAG Architecture
        """)
    textbox = gr.Textbox(label="Enter your SQL statement:")
    with gr.Row():
        button = gr.Button("Submit", variant="primary")
    output = gr.Textbox(lines=1, max_lines=30, label="Natural language translation and explanation:")

# Call chain_invoke function upon clicking the Submit button
    button.click(chain_1b_invoke, textbox, outputs=output)

demo.launch()

Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.


