***OPENAI RAG LLM setup***

**Loading packages, libraries and secrets into notebook**

In [None]:
%run ../Setup.ipynb

***Generating the embedding***

In [4]:
# Embedding model setup
embedding_model = SentenceTransformer("thenlper/gte-large")

class CustomEmbeddingFunction:
    def __init__(self, model):
        self.model = model

    def embed_documents(self, texts):
        """Embeds a list of documents."""
        embeddings = self.model.encode(texts)
        return embeddings.tolist()

    def embed_query(self, text):
        """Embeds a single query."""
        embedding = self.model.encode(text)
        return embedding.tolist()

# Wrap the SentenceTransformer model
embedding_function = CustomEmbeddingFunction(embedding_model)

In [5]:
## MongoDB setup
# SQL Vector
client_SQL = MongoClient(MONGO_URI_SQL)
dbName_SQL = "MVector"
collectionName_SQL = "MTSQL"
collection_SQL = client_SQL[dbName_SQL][collectionName_SQL]
index_name_SQL = "vector_index_SQL"

## SQL Vector setup
# Vector store setup
vector_store_SQL = MongoDBAtlasVectorSearch(
    client=client_SQL,
    database=dbName_SQL,
    collection=collection_SQL,
    index_name=index_name_SQL,
    embedding=embedding_function,
    text_key="Query"  
)

# Retriever setup
retriever_SQL = vector_store_SQL.as_retriever(search_kwargs={"k": 4})

# Define a custom logging retriever to see what the retriever is passing on
class LoggingRetrieverSQL:
    def __init__(self, retriever_SQL):
        self.retriever_SQL = retriever_SQL

    def __call__(self, query):
        # Retrieve the documents
        documents_SQL = self.retriever_SQL.invoke(query)
        
        # Log or print the retrieved documents
        print("Retrieved Documents:")
        for doc in documents_SQL:
            print(doc)
        
        # Return the retrieved documents
        return documents_SQL

# Wrap your retriever with the logging retriever
logging_retriever_SQL = LoggingRetrieverSQL(retriever_SQL)

In [6]:
# Schema Vector
client_schema = MongoClient(MONGO_URI_schema)
dbName_schema = "MVector"
collectionName_schema = "MTSchemaAll"
collection_schema = client_schema[dbName_schema][collectionName_schema]
index_name_schema = "vector_index_schema_all"

## Schema Vector setup
# Vector store setup
vector_store_schema = MongoDBAtlasVectorSearch(
    client=client_schema,
    database=dbName_schema,
    collection=collection_schema,
    index_name=index_name_schema,
    embedding=embedding_function,
    text_key="Lookup_name"  
)

# Retriever setup
retriever_schema = vector_store_schema.as_retriever(search_kwargs={"k": 5})

# Define a custom logging retriever to see what the retriever is passing on
class LoggingRetrieverSchema:
    def __init__(self, retriever_schema):
        self.retriever_schema = retriever_schema

    def __call__(self, query):
        # Retrieve the documents
        documents_schema = self.retriever_schema.invoke(query)
        
        # Log or print the retrieved documents
        print("Retrieved Schema:")
        for doc in documents_schema:
            print(doc)
        
        # Return the retrieved documents
        return documents_schema

# Wrap your retriever with the logging retriever
logging_retriever_schema = LoggingRetrieverSchema(retriever_schema)

***Chain setup***

In [8]:
# Model and parsing setup
model = ChatOpenAI(api_key=OPENAI_API_KEY, model="gpt-3.5-turbo")
parser = StrOutputParser()

# Define prompt template
template = """
Provide first a natural language Translation followed by an Explanation of the SQL statement. Go through it step by step and use the information of the Context as examples and the schema to get the names of tables and columns. If you can't answer the question, reply "I don't know".

Context: {context}

Schema: {schema}

Question: {question}
"""

prompt = ChatPromptTemplate.from_template(template)

# Chain setup
chain_2b = (
    {"context": logging_retriever_SQL, "schema": logging_retriever_schema, "question": RunnablePassthrough()}
    | prompt
    | model
    | parser
)

# Execute the chain with the logging retriever
chain_2b.invoke("SELECT T1.Name ,  sum(T2.Sales) FROM singer AS T1 JOIN song AS T2 ON T1.Singer_ID  =  T2.Singer_ID GROUP BY T1.Name")

Retrieved Documents:Retrieved Schema:
page_content='SINGER_ID SINGER ' metadata={'_id': '66e83eb4203411a6471d680a', 'Column_name': 'SINGER_ID', 'Table_name': 'SINGER ', 'DB_name': 'singer'}
page_content='SINGER_ID SINGER ' metadata={'_id': '66e83eb4203411a6471d67a5', 'Column_name': 'SINGER_ID', 'Table_name': 'SINGER ', 'DB_name': 'concert_singer'}
page_content='SINGER_ID SINGER_IN_CONCERT ' metadata={'_id': '66e83eb4203411a6471d67b5', 'Column_name': 'SINGER_ID', 'Table_name': 'SINGER_IN_CONCERT ', 'DB_name': 'concert_singer'}
page_content='SINGER_ID SONG ' metadata={'_id': '66e83eb4203411a6471d6812', 'Column_name': 'SINGER_ID', 'Table_name': 'SONG ', 'DB_name': 'singer'}
page_content='CONCERT_ID SINGER_IN_CONCERT ' metadata={'_id': '66e83eb4203411a6471d67b4', 'Column_name': 'CONCERT_ID', 'Table_name': 'SINGER_IN_CONCERT ', 'DB_name': 'concert_singer'}

page_content='SELECT T1.artist_name ,  count(*) FROM artist AS T1 JOIN song AS T2 ON T1.artist_name  =  T2.artist_name WHERE T2.languag

'Translation: Find the names and total sales of all singers who have at least one song.\n\nExplanation:\n1. SELECT T1.Name, sum(T2.Sales): We are selecting the Name column from the singer table as T1 and the total sales from the song table as T2.\n2. FROM singer AS T1 JOIN song AS T2 ON T1.Singer_ID = T2.Singer_ID: We are joining the singer table as T1 with the song table as T2 based on the Singer_ID column to match the singers with their songs.\n3. GROUP BY T1.Name: We are grouping the results by the Name column to get the total sales for each singer.\n\nExample:\nIf the singer table contains:\n- Name: John, Singer_ID: 1\n- Name: Sarah, Singer_ID: 2\n\nAnd the song table contains:\n- Title: Song1, Singer_ID: 1, Sales: 100\n- Title: Song2, Singer_ID: 2, Sales: 150\n- Title: Song3, Singer_ID: 1, Sales: 200\n\nThe query will return:\n- John, Total Sales: 300\n- Sarah, Total Sales: 150'

***Chat interface setup***

In [9]:
# Define the chain_invoke function
def chain_2b_invoke(question):
    # Execute the chain with the logging retriever
    result = chain_2b.invoke(question)
    # Return the result 
    return result

# Create a web interface for the app, using Gradio
with gr.Blocks(theme=Base(), title="Question Answering App using Vector Search + RAG") as demo:
    gr.Markdown(
        """
        # Question Answering App using Atlas Vector Search + RAG Architecture
        """)
    textbox = gr.Textbox(label="Enter your SQL statement:")
    with gr.Row():
        button = gr.Button("Submit", variant="primary")
    output = gr.Textbox(lines=1, max_lines=30, label="Natural language translation and explanation:")

# Call chain_invoke function upon clicking the Submit button
    button.click(chain_2b_invoke, textbox, outputs=output)

demo.launch()

Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.


