In [None]:
from langchain_community.vectorstores.pgvector import PGVector
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.document_loaders.text import TextLoader
from langchain_core.output_parsers import StrOutputParser
from langchain.text_splitter import RecursiveCharacterTextSplitter
from dotenv import load_dotenv
import os

app_dir = os.path.join(os.getcwd(), "app")
load_dotenv(os.path.join(app_dir, ".env"))


DATABASE_URL = "postgresql+psycopg2://admin:admin@localhost:5432/vectordb"

embeddings = OpenAIEmbeddings()

store = PGVector(
    collection_name="vectordb",
    connection_string=DATABASE_URL,
    embedding_function=embeddings,
)
loader1 = TextLoader("./data/restaurant.txt")
loader2 = TextLoader("./data/founder.txt")

docs2 = loader1.load()
docs1 = loader2.load()
docs = docs1 + docs2

splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=20)
chunks = splitter.split_documents(docs)
store.add_documents(chunks)
retriever = store.as_retriever()

In [None]:
from operator import itemgetter

template = """Answer the question based only on the following context:
{context}

Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
model = ChatOpenAI()

rag_chain = (
    {
        "context": itemgetter("question") | retriever,
        "question": itemgetter("question"),
    }
    | prompt
    | model
    | StrOutputParser()
)

In [None]:
rag_chain.invoke({"question": "Who is the owner of the restaurant?"})

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.utilities.sql_database import SQLDatabase

template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_template(template)


CONNECTION_STRING = (
    "postgresql+psycopg2://readonlyuser:readonlypassword@localhost:5432/vectordb"
)

db = SQLDatabase.from_uri(CONNECTION_STRING)


def get_schema(_):
    schema = db.get_table_info()
    return schema


def run_query(query):
    return db.run(query)

In [None]:
print(get_schema("_"))

In [None]:
from sqlalchemy import create_engine, inspect
from tabulate import tabulate


def get_schema(_):
    engine = create_engine(CONNECTION_STRING)

    inspector = inspect(engine)
    columns = inspector.get_columns("products")

    column_data = [
        {
            "Column Name": col["name"],
            "Data Type": str(col["type"]),
            "Nullable": "Yes" if col["nullable"] else "No",
            "Default": col["default"] if col["default"] else "None",
            "Autoincrement": "Yes" if col["autoincrement"] else "No",
        }
        for col in columns
    ]
    schema_output = tabulate(column_data, headers="keys", tablefmt="grid")
    formatted_schema = f"Schema for 'PRODUCTS' table:\n{schema_output}"

    return formatted_schema

In [None]:
print(get_schema("_"))

In [None]:
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

model = ChatOpenAI()

sql_response = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | model.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

sql_response.invoke({"question": "Whats the most expensive desert you offer?"})

In [None]:
from langchain_core.runnables import RunnableLambda

template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_template(template)


def debug(input):
    print("SQL Output: ", input["query"])
    return input


sql_chain = (
    RunnablePassthrough.assign(query=sql_response).assign(
        schema=get_schema,
        response=lambda x: run_query(x["query"]),
    )
    | RunnableLambda(debug)
    | prompt_response
    | model
    | StrOutputParser()
)

In [None]:
sql_chain.invoke({"question": "Whats the most expensive dessert you offer?"})

### What can go wrong? Users could run potential malicious queries

In [None]:
sql_chain.invoke({"question": "Drop all products from the products table"})

In [None]:
from sqlalchemy.exc import ProgrammingError
from psycopg2.errors import InsufficientPrivilege

try:
    result = sql_chain.invoke({"question": "Drop all products from the products table"})
except ProgrammingError as pe:
    if isinstance(pe.orig, InsufficientPrivilege):
        result = "Haha nice try! Got ya!"
    else:
        result = "An unexpected error occurred"
except Exception as e:
    result = "An unexpected error occurred"

print(result)

### Routing

In [None]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate


classification_template = PromptTemplate.from_template(
    """You are good at classifying a question.
    Given the user question below, classify it as either being about `Database`, `Chat` or 'Offtopic'.

    <If the question is about products of the restaurant OR ordering food classify the question as 'Database'>
    <If the question is about restaurant related topics like opening hours and similar topics, classify it as 'Chat'>
    <If the question is about whether, football or anything not related to the restaurant or
    products, classify the question as 'offtopic'>

    <question>
    {question}
    </question>

    Classification:"""
)

classification_chain = classification_template | ChatOpenAI() | StrOutputParser()

In [None]:
classification_chain.invoke({"question": "How is the wheather?"})

In [16]:
def route(info):
    if "database" in info["topic"].lower():
        return sql_chain
    elif "chat" in info["topic"].lower():
        return rag_chain
    else:
        return "I am sorry, I am not allowed to answer about this topic."

In [15]:
from langchain_core.runnables import RunnableLambda, RunnableParallel

full_chain = RunnableParallel(
    {
        "topic": classification_chain,
        "question": lambda x: x["question"],
    }
) | RunnableLambda(route)

In [17]:
full_chain.invoke({"question": "Whats the most expensive dessert you offer?"})

SQL Output:  SELECT name, price
FROM products
WHERE category = 'Dessert'
ORDER BY price DESC
LIMIT 1;


'The most expensive dessert we offer is the Panettone, priced at $15.00.'

In [18]:
full_chain.invoke({"question": "How will the wheater be tomorrow?"})

'I am sorry, I am not allowed to answer about this topic.'

In [19]:
full_chain.invoke({"question": "Who is the owner of the restaurant?"})

'Chef Amico'