#### Imports

In [52]:
from langchain_community.utilities import SQLDatabase
from langchain.prompts.chat import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
import os
from dotenv import load_dotenv

#### Set Enviroment Variables

In [53]:
load_dotenv()
os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")

Python-dotenv could not parse statement starting at line 1
Python-dotenv could not parse statement starting at line 5


#### Connect to Postgres SQL

In [54]:
postgres_url = "postgresql+psycopg2://postgres:Sayonika1@localhost:5432/churn"
db = SQLDatabase.from_uri(
    postgres_url,
    schema = "socialnet7",
    view_support = True
)


In [55]:
def get_schema(_):
    return db.get_table_info()

### Template

In [56]:
template = """
Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:
"""

prompt = ChatPromptTemplate.from_template(template)

In [57]:
prompt.format(schema = "my schema", question = "how many users are there")

"Human: \nBased on the table schema below, write a SQL query that would answer the user's question:\nmy schema\n\nQuestion: how many users are there\nSQL Query:\n"

### Connect to SQL and create an Agent

In [58]:
llm = ChatOpenAI(model = "gpt-4-0613")

In [59]:
sql_chain = (
    RunnablePassthrough.assign(schema = get_schema)
    | prompt
    | llm.bind(stop = "\nSQL Result")
    | StrOutputParser()
)

In [60]:
sql_chain.invoke({"question": "how many users churned in this database?"})

'SELECT COUNT(DISTINCT account_id)\nFROM socialnet7.observation\nWHERE is_churn = TRUE;'

In [61]:
template = """
Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL query: {query}
SQL Response: {response}
"""
prompt = ChatPromptTemplate.from_template(template)

In [62]:
def run_query(query):
    return db.run(query)

In [66]:
full_chain = (
    RunnablePassthrough.assign(query = sql_chain).assign(
        schema = get_schema,
        response = lambda variables: run_query(variables["query"])
    )
    | prompt
    | llm
    | StrOutputParser()
)

In [67]:
full_chain.invoke({"question": "how many unique churn users are there?"})

'The number of unique churn users is 224.'