In [112]:
from typing import List
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage, SystemMessage
import json


In [113]:
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI

# Define your Pydantic data model for the expected JSON output.
class SQLResponse(BaseModel):
    conversation_type: str = Field(
        description="Type of conversation: either 'general' or 'sql_query'"
    )
    sql_query: str = Field(
        description="The SQL query to be executed (should be an empty string if not a query)"
    )

parser = JsonOutputParser(pydantic_object=SQLResponse)

In [None]:
llm = ChatOpenAI(api_key="YOUR-API-KEY", model="gpt-3.5-turbo", temperature=0)

In [116]:
sql_generation_prompt = """You are an expert SQL developer and you have access to a MySQL 8.0 database with 2 tables called 'PRODUCT' and 'SALES'. 
The 'PRODUCT' table just have 1 column called 'product_name' and the 'SALES' table has 3 columns called 'sale_date', 'product_name' and 'total_quantity'.
'product_name' is the primary key of the 'PRODUCT' table and a foreign key in the 'SALES' table. 'sale_date' and 'product_name' are the composite primary key of the 'SALES' table.
The 'PRODUCT' table contains a list of products and the 'SALES' table contains sales records for those products. If the user
is having a general conversation then answer normal conversation. If the user is asking for a SQL query, then generate a SQL query based on the context provided.
Generate a SQL query based on the following context to answer the query asked by the user. The final output is a json with keys 'conversation_type' and 'response'.
The 'conversation_type' key should be either 'general' or 'sql_query' and the 'response' key should contain your response either to the general questions or the SQL query to be executed.
{format_instructions}\n user_query: {user_query}"""


# "What is the total quantity of product pizza sold in the last month?"

def sql_generation_chain(state):
    prompt = PromptTemplate(
        template=sql_generation_prompt,
        input_variables=["user_query"],
        partial_variables={"format_instructions": parser.get_format_instructions()},
    )

    # Compose the chain by combining prompt, model, and parser.
    chain = prompt | llm | parser

    user_query = state["messages"][-1].content

    # Execute the chain.
    result = chain.invoke({"user_query": user_query})

    return {"messages": [AIMessage(content=json.dumps(result, indent=2))]}

In [117]:

from google.cloud.sql.connector import Connector
import sqlalchemy
import os
import pandas as pd

def run_sql_command(state):

    MYSQL_HOST = "35.192.172.104"
    MYSQL_USER = "shrey"
    MYSQL_PASSWORD = "shrey"
    MYSQL_DATABASE = "combined_transaction_data"

    host = MYSQL_HOST
    user = MYSQL_USER
    password = MYSQL_PASSWORD
    database = MYSQL_DATABASE
    connector = Connector()

    def getconn():
        conn = connector.connect(
            "primordial-veld-450618-n4:us-central1:mlops-sql",  # Cloud SQL instance connection name
            "pymysql",  # Database driver
            user=user,  # Database user
            password=password,  # Database password
            db=database,
        )
        return conn

    pool = sqlalchemy.create_engine(
        "mysql+pymysql://",  # or "postgresql+pg8000://" for PostgreSQL, "mssql+pytds://" for SQL Server
        creator=getconn,
    )
    
    df = pd.read_sql(json.loads(state["messages"][-1].content)['response'], pool)
    connector.close()
    # return df
    df_json = df.to_json(orient='records')
    return {"messages": [AIMessage(content=df_json)]}

In [118]:
def display_last_message(state):
    return {"messages": [AIMessage(content=json.loads(state["messages"][-1].content)['response'])]}

In [119]:
from langgraph.graph import StateGraph, START, END
def get_state(state):
    last_message = json.loads(state["messages"][-1].content)
    if last_message["conversation_type"] == "sql_query":
        return "run_sql_command"
    else:
        return "display_last_message"

In [120]:
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.message import add_messages
from typing import Annotated
from typing_extensions import TypedDict


class State(TypedDict):
    messages: Annotated[list, add_messages]


memory = MemorySaver()
workflow = StateGraph(State)
workflow.add_node("sql_generation_chain", sql_generation_chain)
workflow.add_node("run_sql_command", run_sql_command)
workflow.add_node("display_last_message", display_last_message)


workflow.add_edge(START, "sql_generation_chain")
workflow.add_conditional_edges("sql_generation_chain", get_state, ["run_sql_command", "display_last_message"])
# workflow.add_edge("sql_generation_chain", "run_sql_command")
workflow.add_edge("display_last_message", END)
workflow.add_edge("run_sql_command", END)
graph = workflow.compile(checkpointer=memory)

In [121]:
# from IPython.display import Image, display

# display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
import uuid

config = {"configurable": {"thread_id": str(uuid.uuid4())}}
while True:
    user = input("User (q/Q to quit): ")
    print(f"User (q/Q to quit): {user}")
    if user in {"q", "Q"}:
        print("AI: Byebye")
        break
    output = None
    for output in graph.stream(
        {"messages": [HumanMessage(content=user)]}, config=config, stream_mode="updates"
    ):
        last_message = next(iter(output.values()))["messages"][-1]
        last_message.pretty_print()

    if output and "prompt" in output:
        print("Done!")