In [2]:
import os

from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.document_loaders import DirectoryLoader
from pydantic import BaseModel

# Load environments variables
from dotenv import load_dotenv
load_dotenv("../.streamlit/secrets.toml") 

True

## Get database information using langchain.SQLDatabase
Snowflake Database has 43 tables 

In [3]:
from snowflake.sqlalchemy import URL
from langchain import OpenAI, SQLDatabase, SQLDatabaseChain


# create snowflake connection
uri_snow = URL(
    account=os.getenv("account"),
    user=os.getenv("user"),
    password=os.getenv("password"),
    database=os.getenv("database"),
    schema=os.getenv("schema"),
    warehouse=os.getenv("warehouse"),
    role=os.getenv("role"),
)

# sample 2 tables
tables = ["ecdc_global", "goog_global_mobility_report"]
db = SQLDatabase.from_uri(uri_snow, include_tables=tables)

# sample all tables
# db = SQLDatabase.from_uri(uri_snow)

In [4]:
# get DDL and 3 rows samples for every table
db_info = db.table_info
db_info

'\nCREATE TABLE ecdc_global (\n\tcountry_region VARCHAR(16777216), \n\tcontinentexp VARCHAR(16777216), \n\tiso3166_1 VARCHAR(2), \n\tcases FLOAT, \n\tdeaths FLOAT, \n\tcases_since_prev_day FLOAT, \n\tdeaths_since_prev_day FLOAT, \n\tpopulation FLOAT, \n\tdate DATE, \n\tlast_update_date TIMESTAMP_NTZ, \n\tlast_reported_flag BOOLEAN\n)\n\n/*\n3 rows from ecdc_global table:\ncountry_region\tcontinentexp\tiso3166_1\tcases\tdeaths\tcases_since_prev_day\tdeaths_since_prev_day\tpopulation\tdate\tlast_update_date\tlast_reported_flag\nAfghanistan\tAsia\tAF\t746.0\t6.0\t0.0\t0.0\t38041757.0\t2020-12-14\t2023-07-15 00:04:02.382733\tTrue\nAfghanistan\tAsia\tAF\t298.0\t9.0\t-448.0\t3.0\t38041757.0\t2020-12-13\t2023-07-15 00:04:02.382733\tFalse\nAfghanistan\tAsia\tAF\t113.0\t11.0\t-185.0\t2.0\t38041757.0\t2020-12-12\t2023-07-15 00:04:02.382733\tFalse\n*/\n\n\nCREATE TABLE goog_global_mobility_report (\n\tcountry_region VARCHAR(250), \n\tprovince_state VARCHAR(250), \n\tiso_3166_1 VARCHAR(2), \n\tiso

In [5]:
# save database info into
with open("../docs/database_info.txt", "w") as fp:
    fp.write(db_info)

In [6]:
# read file
with open("../docs/database_info.txt", "r") as fp:
    text_file = fp.read()

### Create langchain documents from tables DDL 

In [7]:
from langchain.text_splitter import CharacterTextSplitter

text_splitter = CharacterTextSplitter(        
    separator = "\n\n\n",
    chunk_size = 0,
    chunk_overlap = 0,
    length_function = len
)

# Split document
# texts_split = text_file.split("\n\n\n")
texts_split = text_splitter.split_text(text_file)
metadatas = [ {"document": "database_info.txt"} for _ in texts_split]
docs = text_splitter.create_documents(texts_split, metadatas)
docs

Created a chunk of size 779, which is longer than the specified 0


[Document(page_content='CREATE TABLE ecdc_global (\n\tcountry_region VARCHAR(16777216), \n\tcontinentexp VARCHAR(16777216), \n\tiso3166_1 VARCHAR(2), \n\tcases FLOAT, \n\tdeaths FLOAT, \n\tcases_since_prev_day FLOAT, \n\tdeaths_since_prev_day FLOAT, \n\tpopulation FLOAT, \n\tdate DATE, \n\tlast_update_date TIMESTAMP_NTZ, \n\tlast_reported_flag BOOLEAN\n)\n\n/*\n3 rows from ecdc_global table:\ncountry_region\tcontinentexp\tiso3166_1\tcases\tdeaths\tcases_since_prev_day\tdeaths_since_prev_day\tpopulation\tdate\tlast_update_date\tlast_reported_flag\nAfghanistan\tAsia\tAF\t746.0\t6.0\t0.0\t0.0\t38041757.0\t2020-12-14\t2023-07-15 00:04:02.382733\tTrue\nAfghanistan\tAsia\tAF\t298.0\t9.0\t-448.0\t3.0\t38041757.0\t2020-12-13\t2023-07-15 00:04:02.382733\tFalse\nAfghanistan\tAsia\tAF\t113.0\t11.0\t-185.0\t2.0\t38041757.0\t2020-12-12\t2023-07-15 00:04:02.382733\tFalse\n*/', metadata={'document': 'database_info.txt'}),
 Document(page_content='CREATE TABLE goog_global_mobility_report (\n\tcountry_r

### Create embeddings and save database in local using Chroma and OpenAIEmbeddings
- Generate the embeddings using openAI with documents from above cell 
- Create the database in Chroma (local)

In [8]:
# remove folder and avoid conflicts
! rmdir /s /q "../chroma_db"

# create object for embeddings using OpenAI
embeddings = OpenAIEmbeddings()

# create database and save embeddings in local
vector_store = Chroma.from_documents(docs, embeddings, persist_directory="../chroma_db")
vector_store.persist()

# get first embedding in database 
# vector_store.get(limit=1, include=['embeddings', 'documents', 'metadatas'])
vector_store.get(limit=1)

El sistema no puede encontrar el archivo especificado.


{'ids': ['95a7896f-236b-11ee-9067-ac7ed0d21e7b'],
 'embeddings': None,
 'documents': ['CREATE TABLE ecdc_global (\n\tcountry_region VARCHAR(16777216), \n\tcontinentexp VARCHAR(16777216), \n\tiso3166_1 VARCHAR(2), \n\tcases FLOAT, \n\tdeaths FLOAT, \n\tcases_since_prev_day FLOAT, \n\tdeaths_since_prev_day FLOAT, \n\tpopulation FLOAT, \n\tdate DATE, \n\tlast_update_date TIMESTAMP_NTZ, \n\tlast_reported_flag BOOLEAN\n)\n\n/*\n3 rows from ecdc_global table:\ncountry_region\tcontinentexp\tiso3166_1\tcases\tdeaths\tcases_since_prev_day\tdeaths_since_prev_day\tpopulation\tdate\tlast_update_date\tlast_reported_flag\nAfghanistan\tAsia\tAF\t746.0\t6.0\t0.0\t0.0\t38041757.0\t2020-12-14\t2023-07-15 00:04:02.382733\tTrue\nAfghanistan\tAsia\tAF\t298.0\t9.0\t-448.0\t3.0\t38041757.0\t2020-12-13\t2023-07-15 00:04:02.382733\tFalse\nAfghanistan\tAsia\tAF\t113.0\t11.0\t-185.0\t2.0\t38041757.0\t2020-12-12\t2023-07-15 00:04:02.382733\tFalse\n*/'],
 'metadatas': [{'document': 'database_info.txt'}]}

In [9]:
vector_store._collection.count()

2

In [38]:
# load from disk 
vector_store = Chroma(persist_directory="../chroma_db", embedding_function=embeddings)
docs = vector_store.similarity_search("which table contains the names of countries")
docs

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


[Document(page_content='CREATE TABLE ecdc_global (\n\tcountry_region VARCHAR(16777216), \n\tcontinentexp VARCHAR(16777216), \n\tiso3166_1 VARCHAR(2), \n\tcases FLOAT, \n\tdeaths FLOAT, \n\tcases_since_prev_day FLOAT, \n\tdeaths_since_prev_day FLOAT, \n\tpopulation FLOAT, \n\tdate DATE, \n\tlast_update_date TIMESTAMP_NTZ, \n\tlast_reported_flag BOOLEAN\n)\n\n/*\n3 rows from ecdc_global table:\ncountry_region\tcontinentexp\tiso3166_1\tcases\tdeaths\tcases_since_prev_day\tdeaths_since_prev_day\tpopulation\tdate\tlast_update_date\tlast_reported_flag\nAfghanistan\tAsia\tAF\t746.0\t6.0\t0.0\t0.0\t38041757.0\t2020-12-14\t2023-07-15 00:04:02.382733\tTrue\nAfghanistan\tAsia\tAF\t298.0\t9.0\t-448.0\t3.0\t38041757.0\t2020-12-13\t2023-07-15 00:04:02.382733\tFalse\nAfghanistan\tAsia\tAF\t113.0\t11.0\t-185.0\t2.0\t38041757.0\t2020-12-12\t2023-07-15 00:04:02.382733\tFalse\n*/', metadata={'document': 'database_info.txt'}),
 Document(page_content='CREATE TABLE goog_global_mobility_report (\n\tcountry_r

In [39]:
# search document which better match from question
vector_store.similarity_search("which table contains the id for payment")

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


[Document(page_content='CREATE TABLE goog_global_mobility_report (\n\tcountry_region VARCHAR(250), \n\tprovince_state VARCHAR(250), \n\tiso_3166_1 VARCHAR(2), \n\tiso_3166_2 VARCHAR(5), \n\tdate DATE, \n\tgrocery_and_pharmacy_change_perc FLOAT, \n\tparks_change_perc FLOAT, \n\tresidential_change_perc FLOAT, \n\tretail_and_recreation_change_perc FLOAT, \n\ttransit_stations_change_perc FLOAT, \n\tworkplaces_change_perc FLOAT, \n\tlast_update_date TIMESTAMP_NTZ, \n\tlast_reported_flag BOOLEAN, \n\tsub_region_2 VARCHAR(256)\n)\n\n/*\n3 rows from goog_global_mobility_report table:\ncountry_region\tprovince_state\tiso_3166_1\tiso_3166_2\tdate\tgrocery_and_pharmacy_change_perc\tparks_change_perc\tresidential_change_perc\tretail_and_recreation_change_perc\ttransit_stations_change_perc\tworkplaces_change_perc\tlast_update_date\tlast_reported_flag\tsub_region_2\nUnited States\tTennessee\tUS\tTN\t2022-02-24\t6.0\t-29.0\t6.0\t-10.0\t-30.0\t-7.0\t2023-07-15 00:04:23.220323\tFalse\tMontgomery County

# Create chat

In [60]:
from langchain.prompts.prompt import PromptTemplate

template_questions = """Considering the provided chat history and a subsequent question, rewrite the follow-up question to be an independent query. Alternatively, conclude the conversation if it appears to be complete.
Chat History:\"""
{chat_history}
\"""
Follow Up Input: \"""
{question}
\"""
Standalone question:"""


template_qa = """ 
You will be acting as an AI Snowflake SQL Expert. 
Your goal is to give correct, executable sql query to users.
When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate.
When asked about your capabilities, provide a general overview of your ability to assist with data analysis tasks using Snowflake SQL, instead of performing specific SQL queries. 
Based on the question provided, if it pertains to data analysis or SQL tasks, generate SQL code that is compatible with the Snowflake environment with limit 10. Additionally, offer a brief explanation about how you arrived at the SQL code. If the required column isn't explicitly stated in the context, suggest an alternative using available columns, but do not assume the existence of any columns that are not mentioned. Also, do not modify the database in any way (no insert, update, or delete operations). You are only allowed to query the database. Refrain from using the information schema.
If the question or context does not clearly involve SQL or data analysis tasks, respond appropriately without generating SQL queries. 
When the user expresses gratitude or says "Thanks", interpret it as a signal to conclude the conversation. Respond with an appropriate closing statement without generating further SQL queries.
If you don't know the answer, simply state, "I'm sorry, I don't know the answer to your question."
Don't forget to use "ilike %keyword%" for fuzzy match queries (especially for variable_name column)
Write your response in markdown format.

Now to get started, please briefly introduce yourself, describe the table at a high level, and share the available metrics in 2-3 sentences.
Then provide 3 example questions using bullet points.

Question: ```{question}```
{context}


Answer:
"""

condense_question_prompt = PromptTemplate.from_template(template_questions)
prompt_qa = PromptTemplate(template=template_qa, input_variables=["question", "context"])

In [61]:
condense_question_prompt

PromptTemplate(input_variables=['chat_history', 'question'], output_parser=None, partial_variables={}, template='Considering the provided chat history and a subsequent question, rewrite the follow-up question to be an independent query. Alternatively, conclude the conversation if it appears to be complete.\nChat History:"""\n{chat_history}\n"""\nFollow Up Input: """\n{question}\n"""\nStandalone question:', template_format='f-string', validate_template=True)

In [58]:
from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.question_answering import load_qa_chain

q_llm = ChatOpenAI(
    model_name="gpt-3.5-turbo-16k",
    temperature=0.1,
    max_tokens=500
)

llm = ChatOpenAI(
    model_name="gpt-3.5-turbo",
    temperature=0.5,
    max_tokens=500,
    # streaming=True,
)

question_generator = LLMChain(llm=q_llm, prompt=condense_question_prompt)

doc_chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=prompt_qa)
conv_chain = ConversationalRetrievalChain(
    retriever=vector_store.as_retriever(),
    combine_docs_chain=doc_chain,
    question_generator=question_generator
)

In [59]:
result = conv_chain(
            {"question": "hola", "chat_history": []}
        )
answer = result["answer"]
print(answer)

Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


Hello! I am an AI Snowflake SQL Expert and I'm here to assist you with your data analysis tasks using Snowflake SQL. I have access to two tables: `goog_global_mobility_report` and `ecdc_global`. 

The `goog_global_mobility_report` table contains information about global mobility trends, such as changes in grocery and pharmacy visits, parks visits, residential activity, retail and recreation visits, transit station visits, and workplace visits. It also includes details like the country, province/state, ISO codes, date, last update date, and sub-region.

The `ecdc_global` table contains global COVID-19 data, including the number of cases, deaths, cases since the previous day, deaths since the previous day, population, date, last update date, and a flag indicating if the data was last reported.

Now, let's move on to your questions! Please keep in mind that I can only provide SQL queries for data analysis tasks.

1. What are the countries with the highest number of COVID-19 cases?
2. Can 

In [48]:
result

{'question': 'Give me the different countries in table ecdc_global',
 'chat_history': [],
 'answer': "To retrieve the different countries in the `ecdc_global` table, you can use the following SQL query:\n\n```sql\nSELECT DISTINCT country_region\nFROM ecdc_global\nLIMIT 10;\n```\n\nThis query selects the distinct values from the `country_region` column in the `ecdc_global` table. The `DISTINCT` keyword ensures that only unique values are returned. The `LIMIT 10` clause limits the result to 10 rows.\n\nPlease note that the query may return fewer than 10 rows if there are fewer unique countries in the table.\n\nLet me know if there's anything else I can help you with!",
 'source_documents': [Document(page_content='CREATE TABLE ecdc_global (\n\tcountry_region VARCHAR(16777216), \n\tcontinentexp VARCHAR(16777216), \n\tiso3166_1 VARCHAR(2), \n\tcases FLOAT, \n\tdeaths FLOAT, \n\tcases_since_prev_day FLOAT, \n\tdeaths_since_prev_day FLOAT, \n\tpopulation FLOAT, \n\tdate DATE, \n\tlast_update_