In [29]:
from langchain.llms import OpenAI
from langchain import OpenAI, SQLDatabase, SQLDatabaseChain

# create llm
llm = OpenAI(temperature=0)

In [3]:
from snowflake.sqlalchemy import URL

# create snowflake connection
uri_snow = URL(
    account = 'vs57217.east-us-2.azure',
    user = 'angelreh',
    password = '5dellZE@1',
    database = 'COVID19_EPIDEMIOLOGICAL_DATA',
    schema = 'public',
    warehouse = 'COMPUTE_WH',
    role='ACCOUNTADMIN',
)

tables = ["ecdc_global", "goog_global_mobility_report"]
db = SQLDatabase.from_uri(uri_snow, include_tables=tables)
# print(db.table_info)

# Example using Chain

In [4]:
# create chain
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
db_chain.run("can you give me the numbers of cases by region? try to group by region")



[1m> Entering new SQLDatabaseChain chain...[0m
can you give me the numbers of cases by region? try to group by region
SQLQuery:[32;1m[1;3mSELECT country_region, SUM(cases) AS total_cases FROM ecdc_global GROUP BY country_region ORDER BY total_cases DESC LIMIT 5;[0m
SQLResult: [33;1m[1;3m[('United States', 16256754.0), ('India', 9884100.0), ('Brazil', 6901952.0), ('Russian Federation', 2653928.0), ('France', 2376852.0)][0m
Answer:[32;1m[1;3mThe top 5 regions with the highest number of cases are United States, India, Brazil, Russian Federation, and France.[0m
[1m> Finished chain.[0m


'The top 5 regions with the highest number of cases are United States, India, Brazil, Russian Federation, and France.'

# Example using Agent

In [30]:
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.sql_database import SQLDatabase
from langchain.llms.openai import OpenAI
from langchain.agents import AgentExecutor
from langchain.agents.agent_types import AgentType
from langchain.chat_models import ChatOpenAI

llm = OpenAI(temperature=0)
toolkit = SQLDatabaseToolkit(db=db, llm=llm)

In [31]:
llm = OpenAI(temperature=0)
agent_executor = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
)

In [35]:
agent_executor.run("what is the country with more deaths?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: [0m
Observation: [38;5;200m[1;3mecdc_global, goog_global_mobility_report[0m
Thought:[32;1m[1;3m I should query the schema of the ecdc_global table
Action: sql_db_schema
Action Input: ecdc_global[0m
Observation: [33;1m[1;3m
CREATE TABLE ecdc_global (
	country_region VARCHAR(16777216), 
	continentexp VARCHAR(16777216), 
	iso3166_1 VARCHAR(2), 
	cases FLOAT, 
	deaths FLOAT, 
	cases_since_prev_day FLOAT, 
	deaths_since_prev_day FLOAT, 
	population FLOAT, 
	date DATE, 
	last_update_date TIMESTAMP_NTZ, 
	last_reported_flag BOOLEAN
)

/*
3 rows from ecdc_global table:
country_region	continentexp	iso3166_1	cases	deaths	cases_since_prev_day	deaths_since_prev_day	population	date	last_update_date	last_reported_flag
Afghanistan	Asia	AF	746.0	6.0	0.0	0.0	38041757.0	2020-12-14	2023-07-15 00:04:02.382733	True
Afghanistan	Asia	AF	298.0	9.0	-448.0	3.0	38041757.0	2020-12-13	2023-07-15 00:04:02.

'The country with the most deaths is the United States with 299177 deaths.'