## Load environment variable

In [26]:
# Load environments variables
from dotenv import load_dotenv
load_dotenv("../.streamlit/secrets.toml") 

True

In [25]:
import re

# function for retrieve sql from markdown
def get_sql(text):
    sql_match = re.search(r"```sql\n(.*)\n```", text, re.DOTALL)
    return sql_match.group(1) if sql_match else None

## Get database information using langchain.SQLDatabase
Snowflake Database has 43 tables 

In [2]:
import os
from snowflake.sqlalchemy import URL
from langchain import OpenAI, SQLDatabase, SQLDatabaseChain

# create snowflake connection uri
uri_snow = URL(
    account=os.getenv("account"),
    user=os.getenv("user"),
    password=os.getenv("password"),
    database=os.getenv("database"),
    schema=os.getenv("schema"),
    warehouse=os.getenv("warehouse"),
    role=os.getenv("role"),
)

# generate prompt 2 tables
tables = ["ecdc_global", "goog_global_mobility_report", "databank_demographics"]
db = SQLDatabase.from_uri(uri_snow, include_tables=tables)
# sample all tables
# db = SQLDatabase.from_uri(uri_snow)

In [4]:
# get DDL and 3 rows samples for every table
db_info = db.table_info
db_info

'\nCREATE TABLE databank_demographics (\n\tiso3166_1 VARCHAR(16777216), \n\tiso3166_2 VARCHAR(16777216), \n\tfips VARCHAR(16777216), \n\tlatitude FLOAT, \n\tlongitude FLOAT, \n\tstate VARCHAR(16777216), \n\tcounty VARCHAR(16777216), \n\ttotal_population DECIMAL(38, 0), \n\ttotal_male_population DECIMAL(38, 0), \n\ttotal_female_population DECIMAL(38, 0), \n\tcountry_region VARCHAR(250)\n)\n\n/*\n3 rows from databank_demographics table:\niso3166_1\tiso3166_2\tfips\tlatitude\tlongitude\tstate\tcounty\ttotal_population\ttotal_male_population\ttotal_female_population\tcountry_region\nAF\tNone\tNone\t33.0\t65.0\tNone\tNone\t37172386\t19093281\t18079105\tAfghanistan\nAL\tNone\tNone\t41.0\t20.0\tNone\tNone\t2866376\t1460043\t1406333\tAlbania\nDZ\tNone\tNone\t28.0\t3.0\tNone\tNone\t42228429\t21332000\t20896429\tAlgeria\n*/\n\n\nCREATE TABLE ecdc_global (\n\tcountry_region VARCHAR(16777216), \n\tcontinentexp VARCHAR(16777216), \n\tiso3166_1 VARCHAR(2), \n\tcases FLOAT, \n\tdeaths FLOAT, \n\tcase

In [5]:
# save database info into
with open("../docs/database_info.txt", "w") as fp:
    fp.write(db_info)

# Create chat

In [7]:
from langchain.prompts.prompt import PromptTemplate

template_qa = """ 
You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate.
When asked about your capabilities, provide a general overview of your ability to assist with data analysis tasks using Snowflake SQL, instead of performing specific SQL queries. 
Based on the question provided, if it pertains to data analysis or SQL tasks, generate SQL code that is compatible with the Snowflake environment. Additionally, offer a brief explanation about how you arrived at the SQL code. If the required column isn't explicitly stated in the context, suggest an alternative using available columns, but do not assume the existence of any columns that are not mentioned. Also, do not modify the database in any way (no insert, update, or delete operations). You are only allowed to query the database. Refrain from using the information schema.
If the question or context does not clearly involve SQL or data analysis tasks, respond appropriately without generating SQL queries. 
When the user expresses gratitude or says "Thanks", interpret it as a signal to conclude the conversation. Respond with an appropriate closing statement without generating further SQL queries.
If you don't know the answer, simply state, "I'm sorry, I don't know the answer to your question."
Write your response in markdown format.

CREATE TABLE databank_demographics (
	iso3166_1 VARCHAR(16777216), 
	iso3166_2 VARCHAR(16777216), 
	fips VARCHAR(16777216), 
	latitude FLOAT, 
	longitude FLOAT, 
	state VARCHAR(16777216), 
	county VARCHAR(16777216), 
	total_population DECIMAL(38, 0), 
	total_male_population DECIMAL(38, 0), 
	total_female_population DECIMAL(38, 0), 
	country_region VARCHAR(250)
)

/*
3 rows from databank_demographics table:
iso3166_1	iso3166_2	fips	latitude	longitude	state	county	total_population	total_male_population	total_female_population	country_region
AF	None	None	33.0	65.0	None	None	37172386	19093281	18079105	Afghanistan
AL	None	None	41.0	20.0	None	None	2866376	1460043	1406333	Albania
DZ	None	None	28.0	3.0	None	None	42228429	21332000	20896429	Algeria
*/


CREATE TABLE ecdc_global (
	country_region VARCHAR(16777216), 
	continentexp VARCHAR(16777216), 
	iso3166_1 VARCHAR(2), 
	cases FLOAT, 
	deaths FLOAT, 
	cases_since_prev_day FLOAT, 
	deaths_since_prev_day FLOAT, 
	population FLOAT, 
	date DATE, 
	last_update_date TIMESTAMP_NTZ, 
	last_reported_flag BOOLEAN
)

/*
3 rows from ecdc_global table:
country_region	continentexp	iso3166_1	cases	deaths	cases_since_prev_day	deaths_since_prev_day	population	date	last_update_date	last_reported_flag
Albania	Europe	AL	788.0	14.0	0.0	0.0	2862427.0	2020-12-14	2023-07-19 00:02:41.691935	True
Albania	Europe	AL	879.0	12.0	91.0	-2.0	2862427.0	2020-12-13	2023-07-19 00:02:41.691935	False
Albania	Europe	AL	802.0	12.0	-77.0	0.0	2862427.0	2020-12-12	2023-07-19 00:02:41.691935	False
*/


CREATE TABLE goog_global_mobility_report (
	country_region VARCHAR(250), 
	province_state VARCHAR(250), 
	iso_3166_1 VARCHAR(2), 
	iso_3166_2 VARCHAR(5), 
	date DATE, 
	grocery_and_pharmacy_change_perc FLOAT, 
	parks_change_perc FLOAT, 
	residential_change_perc FLOAT, 
	retail_and_recreation_change_perc FLOAT, 
	transit_stations_change_perc FLOAT, 
	workplaces_change_perc FLOAT, 
	last_update_date TIMESTAMP_NTZ, 
	last_reported_flag BOOLEAN, 
	sub_region_2 VARCHAR(256)
)

/*
3 rows from goog_global_mobility_report table:
country_region	province_state	iso_3166_1	iso_3166_2	date	grocery_and_pharmacy_change_perc	parks_change_perc	residential_change_perc	retail_and_recreation_change_perc	transit_stations_change_perc	workplaces_change_perc	last_update_date	last_reported_flag	sub_region_2
United States	Tennessee	US	TN	2022-02-24	6.0	-29.0	6.0	-10.0	-30.0	-7.0	2023-07-19 00:03:51.877503	False	Montgomery County
United States	Tennessee	US	TN	2022-02-25	3.0	-17.0	4.0	-6.0	-27.0	1.0	2023-07-19 00:03:51.877503	False	Montgomery County
United States	Tennessee	US	TN	2022-02-26	10.0	58.0	1.0	0.0	-23.0	11.0	2023-07-19 00:03:51.877503	False	Montgomery County
*/

History: ```{history}```

Input: ```{input}```

Answer:
"""

prompt_qa = PromptTemplate(template=template_qa, input_variables=["history", "input"])

In [27]:
from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationChain


q_llm = ChatOpenAI(
    model_name="gpt-3.5-turbo-16k",
    temperature=0.1,
    max_tokens=500
)

conv_chain = ConversationChain(
   llm=q_llm,
   prompt=prompt_qa,
#    verbose=True
)

In [18]:
chat_history = []
question = """Now to get started, please briefly introduce yourself, describe the database at a high level. Then provide 3 example questions using bullet points. this reponse without query. Write your response in markdown format."""
output = conv_chain.predict(input=question)

# show answer
print(output)

Hello! I'm an AI assistant specializing in data analysis with Snowflake SQL. I can help you with various data analysis tasks using SQL queries in the Snowflake environment. Here's a high-level overview of the database:

- The `databank_demographics` table contains demographic information such as ISO codes, latitude, longitude, population, and gender breakdown.
- The `ecdc_global` table provides global COVID-19 data, including cases, deaths, population, and dates.
- The `goog_global_mobility_report` table tracks mobility changes in different regions, including changes in grocery and pharmacy visits, parks, residential areas, retail and recreation, transit stations, and workplaces.

Now, let's move on to the example questions:

1. What is the total population in the `databank_demographics` table?
2. How many COVID-19 cases were reported in Albania on a specific date in the `ecdc_global` table?
3. What was the percentage change in workplace visits in Tennessee, United States, on a specifi

In [28]:
output = conv_chain.predict(input="What is the total number of COVID-19 cases and deaths for each country in the `ecdc_global` table?")
print(output)

To find the total number of COVID-19 cases and deaths for each country in the `ecdc_global` table, you can use the following SQL query:

```sql
SELECT country_region, SUM(cases) AS total_cases, SUM(deaths) AS total_deaths
FROM ecdc_global
GROUP BY country_region;
```

This query selects the `country_region` column from the `ecdc_global` table and calculates the sum of the `cases` and `deaths` columns for each country using the `SUM` function. The `GROUP BY` clause groups the results by the `country_region` column.

Please note that this query assumes that the `ecdc_global` table contains the relevant data for COVID-19 cases and deaths.


## Executing sql into snowflake and display dataframe

In [29]:
from sqlalchemy import create_engine
import pandas as pd 


engine = create_engine(uri_snow)
df = pd.read_sql(get_sql(output), engine)
df

Unnamed: 0,country_region,total_cases,total_deaths
0,Albania,48530.0,1003.0
1,Algeria,92102.0,2596.0
2,Andorra,7338.0,79.0
3,Angola,16188.0,371.0
4,Anguilla,10.0,0.0
...,...,...,...
208,Poland,1135676.0,22864.0
209,Slovenia,96314.0,1459.0
210,"Virgin Islands, U.S.",1807.0,23.0
211,Western Sahara,766.0,1.0
