In [19]:
from langchain_community.vectorstores.pgvector import PGVector
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.document_loaders.text import TextLoader
from langchain_core.output_parsers import StrOutputParser
from langchain.text_splitter import RecursiveCharacterTextSplitter
from dotenv import load_dotenv, find_dotenv
import os


load_dotenv(find_dotenv())


DATABASE_URL = "postgresql+psycopg2://postgres:postgres@localhost:5432/postgres"

embeddings = OpenAIEmbeddings()

store = PGVector(
    collection_name="vectordb",
    connection_string=DATABASE_URL,
    embedding_function=embeddings,
)
loader1 = TextLoader("./restaurant.txt")
loader2 = TextLoader("./founder.txt")

docs2 = loader1.load()
docs1 = loader2.load()
docs = docs1 + docs2

splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=20)
chunks = splitter.split_documents(docs)
store.add_documents(chunks)
retriever = store.as_retriever()

  store = PGVector(


In [20]:
from operator import itemgetter

template = """Answer the question based only on the following context:
{context}

Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
model = ChatOpenAI()

rag_chain = (
    {
        "context": itemgetter("question") | retriever,
        "question": itemgetter("question"),
    }
    | prompt
    | model
    | StrOutputParser()
)

In [21]:
rag_chain.invoke({"question": "Who is the owner of the restaurant?"})

'Chef Amico is the owner of the restaurant.'

In [22]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.utilities.sql_database import SQLDatabase

template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_template(template)


CONNECTION_STRING = "postgresql+psycopg2://postgres:postgres@localhost:5432/postgres"

db = SQLDatabase.from_uri(CONNECTION_STRING)


def get_schema(_):
    schema = db.get_table_info()
    return schema


def run_query(query):
    return db.run(query)

In [23]:
print(get_schema("_"))


CREATE TABLE contacts (
	contact_id UUID DEFAULT gen_random_uuid() NOT NULL, 
	first_name VARCHAR NOT NULL, 
	last_name VARCHAR NOT NULL, 
	CONSTRAINT contacts_pkey PRIMARY KEY (contact_id)
)

/*
3 rows from contacts table:
contact_id	first_name	last_name
d07b7dc4-5d6a-49da-8e6a-b3c46a82e55d	Craig	West
*/


CREATE TABLE docstore (
	key VARCHAR NOT NULL, 
	value JSONB, 
	CONSTRAINT docstore_pkey PRIMARY KEY (key)
)

/*
3 rows from docstore table:
key	value

*/


CREATE TABLE document_data (
	doc_id UUID NOT NULL, 
	content TEXT NOT NULL, 
	embeddings REAL[], 
	added TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, 
	CONSTRAINT document_data_pkey PRIMARY KEY (doc_id)
)

/*
3 rows from document_data table:
doc_id	content	embeddings	added
c3065f25-34bd-433d-b6d3-71e5f7603e42	Here is a lot of text that we will embed	[0.023209529, 0.023142641, 0.031168992, 0.016092831, -0.019182976, 0.026192656, 0.0064378013, -0.031	2024-11-18 22:52:55.264561
*/


CREATE TABLE documents (
	doc_id UUID

In [26]:
from sqlalchemy import create_engine, inspect
from tabulate import tabulate


def get_schema(_):
    engine = create_engine(CONNECTION_STRING)

    inspector = inspect(engine)
    columns = inspector.get_columns("employees")

    column_data = [
        {
            "Column Name": col["name"],
            "Data Type": str(col["type"]),
            "Nullable": "Yes" if col["nullable"] else "No",
            "Default": col["default"] if col["default"] else "None",
            "Autoincrement": "Yes" if col["autoincrement"] else "No",
        }
        for col in columns
    ]
    schema_output = tabulate(column_data, headers="keys", tablefmt="grid")
    formatted_schema = f"Schema for 'employees' table:\n{schema_output}"

    return formatted_schema

In [27]:
print(get_schema("_"))

Schema for 'employees' table:
+---------------+--------------+------------+---------------------------------------+-----------------+
| Column Name   | Data Type    | Nullable   | Default                               | Autoincrement   |
| id            | INTEGER      | No         | nextval('employees_id_seq'::regclass) | Yes             |
+---------------+--------------+------------+---------------------------------------+-----------------+
| name          | VARCHAR(100) | Yes        | None                                  | No              |
+---------------+--------------+------------+---------------------------------------+-----------------+
| position      | VARCHAR(50)  | Yes        | None                                  | No              |
+---------------+--------------+------------+---------------------------------------+-----------------+
| salary        | NUMERIC      | Yes        | None                                  | No              |
+---------------+--------------+--

In [28]:
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

model = ChatOpenAI()

sql_response = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | model.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

sql_response.invoke({"question": "What is the largest salary?"})

'SELECT MAX(salary) AS largest_salary FROM employees;'

In [29]:
from langchain_core.runnables import RunnableLambda

template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_template(template)


def debug(input):
    print("SQL Output: ", input["query"])
    return input


sql_chain = (
    RunnablePassthrough.assign(query=sql_response).assign(
        schema=get_schema,
        response=lambda x: run_query(x["query"]),
    )
    | RunnableLambda(debug)
    | prompt_response
    | model
    | StrOutputParser()
)

In [30]:
sql_chain.invoke({"question": "Whats is the highest salary?"})

SQL Output:  SELECT MAX(salary) AS highest_salary FROM employees;


"The highest salary in the 'employees' table is $75,000."

In [33]:
from sqlalchemy.exc import ProgrammingError
from psycopg2.errors import InsufficientPrivilege

try:
    result = sql_chain.invoke({"question": "Drop all rows from the employees table"})
except ProgrammingError as pe:
    if isinstance(pe.orig, InsufficientPrivilege):
        result = "Haha nice try! Got ya!"
    else:
        result = "An unexpected error occurred"
except Exception as e:
    result = "An unexpected error occurred"

print(result)

An unexpected error occurred


### Routing

In [34]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate


classification_template = PromptTemplate.from_template(
    """You are good at classifying a question.
    Given the user question below, classify it as either being about `Database`, `Chat` or 'Offtopic'.

    <If the question is about products of the restaurant OR ordering food classify the question as 'Database'>
    <If the question is about restaurant related topics like opening hours and similar topics, classify it as 'Chat'>
    <If the question is about whether, football or anything not related to the restaurant or
    products, classify the question as 'offtopic'>

    <question>
    {question}
    </question>

    Classification:"""
)

classification_chain = classification_template | ChatOpenAI() | StrOutputParser()

In [35]:
classification_chain.invoke({"question": "How is the wheather?"})

'Offtopic'

In [36]:
def route(info):
    if "database" in info["topic"].lower():
        return sql_chain
    elif "chat" in info["topic"].lower():
        return rag_chain
    else:
        return "I am sorry, I am not allowed to answer about this topic."

In [37]:
from langchain_core.runnables import RunnableLambda, RunnableParallel

full_chain = RunnableParallel(
    {
        "topic": classification_chain,
        "question": lambda x: x["question"],
    }
) | RunnableLambda(route)

In [38]:
full_chain.invoke({"question": "How many employees do you have?"})

"Based on the provided context, it is not possible to determine the number of employees at Chef Amico's Restaurant. The documents only mention the creation of the restaurant and its philosophy of hospitality."

In [39]:
full_chain.invoke({"question": "How will the weather be tomorrow?"})

'I am sorry, I am not allowed to answer about this topic.'

In [40]:
full_chain.invoke({"question": "Who is the owner of the restaurant?"})

'Chef Amico is the owner of the restaurant.'