In [None]:
import requests
import operator
import functools
import pandas as pd

from sqlalchemy import create_engine
from langchain_community.utilities import SQLDatabase

from langchain_core.tools import tool
from langchain_community.tools.sql_database.tool import (
    InfoSQLDatabaseTool,
    ListSQLDatabaseTool,
    QuerySQLCheckerTool,
    QuerySQLDataBaseTool,
)

from langgraph.prebuilt import ToolNode
from typing import Annotated, Sequence


from langchain_core.messages import AIMessage
from langchain_core.messages import (
    BaseMessage,
    HumanMessage,
    ToolMessage,
)

from langchain import hub
from langchain.agents import AgentExecutor
from langchain.agents import create_tool_calling_agent


from langchain_core.agents import AgentAction
from typing import TypedDict, Annotated, List
from langchain_core.messages import BaseMessage, trim_messages
from langgraph.graph import StateGraph, END,START
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser

In [None]:
from langchain_groq import ChatGroq
llm = ChatGroq(api_key="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")

df = pd.read_csv("/home/godwin/Documents/Workflow/Customer-retention/data/raw_data/Churn.csv")
engine = create_engine("sqlite:///local.db")
df.to_sql("customers", engine, index=False)
db = SQLDatabase(engine=engine)

In [None]:
@tool("list_tables")
def list_tables() -> str:
    """List the available tables in the database"""
    return ListSQLDatabaseTool(db=db).invoke("")

@tool("tables_schema")
def tables_schema(tables: str) -> str:
    """
    Input is a comma-separated list of tables, output is the schema and sample rows
    for those tables. Be sure that the tables actually exist by calling `list_tables` first!
    Example Input: table1, table2, table3
    """
    tool = InfoSQLDatabaseTool(db=db)
    return tool.invoke(tables)

@tool("execute_sql")
def execute_sql(sql_query: str) -> str:
    """Execute a SQL query against the database. Returns the result"""
    return QuerySQLDataBaseTool(db=db).invoke(sql_query)

@tool("check_sql")
def check_sql(sql_query: str) -> str:
    """
    Use this tool to double check if your query is correct before executing it. Always use this
    tool before executing a query with `execute_sql`.
    """
    return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})

@tool("make_predictions")
def make_inference(input_data):
    """
    Use this tool to perform inference on data from a specified date range.

    Retrieves data for the given date range from a database, sends it to an inference
    endpoint for prediction, and prints the response.
    """

    data = input_data.to_dict()
    inference_endpoint = "https://retention.zapto.org/predict"

    response = requests.post(inference_endpoint, json=data).json()
    return response

@tool("email_draft_generator")
def email_draft_generator(query: str):
    """
    Use this tool to generates an email draft.

    Uses a language model to create a sample draft email on a specified topic given.
    """
    email_draft = llm.invoke(f"Generate a sample draft mail on {query}.")
    return email_draft

@tool("email_subject_generator")
def email_subjects_generator(query: str):
    """
    Use this tool to generates a list of email subject suggestions based on a request.

    Uses a language model to create multiple email subject lines for a specific request. Always use this tool before 
    using the `email_draft_generator`.
    """
    topic_suggestions = llm.invoke(f"""Generate mail subjects that can fit for this "{query}" request """)
    return topic_suggestions

@tool("Simple query agent")
def simple_query_responder(query: str):

    """Use this tool to generates response to simple query from greetings to little interactions

    Uses a Language model to respond directly to query."""

    response = llm.invoke(query)
    # We return a list, because this will get added to the existing list
    return {"messages": [response]}

@tool("final_response")
def final_response(response: str, research_steps):
    """
    Provides a detailed and reliable answer to the user's query.
    
    """

    if type(research_steps) is list:
        research_steps = "\n".join([f"- {r}" for r in research_steps])
    
    return response

In [None]:
trimmer = trim_messages(
    max_tokens=100000,
    strategy="last",
    token_counter=llm,
    include_system=True,
)


def agent_node(state, agent, name):
    result = agent.invoke(state)
    return {
        "messages": [HumanMessage(content=result["messages"][-1].content, name=name)]
    }


def create_team_supervisor(llm, system_prompt, members) -> str:
    """An LLM-based router."""
    options = ["FINISH"] + members
    function_def = {
        "name": "route",
        "description": "Select the next role.",
        "parameters": {
            "title": "routeSchema",
            "type": "object",
            "properties": {
                "next": {
                    "title": "Next",
                    "anyOf": [
                        {"enum": options},
                    ],
                },
            },
            "required": ["next"],
        },
    }
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_prompt),
            MessagesPlaceholder(variable_name="messages"),
            (
                "system",
                "Given the conversation above, who should act next?"
                " Or should we FINISH? Select one of: {options}",
            ),
        ]
    ).partial(options=str(options), team_members=", ".join(members))
    return (
        prompt
        | trimmer
        | llm.bind_functions(functions=[function_def], function_call="route")
        | JsonOutputFunctionsParser()
    )

In [None]:
from langgraph.prebuilt import create_react_agent

# DatabaseTeam graph state
class DatabaseTeamState(TypedDict):
    # A message is added after each team member finishes
    messages: Annotated[List[BaseMessage], operator.add]
    # The team members are tracked so they are aware of
    # the others' skill-sets
    team_members: List[str]
    # Used to route work. The supervisor calls a function
    # that will update this every time it makes a decision
    next: str

sql_dev = create_react_agent(llm, tools=[list_tables, tables_schema, execute_sql, check_sql], 
                             state_modifier= """You are an experienced database engineer who is master at creating efficient and complex SQL queries.
                                                    You have a deep understanding of how different databases work and how to optimize queries.
                                                    Use the `list_tables` to find available tables.
                                                    Use the `tables_schema` to understand the metadata for the tables.
                                                    Use the `execute_sql` to execute queries against the database.
                                                    Use the `check_sql` to check your queries for correctness.
                                                    You should produce good result for analyst to use""")
db_node = functools.partial(agent_node, agent=sql_dev, name="database_manager")

data_analyst = create_react_agent(llm, tools=[],
                                  state_modifier= """
                                                        You have deep experience with analyzing datasets using Python.
                                                        Your work is always based on the provided data and is clear,
                                                        easy-to-understand and to the point. You have attention
                                                        to detail and always produce very detailed work (as long as you need).
                                                        Your analyses should be good for the reporter to use for final reporting.
                                                    """)
analysis_node = functools.partial(agent_node, agent=data_analyst, name="data_analyst")

supervisor_agent = create_team_supervisor(
    llm,
    "You are a supervisor tasked with managing a conversation between the"
    " following workers:  database_manager, data_analyst. Given the following user request,"
    " respond with the worker to act next. Each worker will perform a"
    " task and respond with their results and status. When finished,"
    " respond with FINISH.",
    ["database_manager", "data_analyst"],
)

In [None]:
database_graph = StateGraph(DatabaseTeamState)
database_graph.add_node("database_manager", db_node)
database_graph.add_node("data_analyst", analysis_node)
database_graph.add_node("supervisor", supervisor_agent)



# Define the control flow
database_graph.add_edge("database_manager", "supervisor")
database_graph.add_edge("data_analyst", "supervisor")

database_graph.add_conditional_edges(
    "supervisor",
    lambda x: x["next"],
    {"data_analyst": "data_analyst", "FINISH": END},
)

database_graph.add_conditional_edges("data_analyst", lambda x: x["next"], {"database_manager":"database_manager"})
database_graph.add_edge(START, "supervisor")
chain = database_graph.compile()


# The following functions interoperate between the top level graph state
# and the state of the research sub-graph
# this makes it so that the states of each graph don't get intermixed
def enter_chain(message: str):
    results = {
        "messages": [HumanMessage(content=message)],
    }
    return results


research_chain = enter_chain | chain

In [None]:
from IPython.display import Image, display

display(Image(chain.get_graph(xray=True).draw_mermaid_png()))

In [None]:
# for s in research_chain.stream(
#     "What is the churn rate in the company from the data in the database?", {"recursion_limit": 100}
# ):
#     if "__end__" not in s:
#         print(s)
#         print("---")