In [None]:
from langchain_community.vectorstores.pgvector import PGVector
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.document_loaders import TextLoader
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain.text_splitter import RecursiveCharacterTextSplitter


DATABASE_URL = "postgresql+psycopg2://myuser:mypassword@localhost:5433/mydatabase"

embeddings = OpenAIEmbeddings()

store = PGVector(
    collection_name="mydatabase",
    connection_string=DATABASE_URL,
    embedding_function=embeddings,
)
loader = TextLoader("./text.txt")
docs = loader.load()
splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=20)
chunks =  splitter.split_documents(docs)
store.add_documents(chunks)
retriever = store.as_retriever()

In [None]:
template = """
Answer the question based only on the following context:
{context}

Question: {question}
"""

prompt = ChatPromptTemplate.from_template(template)
model = ChatOpenAI(model_name="gpt-3.5-turbo")

rag_chain = (
    {
        "context": itemgetter("question") | retriever,
        "question": itemgetter("question"),
    }
    | prompt
    | model
    | StrOutputParser()
)

In [None]:
rag_chain.invoke({"question": "Who is the owner of the restaurant?"})

In [20]:
from langchain_core.prompts import ChatPromptTemplate

template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_template(template)

from langchain_community.utilities import SQLDatabase

DATABASE_URL = "postgresql+psycopg2://myuser:mypassword@localhost:5433/mydatabase"

db = SQLDatabase.from_uri(DATABASE_URL)

def get_schema(_):
    schema = db.get_table_info()
    return schema

def run_query(query):
    return db.run(query)

In [21]:
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

model = ChatOpenAI()

sql_response = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | model.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

sql_response.invoke({"question": "Whats the most expensive desert you offer?"})

"SELECT name, price\nFROM products\nWHERE category = 'Dessert'\nORDER BY price DESC\nLIMIT 1;"

In [22]:
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_template(template)

sql_chain = (
    RunnablePassthrough.assign(query=sql_response).assign(
        schema=get_schema,
        response=lambda x: run_query(x["query"]),
    )
    | prompt_response
    | model
    | StrOutputParser()
)

sql_chain.invoke({"question": "Whats the most expensive desert you offer?"})

'The most expensive dessert we offer is the "Tiramisu della Nonna" priced at $8.50.'

### Routing

In [23]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate


classification_template = PromptTemplate.from_template(
    """You are good at classifying a question.
    Given the user question below, classify it as either being about `Database`, `Chat` or 'Offtopic'.

    <If the question is about products of the restaurant OR ordering food classify the question as 'Database'>
    <If the question is about restaurant related topics like opening hours and similar topics, classify it as 'Chat'>
    <If the question is about whether, football or anything not related to the restaurant or
    products, classify the question as 'offtopic'>

    <question>
    {question}
    </question>

    Classification:"""
)

classification_chain = (
    classification_template
    | ChatOpenAI()
    | StrOutputParser()
)


In [24]:
classification_chain.invoke({"question": "How is the wheather?"})

'Offtopic'

In [26]:
# def route(info):
#     print(f"INFO: {info}")
#     if "database" in info["topic"].lower():
#         return "This has to be handled in the database"
#     elif "chat" in info["topic"].lower():
#         return "This has to be handled in Chat"
#     else:
#         return "I am sorry, I am not allowed to answer questions about this topic."

def route(info):
    if "database" in info["topic"].lower():
        return sql_chain
    elif "chat" in info["topic"].lower():
        return rag_chain
    else:
        return "I am sorry, I am not allowed to answer about this topic."

In [27]:
from langchain_core.runnables import RunnableLambda

full_chain = {"topic": classification_chain, "question": lambda x: x["question"]} | RunnableLambda(route)

In [28]:
full_chain.invoke({"question": "Whats the most expensive desert you offer?"})

INFO: {'topic': 'Database', 'question': 'Whats the most expensive desert you offer?'}


'This has to be handled in the database'

In [29]:
full_chain.invoke({"question": "How will the wheater be tomorrow?"})

INFO: {'topic': 'Offtopic', 'question': 'How will the wheater be tomorrow?'}


'I am sorry, I am not allowed to answer questions about this topic.'

In [30]:
full_chain.invoke({"question": "Who is the owner of the restaurant?"})

INFO: {'topic': 'Chat', 'question': 'Who is the owner of the restaurant?'}


'This has to be handled in Chat'