In [2]:
from langchain.chat_models import ChatOpenAI
from langchain.agents import create_sql_agent
from langchain.utilities import SQLDatabase
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.agents import AgentType
from os import getenv

case_management_schema = {
    "tables": [
        {
            "name": "cases",
            "columns": [
                {"name": "case_id", "type": "SERIAL", "primary_key": True},
                {"name": "customer_id", "type": "INTEGER", "foreign_key": "customers.customer_id"},
                {"name": "description", "type": "TEXT"},
                {"name": "status", "type": "VARCHAR(20)"},
                {"name": "created_at", "type": "TIMESTAMP", "default": "now()"}
            ]
        },
        {
            "name": "customers",
            "columns": [
                {"name": "customer_id", "type": "SERIAL", "primary_key": True},
                {"name": "name", "type": "VARCHAR(100)"},
                {"name": "email", "type": "VARCHAR(100)", "unique": True},
                {"name": "phone", "type": "VARCHAR(20)"}
            ]
        },
        {
            "name": "case_notes",
            "columns": [
                {"name": "note_id", "type": "SERIAL", "primary_key": True},
                {"name": "case_id", "type": "INTEGER", "foreign_key": "cases.case_id"},
                {"name": "note", "type": "TEXT"},
                {"name": "created_at", "type": "TIMESTAMP", "default": "now()"}
            ]
        }
    ]
}

def get_schema_string(schema: dict) -> str:
    """
    Converts the schema dictionary into a string representation.
    """
    schema_string = ""
    for table in schema["tables"]:
        schema_string += f"Table: {table['name']}\n"
        for column in table["columns"]:
            schema_string += f"  Column: {column['name']}  Type: {column['type']}"
            if column.get("primary_key"):
                schema_string += "  Primary Key"
            if column.get("foreign_key"):
                schema_string += f"  Foreign Key ({column['foreign_key']})"
            if column.get("unique"):
                schema_string += "  Unique"
            if column.get("default"):
                schema_string += f"  Default: {column['default']}"
            schema_string += "\n"
    return schema_string

def get_llm():
    llm = ChatOpenAI(
    openai_api_base=getenv("OPENROUTER_BASE_URL"),
    openai_api_key=getenv("OPENROUTER_API_KEY"),
    model_name="google/gemma-3-27b-it:free",
    )
    return llm


In [8]:
# import psycopg2 # If you are using postgres

# Your database connection details.  Use a secure method to store these!
# DB_USER = "your_user"
# DB_PASSWORD = "your_password"
# DB_HOST = "your_host"
# DB_NAME = "your_database_name"
# DB_PORT = "your_port" # Add the port if it is not the default

# # 2. Connect to the database using psycopg2
# conn_string = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
# db = SQLDatabase.from_uri(conn_string)

from langchain.agents import ZeroShotAgent, initialize_agent

llm = get_llm()

# # Create SQL agent.  This is a high-level agent for SQL.
# agent_executor = create_sql_agent(
#     llm=llm,
#     verbose=True,  # Set to True to see the intermediate steps
#     agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, # Use a simple agent type
# )

def generate_sql(schema: str, instruction: str):
    """
    Generates SQL code (stored procedure or query) based on the schema and instruction.
    """

    prompt_template = """
    You are a SQL expert.
    Here is the SQL schema for the case management database:
    {schema}

    Generate SQL code to perform the following action:
    {instruction}
    
    IMPORTANT:
    1.  For stored procedures, enclose the entire code within CREATE PROCEDURE ... END;
    2.  For stored procedures, make sure to handle any errors.
    3.  For queries, just return the SQL query.
    4.  Do not include any explanation before or after the SQL code.
    """
    prompt = PromptTemplate(input_variables=["schema", "instruction"], template=prompt_template)
    chain = LLMChain(llm=llm, prompt=prompt)
    sql_code = chain.run(schema=schema, instruction=instruction)
    return sql_code

In [9]:
# def execute_sql(conn_string, sql_code, is_procedure=False):
#     """
#     Executes the generated SQL code (stored procedure or query) against the database.
#     """
#     conn = None
#     cursor = None
#     try:
#         conn = psycopg2.connect(conn_string)
#         cursor = conn.cursor()
#         cursor.execute(sql_code)
#         if is_procedure:
#             conn.commit()  # Commit the transaction for stored procedure creation
#             return "Stored procedure created successfully."
#         else:
#             results = cursor.fetchall()
#             conn.commit() # Commit the query
#             return results
#     except Exception as e:
#         conn.rollback()
#         return f"Error: {e}"
#     finally:
#         if cursor:
#             cursor.close()
#         if conn:
#             conn.close()

# Example Usage
if __name__ == "__main__":
    schema_str = get_schema_string(case_management_schema)
    # Example 1: Generate a stored procedure
    instruction_sp = """
    Create a stored procedure named create_new_case that takes a customer_id and a description as input,
    creates a new case with the given information, sets the status to 'Open', and returns the new case_id.
    """
    sp_code = generate_sql(schema_str, instruction_sp)
    print(f"\nGenerated Stored Procedure Code:\n{sp_code}")
    # sp_result = execute_sql(conn_string, sp_code, is_procedure=True)
    # print(f"\nStored Procedure Execution Result:\n{sp_result}")

    # Example 2: Generate a query
    instruction_query = "Get the name and email of all customers who have open cases."
    query_code = generate_sql(schema_str, instruction_query)
    print(f"\nGenerated Query Code:\n{query_code}")
    # query_results = execute_sql(conn_string, query_code)
    # print(f"\nQuery Results:\n{query_results}")


Generated Stored Procedure Code:

```sql
CREATE PROCEDURE create_new_case(
  IN p_customer_id INTEGER,
  IN p_description TEXT
)
LANGUAGE plpgsql
AS $$
DECLARE
  v_new_case_id INTEGER;
BEGIN
  -- Check if the customer_id exists
  IF NOT EXISTS (SELECT 1 FROM customers WHERE customer_id = p_customer_id) THEN
    RAISE EXCEPTION 'Customer with ID % does not exist', p_customer_id;
  END IF;

  -- Insert the new case
  INSERT INTO cases (customer_id, description, status)
  VALUES (p_customer_id, p_description, 'Open')
  RETURNING case_id INTO v_new_case_id;

  -- Return the new case_id
  SELECT v_new_case_id;

EXCEPTION
  WHEN OTHERS THEN
    RAISE EXCEPTION 'Error creating case: %', SQLERRM;
END;
$$;
```

Generated Query Code:

```sql
SELECT DISTINCT
  c.name,
  c.email
FROM customers AS c
JOIN cases AS ca
  ON c.customer_id = ca.customer_id
WHERE
  ca.status != 'Closed';
```


In [None]:
CREATE PROCEDURE create_new_case(
  IN p_customer_id INTEGER,
  IN p_description TEXT
)
LANGUAGE plpgsql
AS $$
DECLARE
  v_new_case_id INTEGER;
BEGIN
  -- Check if the customer_id exists
  IF NOT EXISTS (SELECT 1 FROM customers WHERE customer_id = p_customer_id) THEN
    RAISE EXCEPTION 'Customer with ID % does not exist', p_customer_id;
  END IF;

  -- Insert the new case
  INSERT INTO cases (customer_id, description, status)
  VALUES (p_customer_id, p_description, 'Open')
  RETURNING case_id INTO v_new_case_id;

  -- Return the new case_id
  SELECT v_new_case_id;
...
  ON c.customer_id = ca.customer_id
WHERE
  ca.status != 'Closed';