In [1]:
from langchain import OpenAI, SQLDatabase
from snowflake.snowpark import Session
from langchain.chains import create_sql_query_chain

from dotenv import main
import os

main.load_dotenv()
CSDW_USEC_PASSWORD = os.getenv("CSDW_USEC_PASSWORD")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

In [2]:
# Create an .env file in curr directory with your password strored in CSDW_USEC_PASSWORD variable
connection_parameters = {
    "account": "psai-csdw_usec",
    "user": "ROYDON_USEC",
    "password": CSDW_USEC_PASSWORD,
    "database" : "DB_DEV",
    "schema" : "TEST_SCHEMA",
    "role": "DATA_SCIENTIST", 
    "warehouse": "COMPUTE_VWH",  # optional
    } 
sp_session = Session.builder.configs(connection_parameters).create()  

  warn(f"Bad owner or permissions on {str(filep)}{chmod_message}")


In [3]:
snowflake_account = "psai-csdw_usec"
username = "ROYDON_USEC"
password = CSDW_USEC_PASSWORD
database= "DB_DEV"
schema = "TEST_SCHEMA"
warehouse = "COMPUTE_VWH"
role = "DATA_SCIENTIST"

### Create Custom prompt for Snowflaked SQL dialect
LangChain does not support Snowflake SQL currently (no specialised prompt to guide llm to follow snowflake sql dialect). The cell below creates one manually.

In [15]:
from langchain_core.prompts import PromptTemplate
from langchain.chains.sql_database.prompt import SQL_PROMPTS

_snowflake_prompt = """You are a Snowflake expert. Given an input question, first create a syntactically correct Snowflake query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per Snowflake. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURRENT_DATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

"""

SNOWFLAKE_PROMPT = PromptTemplate(
    input_variables=["input", "table_info", "top_k"],
    template=_snowflake_prompt + 
    """
    Only use the following tables:
    {table_info}
    Question: {input}
    """,
)

SQL_PROMPTS["snowflake"] = SNOWFLAKE_PROMPT

In [4]:
snowflake_url = f"snowflake://{username}:{password}@{snowflake_account}/{database}/{schema}?warehouse={warehouse}&role={role}"

db = SQLDatabase.from_uri(snowflake_url,sample_rows_in_table_info=1, include_tables=['orders', 'locations'])

# we can see what information is passed to the LLM regarding the database
print(db.table_info)


CREATE TABLE locations (
	location_id DECIMAL(38, 0), 
	location_name VARCHAR(50), 
	city VARCHAR(50), 
	state VARCHAR(2), 
	latitude DECIMAL(9, 6), 
	longitude DECIMAL(9, 6)
)

/*
1 rows from locations table:
location_id	location_name	city	state	latitude	longitude
1	Warehouse A	New York	NY	40.712776	-74.005974
*/


CREATE TABLE orders (
	order_id DECIMAL(38, 0), 
	customer_id DECIMAL(38, 0), 
	product_name VARCHAR(50), 
	quantity DECIMAL(38, 0), 
	order_date DATE, 
	location_id DECIMAL(38, 0), 
	order_amount DECIMAL(10, 2)
)

/*
1 rows from orders table:
order_id	customer_id	product_name	quantity	order_date	location_id	order_amount
1	101	Product A	3	2023-08-15	1	150.00
*/


In [5]:
type(db.table_info)

str

In [6]:
from langchain_core.language_models.llms import LLM
from typing import Any, List, Mapping, Optional
from langchain_core.callbacks.manager import CallbackManagerForLLMRun

# Define LangChain Custom LLM object to interface with Cortex COMPLETE
class SnowflakeCortexLLM(LLM):

    sp_session: Session = None
    """Snowpark Session class instance, set before invoking the LLM to authenticate to an appropriate Snowflake account with Cortex LLMs provisioned."""

    model: str = 'mistral-7b'
    """The Snowflake cortex hosted LLM model name, default to `mistral-7b`. Refer to doc for other options."""

    cortex_function: str = 'complete'
    """The cortex function to use, defaulted to complete. for other types refer to doc"""

    llm_type: str = 'snowflake-cortex'
    """The type of LLM, defaulted to snowflake-cortex, for logging purposes only."""

    @property
    def _llm_type(self) -> str:
        return "snowflake_cortex"

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """Adapt the Snowflake Cortex LLM SQL-based API to this Python interface.
        Modify this accordingly to the available Snowflake Cortex LLM API.
        For example, this implementation is based on the following snowflake SQL command: 
        `SELECT SNOWFLAKE.CORTEX.COMPLETE('<model_name>', '<prompt_text>');`
        """    
        prompt_text = prompt
        sql_statement = f"""SELECT snowflake.cortex.{self.cortex_function}('{self.model}','{prompt_text}') AS LLM_RESPONSE;"""
        l_rows = self.sp_session.sql(sql_statement).collect()
        llm_response = l_rows[0]['LLM_RESPONSE'] # only 1 row is expected from the SQL statement as it is applied to 1 prompt
        return llm_response

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {
            "model": self.model
            ,"cortex_function" : self.cortex_function
            ,"snowpark_session": self.sp_session.session_id
        }
    @property
    def _llm_type(cls) -> str:
        """Get the type of language model used by this chat model. Used for logging purposes only."""
        return cls.llm_type

In [16]:
# Initialize llm
llm = SnowflakeCortexLLM(sp_session=sp_session, model='mistral-7b')
llm2 = SnowflakeCortexLLM(sp_session=sp_session, model='mixtral-8x7b')
llm3 = OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)
  
database_chain = create_sql_query_chain(llm,db)
database_chain_2 = create_sql_query_chain(llm2,db)
database_chain_3 = create_sql_query_chain(llm3,db)


In [None]:
# Check if snowflake dialect is added
list(SQL_PROMPTS)

['crate',
 'duckdb',
 'googlesql',
 'mssql',
 'mysql',
 'mariadb',
 'oracle',
 'postgresql',
 'sqlite',
 'clickhouse',
 'prestodb',
 'snowflake']

### Mistral got `city` column wrong, used `cities` instead
As seen in the line "WHERE cities IN ('New York', 'Chicago')", column should be city

In [16]:
# Test with mistral-7b
prompt = "Show the total order amounts for the warehouses in New York and Chicago."
sql_query = database_chain.invoke({"question": prompt})

# Split the string at "SQLResult:" and take the first part
cleaned_string = sql_query.split('SQLResult:')[0].strip()
cleaned_string = cleaned_string.strip(';')


#we can visualize what sql query is generated by the LLM
print(cleaned_string)

sp_session.sql(cleaned_string).show()

SELECT locations.city, SUM(orders.order_amount) as total_order_amount
FROM locations
JOIN orders ON locations.location_id = orders.location_id
WHERE cities IN ('New York', 'Chicago')
GROUP BY locations.city


SnowparkSQLException: (1304): 01b528a0-0000-36c3-0013-9903002ba172: 000904 (42000): SQL compilation error: error line 4 at position 6
invalid identifier 'CITIES'

GPT3.5 executes successfully

In [None]:
# Test with GPT3.5
prompt = "Show the total order amounts for the warehouses in New York and Chicago."
sql_query = database_chain_3.invoke({"question": prompt})

#we can visualize what sql query is generated by the LLM
print(sql_query)

sp_session.sql(sql_query).show()

SELECT SUM(order_amount) AS total_order_amount
FROM orders
INNER JOIN locations ON orders.location_id = locations.location_id
WHERE locations.city IN ('New York', 'Chicago')
GROUP BY locations.location_name
------------------------
|"TOTAL_ORDER_AMOUNT"  |
------------------------
|450.00                |
|440.00                |
------------------------



### Mistral-7b and Mixtral-8x7b consistently adds "SQLResult:...." and ";" at the end of SQL queries (which causes errors) to its prompt when we use SQL_query_chain from LangChain.

OpenAI model works well with out of the box sql agent function, generating only the sql query and nothing else.

In [None]:
# Test Mistral-7b
prompt = "Which city is order id 1, 4 and 5 in? Return a list please."

sql_query = database_chain.invoke({"question": prompt})
print(sql_query)
sp_session.sql(sql_query).show()

SELECT l.city
FROM locations l
JOIN orders o ON l.location_id = o.location_id
WHERE o.order_id IN (1, 4, 5)
LIMIT 5;

SQLResult:
city
-----
New York
New York
(other city) -- assuming there is another city for order 5

Answer: The orders with ID 1 and 4 are in New York. The city for order 5 is not provided in the example data.


SnowparkSQLException: (1304): 01b5274d-0000-3658-0013-9903002a9686: 001003 (42000): SQL compilation error:
syntax error line 7 at position 0 unexpected 'SQLResult'.

Solution: Some cleaning of query generated before doing sql

In [None]:
# Test Mistral-7b
prompt = "Which city is order id 1, 4 and 5 in? Return a list please."

sql_query = database_chain.invoke({"question": prompt})

# Split the string at "SQLResult:" and take the first part
cleaned_string = sql_query.split('SQLResult:')[0].strip()
cleaned_string = cleaned_string.strip(';')


#we can visualize what sql query is generated by the LLM
print(cleaned_string)

sp_session.sql(cleaned_string).show()

SELECT l.city
FROM locations l
JOIN orders o ON l.location_id = o.location_id
WHERE o.order_id IN (1, 4, 5)
LIMIT 5
------------
|"CITY"    |
------------
|New York  |
|Houston   |
|Miami     |
------------



### Mistral SQL query does not work

In [18]:
prompt = "For cities with orders IDs 1, 4, and 5, provide the number of orders placed in the last year and the trend in monthly order volumes. Include a comparison with the previous year."

sql_query = database_chain.invoke({"question": prompt})

# Split the string at "SQLResult:" and take the first part
cleaned_string = sql_query.split('SQLResult:')[0].strip()
cleaned_string = cleaned_string.strip(';')


#we can visualize what sql query is generated by the LLM
print(cleaned_string)

sp_session.sql(cleaned_string).show()

SELECT c.city, COUNT(o.order_id) as num_orders, EXTRACT(MONTH FROM o.order_date) as month,
       (COUNT(o.order_id) OVER (PARTITION BY c.city ORDER BY EXTRACT(MONTH FROM o.order_date) ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) - COUNT(o.order_id) OVER (PARTITION BY c.city ORDER BY EXTRACT(MONTH FROM o.order_date) ROWS BETWEEN CURRENT ROW PRECEDING AND UNBOUNDED PRECEDING)) as monthly_change
FROM locations c
JOIN orders o ON c.location_id = o.location_id
WHERE o.order_id IN (1, 4, 5)
AND o.order_date BETWEEN DATEADD(year, -1, CURRENT_DATE()) AND CURRENT_DATE()
GROUP BY c.city, EXTRACT(MONTH FROM o.order_date)
ORDER BY c.city, month
LIMIT 5


SnowparkSQLException: (1304): 01b54225-0000-3a5a-0013-9903003200ae: 001003 (42000): SQL compilation error:
syntax error line 2 at position 257 unexpected 'PRECEDING'.

### Mixtral-8x7b hallucinates a `city` column in the orders table. Unreliable SQL genereted. 
Additionally it generates backslahes before underscores which can be manually handled, but unreliable output means this llm should not be used anyways.

In [None]:
import re

# Test with mixtral-8x7b
prompt = "For cities with orders IDs 1, 4, and 5, provide the number of orders placed in the last year and the trend in monthly order volumes. Include a comparison with the previous year."

sql_query = database_chain_2.invoke({"question": prompt})

# Split the string at "SQLResult:" and take the first part
cleaned_string = sql_query.split('SQLResult:')[0].strip()
cleaned_string = cleaned_string.strip(';')
cleaned_string = re.sub(r'(?<!\\)\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})', r'', cleaned_string)


#we can visualize what sql query is generated by the LLM
print(cleaned_string)

sp_session.sql(cleaned_string).show()

WITH order_data AS (
SELECT o.city, COUNT(o.order_id) as num_orders, EXTRACT(YEAR FROM o.order_date) as order_year
FROM orders o
WHERE o.order_id IN (1, 4, 5)
GROUP BY o.city, order_year
),
order_trend AS (
SELECT city, order_year, num_orders,
CASE
WHEN LAG(num_orders) OVER (PARTITION BY city ORDER BY order_year) IS NULL THEN num_orders
WHEN num_orders - LAG(num_orders) OVER (PARTITION BY city ORDER BY order_year) > 0 THEN 'Increase'
ELSE 'Decrease'
END AS trend
FROM order_data
)
SELECT city, num_orders, trend
FROM order_trend
WHERE order_year = EXTRACT(YEAR FROM CURRENT_DATE)
ORDER BY city


SnowparkSQLException: (1304): 01b52756-0000-368c-0013-9903002ae3be: 000904 (42000): SQL compilation error: error line 2 at position 7
invalid identifier 'O.CITY'

In [None]:
# Test with openAI gpt3.5
prompt = "For cities with orders IDs 1, 4, and 5, provide the number of orders placed in the last year and the trend in monthly order volumes. Include a comparison with the previous year."

sql_query = database_chain_2.invoke({"question": prompt})

#we can visualize what sql query is generated by the LLM
print(sql_query)

sp_session.sql(sql_query).show()

SELECT city, COUNT(order_id) AS num_orders, SUM(quantity) AS total_quantity, SUM(order_amount) AS total_amount, MONTH(order_date) AS month, YEAR(order_date) AS year
FROM orders
INNER JOIN locations ON orders.location_id = locations.location_id
WHERE order_id IN (1, 4, 5) AND order_date >= DATEADD(year, -1, GETDATE())
GROUP BY city, MONTH(order_date), YEAR(order_date)
ORDER BY city, year, month
----------------------------------------------------------------------------------
|"CITY"    |"NUM_ORDERS"  |"TOTAL_QUANTITY"  |"TOTAL_AMOUNT"  |"MONTH"  |"YEAR"  |
----------------------------------------------------------------------------------
|Houston   |1             |4                 |200.00          |8        |2023    |
|Miami     |1             |2                 |120.00          |8        |2023    |
|New York  |1             |3                 |150.00          |8        |2023    |
----------------------------------------------------------------------------------

