## Load environment variable

In [17]:
# Load environments variables
from dotenv import load_dotenv
load_dotenv("../.streamlit/secrets.toml") 

True

## Get database information using langchain.SQLDatabase
Snowflake Database has 43 tables 

In [21]:
import os
from snowflake.sqlalchemy import URL
from langchain import OpenAI, SQLDatabase, SQLDatabaseChain

# create snowflake connection uri
uri_snow = URL(
    account=os.getenv("account"),
    user=os.getenv("user"),
    password=os.getenv("password"),
    database=os.getenv("database"),
    schema=os.getenv("schema"),
    warehouse=os.getenv("warehouse"),
    role=os.getenv("role"),
)

# generate prompt 2 tables
# tables = ["ecdc_global", "goog_global_mobility_report", "databank_demographics", "demographics"]
# db = SQLDatabase.from_uri(uri_snow, include_tables=tables)
os.getenv("database")
# sample all tables
# db = SQLDatabase.from_uri(uri_snow)

'ORDERS_SAMPLE'

In [4]:
# get DDL and 3 rows samples for every table
db_info = db.table_info
db_info

'\nCREATE TABLE customers (\n\tcustomer_id DECIMAL(38, 0) NOT NULL, \n\tcustomer_name VARCHAR(50), \n\taddress VARCHAR(100), \n\tphone_number VARCHAR(15), \n\tCONSTRAINT "SYS_CONSTRAINT_5baa0130-6d36-4da3-bd78-919bb9520841" PRIMARY KEY (customer_id)\n)\n\n/*\n3 rows from customers table:\ncustomer_id\tcustomer_name\taddress\tphone_number\n1\tJohn Doe\t123 Main St\t123-456-7890\n2\tJane Smith\t456 Elm St\t987-654-3210\n3\tMichael Johnson\t789 Oak Ave\t555-123-4567\n*/\n\n\nCREATE TABLE orderitems (\n\torder_item_id DECIMAL(38, 0) NOT NULL, \n\torder_id DECIMAL(38, 0), \n\tproduct_name VARCHAR(50), \n\tquantity DECIMAL(38, 0), \n\tCONSTRAINT "SYS_CONSTRAINT_76efb4dc-883b-41fc-97f2-d91ebbbbefc2" PRIMARY KEY (order_item_id), \n\tCONSTRAINT "SYS_CONSTRAINT_f3b7502c-198a-4713-b687-9848113b5af3" FOREIGN KEY(order_id) REFERENCES orders (order_id)\n)\n\n/*\n3 rows from orderitems table:\norder_item_id\torder_id\tproduct_name\tquantity\n1\t1\tProduct A\t2\n2\t2\tProduct B\t1\n3\t3\tProduct C\t3\

In [5]:
# save database info into
with open("../docs/database_info.txt", "w") as fp:
    fp.write(db_info)

In [6]:
# read file
with open("../docs/database_info.txt", "r") as fp:
    text_file = fp.read()

### Create langchain documents from tables DDL 

In [7]:
from langchain.text_splitter import CharacterTextSplitter

text_splitter = CharacterTextSplitter(        
    separator = "\n\n\n",
    chunk_size = 0,
    chunk_overlap = 0,
    length_function = len
)

# Split document
# texts_split = text_file.split("\n\n\n")
texts_split = text_splitter.split_text(text_file)
metadatas = [ {"document": "database_info.txt"} for _ in texts_split]
docs = text_splitter.create_documents(texts_split, metadatas)
docs

Created a chunk of size 438, which is longer than the specified 0
Created a chunk of size 493, which is longer than the specified 0


[Document(page_content='CREATE TABLE customers (\n\tcustomer_id DECIMAL(38, 0) NOT NULL, \n\tcustomer_name VARCHAR(50), \n\taddress VARCHAR(100), \n\tphone_number VARCHAR(15), \n\tCONSTRAINT "SYS_CONSTRAINT_5baa0130-6d36-4da3-bd78-919bb9520841" PRIMARY KEY (customer_id)\n)\n\n/*\n3 rows from customers table:\ncustomer_id\tcustomer_name\taddress\tphone_number\n1\tJohn Doe\t123 Main St\t123-456-7890\n2\tJane Smith\t456 Elm St\t987-654-3210\n3\tMichael Johnson\t789 Oak Ave\t555-123-4567\n*/', metadata={'document': 'database_info.txt'}),
 Document(page_content='CREATE TABLE orderitems (\n\torder_item_id DECIMAL(38, 0) NOT NULL, \n\torder_id DECIMAL(38, 0), \n\tproduct_name VARCHAR(50), \n\tquantity DECIMAL(38, 0), \n\tCONSTRAINT "SYS_CONSTRAINT_76efb4dc-883b-41fc-97f2-d91ebbbbefc2" PRIMARY KEY (order_item_id), \n\tCONSTRAINT "SYS_CONSTRAINT_f3b7502c-198a-4713-b687-9848113b5af3" FOREIGN KEY(order_id) REFERENCES orders (order_id)\n)\n\n/*\n3 rows from orderitems table:\norder_item_id\torder_

### Create embeddings and save database in local using Chroma and OpenAIEmbeddings
- Generate the embeddings using openAI with documents from above cell 
- Create the database in Chroma (local)

In [8]:
# import Chroma Library that allow to store vector database in local
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings

# remove folder and avoid conflicts
! rmdir /s /q "../chroma_db"

# create object for embeddings using OpenAI
embeddings = OpenAIEmbeddings()

# create database and save embeddings in local
vector_store = Chroma.from_documents(docs, embeddings, persist_directory="../chroma_db")
vector_store.persist()

# get first embedding in database 
# vector_store.get(limit=1, include=['embeddings', 'documents', 'metadatas'])
vector_store.get(limit=1)

{'ids': ['1d215242-263a-11ee-91fe-ac7ed0d21e7b'],
 'embeddings': None,
 'documents': ['CREATE TABLE customers (\n\tcustomer_id DECIMAL(38, 0) NOT NULL, \n\tcustomer_name VARCHAR(50), \n\taddress VARCHAR(100), \n\tphone_number VARCHAR(15), \n\tCONSTRAINT "SYS_CONSTRAINT_5baa0130-6d36-4da3-bd78-919bb9520841" PRIMARY KEY (customer_id)\n)\n\n/*\n3 rows from customers table:\ncustomer_id\tcustomer_name\taddress\tphone_number\n1\tJohn Doe\t123 Main St\t123-456-7890\n2\tJane Smith\t456 Elm St\t987-654-3210\n3\tMichael Johnson\t789 Oak Ave\t555-123-4567\n*/'],
 'metadatas': [{'document': 'database_info.txt'}]}

In [9]:
vector_store._collection.count()

3

In [23]:
# load from disk 
vector_store = Chroma(persist_directory="../chroma_db", embedding_function=embeddings)
docs = vector_store.similarity_search("which table contains the names of countries")
docs

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


[Document(page_content='CREATE TABLE ecdc_global (\n\tcountry_region VARCHAR(16777216), \n\tcontinentexp VARCHAR(16777216), \n\tiso3166_1 VARCHAR(2), \n\tcases FLOAT, \n\tdeaths FLOAT, \n\tcases_since_prev_day FLOAT, \n\tdeaths_since_prev_day FLOAT, \n\tpopulation FLOAT, \n\tdate DATE, \n\tlast_update_date TIMESTAMP_NTZ, \n\tlast_reported_flag BOOLEAN\n)\n\n/*\n3 rows from ecdc_global table:\ncountry_region\tcontinentexp\tiso3166_1\tcases\tdeaths\tcases_since_prev_day\tdeaths_since_prev_day\tpopulation\tdate\tlast_update_date\tlast_reported_flag\nAfghanistan\tAsia\tAF\t746.0\t6.0\t0.0\t0.0\t38041757.0\t2020-12-14\t2023-07-18 00:03:41.837110\tTrue\nAfghanistan\tAsia\tAF\t298.0\t9.0\t-448.0\t3.0\t38041757.0\t2020-12-13\t2023-07-18 00:03:41.837110\tFalse\nAfghanistan\tAsia\tAF\t113.0\t11.0\t-185.0\t2.0\t38041757.0\t2020-12-12\t2023-07-18 00:03:41.837110\tFalse\n*/', metadata={'document': 'database_info.txt'}),
 Document(page_content='CREATE TABLE goog_global_mobility_report (\n\tcountry_r

In [25]:
# search document which better match from question
vector_store.similarity_search("which table country has more deaths")

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


[Document(page_content='CREATE TABLE ecdc_global (\n\tcountry_region VARCHAR(16777216), \n\tcontinentexp VARCHAR(16777216), \n\tiso3166_1 VARCHAR(2), \n\tcases FLOAT, \n\tdeaths FLOAT, \n\tcases_since_prev_day FLOAT, \n\tdeaths_since_prev_day FLOAT, \n\tpopulation FLOAT, \n\tdate DATE, \n\tlast_update_date TIMESTAMP_NTZ, \n\tlast_reported_flag BOOLEAN\n)\n\n/*\n3 rows from ecdc_global table:\ncountry_region\tcontinentexp\tiso3166_1\tcases\tdeaths\tcases_since_prev_day\tdeaths_since_prev_day\tpopulation\tdate\tlast_update_date\tlast_reported_flag\nAfghanistan\tAsia\tAF\t746.0\t6.0\t0.0\t0.0\t38041757.0\t2020-12-14\t2023-07-18 00:03:41.837110\tTrue\nAfghanistan\tAsia\tAF\t298.0\t9.0\t-448.0\t3.0\t38041757.0\t2020-12-13\t2023-07-18 00:03:41.837110\tFalse\nAfghanistan\tAsia\tAF\t113.0\t11.0\t-185.0\t2.0\t38041757.0\t2020-12-12\t2023-07-18 00:03:41.837110\tFalse\n*/', metadata={'document': 'database_info.txt'}),
 Document(page_content='CREATE TABLE goog_global_mobility_report (\n\tcountry_r

# Create chat

In [10]:
from langchain.prompts.prompt import PromptTemplate

template_questions = """Considering the provided chat history and a subsequent question, rewrite the follow-up question to be an independent query. Alternatively, conclude the conversation if it appears to be complete.
Chat History:\"""
{chat_history}
\"""
Follow Up Input: \"""
{question}
\"""
Standalone question:"""


template_qa = """ 
You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate.
When asked about your capabilities, provide a general overview of your ability to assist with data analysis tasks using Snowflake SQL, instead of performing specific SQL queries. 
Based on the question provided, if it pertains to data analysis or SQL tasks, generate SQL code that is compatible with the Snowflake environment. Additionally, offer a brief explanation about how you arrived at the SQL code. If the required column isn't explicitly stated in the context, suggest an alternative using available columns, but do not assume the existence of any columns that are not mentioned. Also, do not modify the database in any way (no insert, update, or delete operations). You are only allowed to query the database. Refrain from using the information schema.
If the question or context does not clearly involve SQL or data analysis tasks, respond appropriately without generating SQL queries. 
When the user expresses gratitude or says "Thanks", interpret it as a signal to conclude the conversation. Respond with an appropriate closing statement without generating further SQL queries.
If you don't know the answer, simply state, "I'm sorry, I don't know the answer to your question."
Write your response in markdown format.

Question: ```{question}```
{context}

Answer:
"""

condense_question_prompt = PromptTemplate.from_template(template_questions)
prompt_qa = PromptTemplate(template=template_qa, input_variables=["question", "context"])

In [40]:
condense_question_prompt

PromptTemplate(input_variables=['chat_history', 'question'], output_parser=None, partial_variables={}, template='Considering the provided chat history and a subsequent question, rewrite the follow-up question to be an independent query. Alternatively, conclude the conversation if it appears to be complete.\nChat History:"""\n{chat_history}\n"""\nFollow Up Input: """\n{question}\n"""\nStandalone question:', template_format='f-string', validate_template=True)

In [46]:
from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.question_answering import load_qa_chain

q_llm = ChatOpenAI(
    model_name="gpt-3.5-turbo-16k",
    temperature=0.1,
    max_tokens=500
)

llm = ChatOpenAI(
    model_name="gpt-3.5-turbo",
    temperature=0.5,
    max_tokens=500,
    # streaming=True,
)

question_generator = LLMChain(llm=q_llm, prompt=condense_question_prompt)
doc_chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=prompt_qa)

conv_chain = ConversationalRetrievalChain(
    retriever=vector_store.as_retriever(),
    combine_docs_chain=doc_chain,
    question_generator=question_generator
)

In [59]:
chat_history = []
question = """Now to get started, please briefly introduce yourself, describe the database at a high level. Then provide 3 example questions using bullet points. this reponse without query. Write your response in markdown format."""
result = conv_chain(
            {"question": question, "chat_history": chat_history}
        )
answer = result["answer"]

# store the response in chat history
chat_history = [(question, answer)]

# show answer
print(answer)

Hi there! I'm an AI assistant specializing in data analysis with Snowflake SQL. I can help you with various data analysis tasks using SQL queries in the Snowflake environment. Here's a brief overview of the two databases mentioned:

1. `ecdc_global` table:
   - This table contains information about COVID-19 cases and deaths globally.
   - It includes columns such as `country_region`, `continentexp`, `iso3166_1`, `cases`, `deaths`, `cases_since_prev_day`, `deaths_since_prev_day`, `population`, `date`, `last_update_date`, and `last_reported_flag`.

2. `goog_global_mobility_report` table:
   - This table provides mobility data for different regions.
   - It includes columns such as `country_region`, `province_state`, `iso_3166_1`, `iso_3166_2`, `date`, `grocery_and_pharmacy_change_perc`, `parks_change_perc`, `residential_change_perc`, `retail_and_recreation_change_perc`, `transit_stations_change_perc`, `workplaces_change_perc`, `last_update_date`, `last_reported_flag`, and `sub_region_2`.

In [61]:
result = conv_chain(
            {"question": "What is the total number of COVID-19 cases and deaths for each country in the `ecdc_global` table?", "chat_history": []}
        )
answer = result["answer"]
print(answer)


To obtain the total number of COVID-19 cases and deaths for each country in the `ecdc_global` table, you can use the following SQL query:

```sql
SELECT country_region, SUM(cases) AS total_cases, SUM(deaths) AS total_deaths
FROM ecdc_global
GROUP BY country_region;
```

This query selects the `country_region` column and calculates the sum of the `cases` and `deaths` columns for each unique country using the `SUM()` function. The result is grouped by the `country_region` column using the `GROUP BY` clause.

Please note that this query assumes that the `ecdc_global` table contains the necessary data and columns mentioned in the question. If there are any additional requirements or if you need further assistance, please let me know.
