In [2]:
# Agents
from langchain_community.agent_toolkits.sql.base import create_sql_agent
from langchain.agents import AgentExecutor
from langchain.agents.agent_types import AgentType
from langchain.agents import tool
# from langchain_community.chat_models import BedrockChat
from langchain_aws.chat_models import ChatBedrock
from langchain.chains import create_sql_query_chain

# SQL Database Toolkit
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit

# Tracing
from langsmith import traceable


# Few shot prompting
from langchain_community.vectorstores import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_aws import BedrockEmbeddings
from langchain_core.prompts.few_shot import FewShotPromptTemplate
from langchain_core.prompts.few_shot import FewShotChatMessagePromptTemplate
from langchain_core.prompts.chat import SystemMessagePromptTemplate
from langchain_core.prompts import PromptTemplate


from langchain.memory import ConversationBufferMemory
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.prompts.chat import ChatPromptTemplate, MessagesPlaceholder
import langchain
langchain.debug=True
import yaml

In [2]:
config_file = './model_settings.yml'
with open(config_file) as stream:
    try:
        config = yaml.safe_load(stream)
    except:
        print(exec)

config

{'sql_generator': {'model_name': 'anthropic.claude-3-sonnet-20240229-v1:0',
  'model_provider': 'bedrock',
  'temperature': 0,
  'max_tokens': 1024,
  'top_p': 0.8,
  'system_prompt': 'You are an {dialect} expert. Given an input question, create a syntactically correct {dialect} query to run. If you cannot generate the SQL for any reason, DO NOT PROVIDE ANY ADDITIONAL CONTEXT. Just return "NO".\nUnless the user specifies in the question a specific number of examples to obtain, query for at most 10 results using the LIMIT clause as per {dialect}. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is

In [None]:
def get_table_context(
    table_name: str, table_description: str = None, metadata_query: str = None
):
    table = table_name.split(".")
    conn = st.connection("snowflake")
    columns = conn.query(
        f"""
        SELECT COLUMN_NAME, DATA_TYPE FROM {table[0].upper()}.INFORMATION_SCHEMA.COLUMNS
        WHERE TABLE_SCHEMA = '{table[1].upper()}' AND TABLE_NAME = '{table[2].upper()}'
        """,
        show_spinner=False,
    )
    columns = "\n".join(
        [
            f"- **{columns['COLUMN_NAME'][i]}**: {columns['DATA_TYPE'][i]}"
            for i in range(len(columns["COLUMN_NAME"]))
        ]
    )
    context = f"""
    Here is the table name <tableName> {'.'.join(table)} </tableName> <tableDescription>{table_description}</tableDescription>
    Here are the columns of the {'.'.join(table)} <columns>\n\n{columns}\n\n</columns>
    """
    if metadata_query:
        metadata = conn.query(metadata_query, show_spinner=False)
        metadata = "\n".join(
            [
                f"- **{metadata['VARIABLE_NAME'][i]}**: {metadata['DEFINITION'][i]}"
                for i in range(len(metadata["VARIABLE_NAME"]))
            ]
        )
        context = context + f"\n\nAvailable variables by VARIABLE_NAME:\n\n{metadata}"
    return context


In [3]:
# Reference for this code: https://python.langchain.com/v0.1/docs/modules/model_io/prompts/few_shot_examples/

# vectorstore = Chroma()
# vectorstore.delete_collection()

few_shots = [
    {
        "input": "How many rides were taken in spring of 2022?", 
        "query": "SELECT COUNT(*) FROM rides WHERE started_at >= '2022-03-01' AND started_at < '2022-06-01';"
    }, 
    {
        "input": "What was the busiest month?",
        "query": "SELECT EXTRACT(MONTH FROM started_at) AS month, COUNT(*) AS ride_count FROM ride_data GROUP BY month ORDER BY ride_count DESC LIMIT 1;"
    }, 
    {
        "input": "How many rides were taken in the cold night of april 13th of 2023",
        "query": "SELECT COUNT(*) FROM ride_data WHERE started_at >= '2022-11-15 00:00:00' AND started_at < '2022-11-15 06:00:00';"
    }
]


example_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "{input}"),
        ("ai", "{query}"),
    ]
)

print("Example single prompt: ", example_prompt.format(**few_shots[1]))


few_shot_prompt = FewShotChatMessagePromptTemplate(
    examples=few_shots,
    example_prompt=example_prompt,
    input_variables=["input", "top_k"],
    
)

few_shot_prompt.format(input="What ride was on the 30th?")

Example single prompt:  Human: What was the busiest month?
AI: SELECT EXTRACT(MONTH FROM started_at) AS month, COUNT(*) AS ride_count FROM ride_data GROUP BY month ORDER BY ride_count DESC LIMIT 1;


"Human: How many rides were taken in spring of 2022?\nAI: SELECT COUNT(*) FROM rides WHERE started_at >= '2022-03-01' AND started_at < '2022-06-01';\nHuman: What was the busiest month?\nAI: SELECT EXTRACT(MONTH FROM started_at) AS month, COUNT(*) AS ride_count FROM ride_data GROUP BY month ORDER BY ride_count DESC LIMIT 1;\nHuman: How many rides were taken in the cold night of april 13th of 2023\nAI: SELECT COUNT(*) FROM ride_data WHERE started_at >= '2022-11-15 00:00:00' AND started_at < '2022-11-15 06:00:00';"

In [86]:
# Block for table context
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List
from operator import itemgetter

from langchain_openai import ChatOpenAI
# from langchain.chains.openai_tools import create_extraction_chain_pydantic
from langchain.chains import create_extraction_chain_pydantic
from langchain.chains import LLMChain
from langchain.output_parsers import PydanticOutputParser

# from langchain_aws.function_calling import AnthropicFunctions

import os


class Table(BaseModel):
    '''
    Table in the SQL database
    '''
    name: str = Field(description="Name of the table in the database")


def get_tables(tables: List[Table]) -> List[str]:
    '''
    Any post processing with tables goes here.

    @return: List of table names
    '''
    tables = [table.name for table in tables]
    
    return tables

print(os.getcwd())

try:
    table_deets = open('multi_table_context.txt', 'r').read()

except FileNotFoundError as e:
    print(f"Could not find file the table context file. {e}")


table_prompt_template = '''
    Return the names of all the tables that are relevant to the users question.
    The complete set of tables are: 
    {table_info}

    Be generous with the tables you choose and include ALL RELEVANT TABLES even if you don't think
    they are needed

    '''.format(table_info=snowflake_all_engine.get_context())


table_prompt = ChatPromptTemplate.from_messages([
    ("system", table_prompt_template),
    ("human", '{input}'),
    
])

# table_prompt = ChatPromptTemplate(
#     template=table_prompt_template,
#     input_variables=["table_info", "input"],
# 
# )


model = ChatBedrock(
    credentials_profile_name="bedrock-admin", 
    model_id=config['sql_generator']['model_name'], 
    model_kwargs={
        "temperature": config['sql_generator']['temperature'] if config['sql_generator']['temperature'] else 0,
        "top_p": config['sql_generator']['top_p'] if config['sql_generator']['top_p'] else 0.8,
        'max_tokens': config['sql_generator']['max_tokens'] if config['sql_generator']['max_tokens'] else 1024
    },
    region_name="us-east-1",
    )


# structured_llm = model.with_structured_output(Table)
# structured_llm.invoke("Hello")

parser = PydanticOutputParser(pydantic_object=Table)
table_chain = LLMChain(llm=model, prompt=table_prompt)

tchain = table_prompt | model 

# table_chain.run(input="What are the transaction ids?", table_info=snowflake_engine.get_context()['table_info'])
table_chain.invoke(input="What are the transaction ids?")
# table_prompt.invoke({"input": "Hello", "table_info":snowflake_engine.get_context()['table_info']})

/Users/hemanthrajan/Desktop/Projects/doordash-fincopilot/src/backend
[32;1m[1;3m[chain/start][0m [1m[chain:LLMChain] Entering Chain run with input:
[0m{
  "'table_info'": "What are the transaction ids?"
}
[31;1m[1;3m[chain/error][0m [1m[chain:LLMChain] [0ms] Chain run errored with error:
[0m"ValueError(\"Missing some input keys: {'input'}\")Traceback (most recent call last):\n\n\n  File \"/opt/miniconda3/envs/doordash-agent-test/lib/python3.11/site-packages/langchain/chains/base.py\", line 154, in invoke\n    self._validate_inputs(inputs)\n\n\n  File \"/opt/miniconda3/envs/doordash-agent-test/lib/python3.11/site-packages/langchain/chains/base.py\", line 284, in _validate_inputs\n    raise ValueError(f\"Missing some input keys: {missing_keys}\")\n\n\nValueError: Missing some input keys: {'input'}"


ValueError: Missing some input keys: {'input'}

In [25]:
snowflake_engine.get_context()

{'table_info': '\nCREATE TABLE rpt_gl_transactions (\n\tgl_entity VARCHAR(7), \n\ttransaction_id FLOAT NOT NULL, \n\ttransaction_number VARCHAR(552), \n\tdocument_number VARCHAR(16000), \n\ttransaction_source VARCHAR(16000), \n\ttransaction_type VARCHAR(768), \n\ttrx_created_date TIMESTAMP_TZ, \n\tbill_payment_date TIMESTAMP_TZ, \n\ttrx_closed_date TIMESTAMP_TZ, \n\ttransaction_currency VARCHAR(16), \n\tgl_account_number VARCHAR(240), \n\tgl_account_name VARCHAR(372), \n\tgl_account_type VARCHAR(512), \n\tgl_account_parent_id FLOAT, \n\tdepartment_id FLOAT, \n\tdepartment_name VARCHAR(240), \n\tgl_amount FLOAT, \n\tgl_period_name VARCHAR(256), \n\tamortization_schedule_number DECIMAL(6, 0), \n\tamortization_start_date DATE, \n\tamortization_end_date DATE\n)\n\n/*\n1 rows from rpt_gl_transactions table:\ngl_entity\ttransaction_id\ttransaction_number\tdocument_number\ttransaction_source\ttransaction_type\ttrx_created_date\tbill_payment_date\ttrx_closed_date\ttransaction_currency\tgl_acco

In [None]:
langchain.debug = True
sql_chain.invoke({"question": "Do we have any late transactions?", "table_context": table_context, "history": chat_history})

In [19]:
raw_system_prompt = config['sql_generator']['system_prompt']



# Replace this with actual table context
table_context = '''
Here is the table name <tableName> DEMO_ANALYTICS.DEMO.RPT_GL_TRANSACTIONS </tableName> <tableDescription>None</tableDescription>
Here are the columns of the DEMO_ANALYTICS.DEMO.RPT_GL_TRANSACTIONS <columns>

- **GL_ACCOUNT_PARENT_ID**: FLOAT
- **GL_ACCOUNT_NUMBER**: TEXT
- **BILL_PAYMENT_DATE**: TIMESTAMP_TZ
- **AMORTIZATION_SCHEDULE_NUMBER**: NUMBER
- **GL_ACCOUNT_NAME**: TEXT
- **GL_ENTITY**: TEXT
- **GL_PERIOD_NAME**: TEXT
- **TRX_CLOSED_DATE**: TIMESTAMP_TZ
- **DEPARTMENT_NAME**: TEXT
- **DEPARTMENT_ID**: FLOAT
- **TRX_CREATED_DATE**: TIMESTAMP_TZ
- **TRANSACTION_TYPE**: TEXT
- **AMORTIZATION_START_DATE**: DATE
- **TRANSACTION_NUMBER**: TEXT
- **TRANSACTION_CURRENCY**: TEXT
- **GL_AMOUNT**: FLOAT
- **GL_ACCOUNT_TYPE**: TEXT
- **TRANSACTION_ID**: FLOAT
- **DOCUMENT_NUMBER**: TEXT
- **AMORTIZATION_END_DATE**: DATE
- **TRANSACTION_SOURCE**: TEXT

</columns>
'''

# SystemMessagePromptTemplate(prompt=table_context_template)


# template = ChainedPromptTemplate([
#     SystemMessagePromptTemplate.from_template("You have access to {tools}."),
#     ChatPromptTemplate.from_messages([
#         SystemMessagePromptTemplate.from_template("Your objective is to answer human questions."),
#     ]),
#     "Tell me: {question}?",
# ])


# Define your custom prompt template for the smaller llm
custom_prompt_template = "Custom prompt for the smaller llm: {input}"

# Create a new prompt object with the custom prompt template and the table_info variable
new_prompt = ChatPromptTemplate.from_messages([
    ('system', custom_prompt_template),
    ('user', '{table_info}'),
])



chat_history = ChatMessageHistory()

final_prompt = ChatPromptTemplate.from_messages(
    [
        ('system', raw_system_prompt),
        few_shot_prompt,
        # MessagesPlaceholder(variable_name="history"),
        # table_prompt,
        ('user', '{input}'),
    ]
)



# final_prompt.invoke({
#     "input": "select * from table", 
#     "history": chat_history.messages, 
#     # "agent_scratchpad": [],
#     })

final_prompt.format(input="select * from table", history=chat_history.messages, table_info=table_context, dialect='sql')

'System: You are an sql expert. Given an input question, create a syntactically correct sql query to run. If you cannot generate the SQL for any reason, DO NOT PROVIDE ANY ADDITIONAL CONTEXT. Just return "NO".\nUnless the user specifies in the question a specific number of examples to obtain, query for at most 10 results using the LIMIT clause as per sql. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\nPay attention to use date(\'now\') function to get the current date, if the question involves "today".\nDo not make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to th

In [5]:
llm = ChatBedrock(
    credentials_profile_name="bedrock-admin", 
    model_id=config['sql_generator']['model_name'], 
    model_kwargs={
        "temperature": config['sql_generator']['temperature'] if config['sql_generator']['temperature'] else 0,
        "top_p": config['sql_generator']['top_p'] if config['sql_generator']['top_p'] else 0.8,
        'max_tokens': config['sql_generator']['max_tokens'] if config['sql_generator']['max_tokens'] else 1024
    },
    region_name="us-east-1",
    )
llm

llm.invoke("Hello world")

NameError: name 'config' is not defined

In [4]:

def get_snowflake_uri():
    # uri = f"snowflake://{user}:{password}@{account}/{database}/{schema}?warehouse={warehouse}&role={role}"
    # uri = f"snowflake://NISHIKANT:Nishi123#@rla01593/DEMO_ANALYTICS/DEMO?warehouse=COMPUTE_WH&role=DATA_ENGINEER_ROLE"
    uri = f"snowflake://NISHIKANT:Nishi123#@rla01593/FINCOPILOT_CDM/ACCOUNTS_RECEIVABLE?warehouse=COMPUTE_WH&role=DATA_ENGINEER_ROLE"
    return uri


snowflake_uri = get_snowflake_uri() 

snowflake_uri

# snowflake_engine = SQLDatabase.from_uri(
#     database_uri=snowflake_uri, 
#     sample_rows_in_table_info=1, 
#     view_support=True,
#     include_tables=['rpt_gl_transactions']
#     )

'snowflake://NISHIKANT:Nishi123#@rla01593/FINCOPILOT_CDM/ACCOUNTS_RECEIVABLE?warehouse=COMPUTE_WH&role=DATA_ENGINEER_ROLE'

In [22]:
snowflake_all_engine = SQLDatabase.from_uri(
    database_uri=snowflake_uri, 
    sample_rows_in_table_info=1, 
    view_support=True,
    )

In [19]:


# selected_tables = set(['ar_customer_invoices'])
# snowflake_all_engine._all_tables = selected_tables

custom_table_info= {}
for table in selected_tables:
    result = snowflake_all_engine._engine.execute("show VIEWS like 'ar_customer_invoices' in fincopilot_cdm.accounts_receivable")
    rows = result.fetchall()
    if len(rows[0]) > 0:
        custom_table_info[table] = rows[0][7]   # 7th column is the view definition
    # custom_table_info[table] = f"custom info for {table}"

custom_table_info


{'ar_customer_invoices': 'create or replace secure  view FINCOPILOT_CDM.accounts_receivable.ar_customer_invoices\n  \n    \n    \n(\n  \n    "NETSUITE_TRANSACTION_ID" COMMENT $$Unique identifier for each transaction record in the system.$$, \n  \n    "INVOICE_NUMBER" COMMENT $$Unique identifier assigned to each invoice created within the system$$, \n  \n    "INVOICE_POSTED_ON" COMMENT $$Period on which invoice is posted in the accounting system.$$, \n  \n    "INVOICED_TO_CUSTOMER_ID" COMMENT $$Unique identifier assigned to each customer to link the invoice.$$, \n  \n    "INVOICED_TO_CUSTOMER_NAME" COMMENT $$Full name of the customer who is billed for the service provided in invoice.$$, \n  \n    "INVOICED_TO_CUSTOMER_COUNTRY" COMMENT $$Country associated with the customer to whom the invoice is issued.$$, \n  \n    "INVOICED_IN_CURRENCY_CODE" COMMENT $$Currency code in which an invoice is issued.$$, \n  \n    "TOTAL_INVOICED_AMOUNT" COMMENT $$Total amount on an invoice.$$, \n  \n    "T

In [33]:
custom_table_info_cache = {}

result = snowflake_all_engine._engine.execute("show VIEWS in fincopilot_cdm.common")
rows = result.fetchall()

selected_tables = set(['DIM_ACCOUNTING_PERIOD'.lower(), 'DIM_CUSTOMER'.lower()])
print(len(rows))

for row in rows:
    table_name  = row[1].lower()
    description = row[7]
    print(table_name, description[0:10])
    custom_table_info_cache[table_name] = description
    

{k : custom_table_info_cache[k] for k in selected_tables}


5
dim_accounting_period create or 
dim_customer create or 
dim_general_ledger_account create or 
dim_general_ledger_account_hierarchy create or 
dim_subsidiary create or 


{'dim_customer': 'create or replace secure  view FINCOPILOT_CDM.common.dim_customer\n  \n    \n    \n(\n  \n    "CUSTOMER_KEY" COMMENT $$$$, \n  \n    "CUSTOMER_ID" COMMENT $$$$, \n  \n    "CUSTOMER_TYPE" COMMENT $$$$, \n  \n    "CUSTOMER_NAME" COMMENT $$$$, \n  \n    "PARENT_ID" COMMENT $$$$, \n  \n    "BRAND" COMMENT $$$$, \n  \n    "CHILD_DETAILS" COMMENT $$$$\n  \n)\n\n   as (\n    \n\nwith customer as (\n    select \n        customer_key as Customer_Key,\n        customer_id as Customer_Id,\n        customer_type as Customer_Type,\n        customer_name as Customer_Name,\n        parent_id as Parent_Id,\n        brand as Brand,\n        child_customer_details as Child_Details\n   \n    from FINCOPILOT_ANALYTICS.FINANCE.dim_customer\n),\nfinal as (\n    select * from customer\n)\nselect * from final\n  );',
 'dim_accounting_period': 'create or replace secure  view FINCOPILOT_CDM.common.dim_accounting_period\n  \n    \n    \n(\n  \n    "ACCOUNTING_PERIOD_ID" COMMENT $$$$, \n  \n    

In [34]:
snowflake_all_engine._custom_table_info = custom_table_info

In [21]:
snowflake_all_engine.get_table_info()

'create or replace secure  view FINCOPILOT_CDM.accounts_receivable.ar_customer_invoices\n  \n    \n    \n(\n  \n    "NETSUITE_TRANSACTION_ID" COMMENT $$Unique identifier for each transaction record in the system.$$, \n  \n    "INVOICE_NUMBER" COMMENT $$Unique identifier assigned to each invoice created within the system$$, \n  \n    "INVOICE_POSTED_ON" COMMENT $$Period on which invoice is posted in the accounting system.$$, \n  \n    "INVOICED_TO_CUSTOMER_ID" COMMENT $$Unique identifier assigned to each customer to link the invoice.$$, \n  \n    "INVOICED_TO_CUSTOMER_NAME" COMMENT $$Full name of the customer who is billed for the service provided in invoice.$$, \n  \n    "INVOICED_TO_CUSTOMER_COUNTRY" COMMENT $$Country associated with the customer to whom the invoice is issued.$$, \n  \n    "INVOICED_IN_CURRENCY_CODE" COMMENT $$Currency code in which an invoice is issued.$$, \n  \n    "TOTAL_INVOICED_AMOUNT" COMMENT $$Total amount on an invoice.$$, \n  \n    "TOTAL_PAYMENT_RECEIVED_AMO

In [18]:
rows[0][7]

'create or replace secure  view FINCOPILOT_CDM.accounts_receivable.ar_customer_invoices\n  \n    \n    \n(\n  \n    "NETSUITE_TRANSACTION_ID" COMMENT $$Unique identifier for each transaction record in the system.$$, \n  \n    "INVOICE_NUMBER" COMMENT $$Unique identifier assigned to each invoice created within the system$$, \n  \n    "INVOICE_POSTED_ON" COMMENT $$Period on which invoice is posted in the accounting system.$$, \n  \n    "INVOICED_TO_CUSTOMER_ID" COMMENT $$Unique identifier assigned to each customer to link the invoice.$$, \n  \n    "INVOICED_TO_CUSTOMER_NAME" COMMENT $$Full name of the customer who is billed for the service provided in invoice.$$, \n  \n    "INVOICED_TO_CUSTOMER_COUNTRY" COMMENT $$Country associated with the customer to whom the invoice is issued.$$, \n  \n    "INVOICED_IN_CURRENCY_CODE" COMMENT $$Currency code in which an invoice is issued.$$, \n  \n    "TOTAL_INVOICED_AMOUNT" COMMENT $$Total amount on an invoice.$$, \n  \n    "TOTAL_PAYMENT_RECEIVED_AMO

In [6]:
snowflake_all_engine.get_table_names()

  warn_deprecated(


['dim_accounting_period',
 'dim_customer',
 'dim_general_ledger_account',
 'dim_general_ledger_account_hierarchy',
 'dim_subsidiary']

In [82]:
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
chain = create_sql_query_chain(llm, snowflake_all_engine, prompt=final_prompt)
# chain = create_sql_query_chain(llm, snowflake_all_engine)


tables = ['rpt_gl_transactions']        # Get the list of
snowflake_all_engine._all_tables = set(tables)


# custom_table_info = snowflake_all_engine.get_table_info()

# chain_with_custom_info = RunnablePassthrough.assign(table_info=RunnableLambda(lambda _: custom_table_info)) | chain 


chain.invoke({
    'question': "Hello",
    "history": chat_history,
    # 'dialect': snowflake_all_engine.dialect,
    # 'table_info': custom_table_info
})


# chain_with_custom_info.invoke({
#     'question': "Hellp",
#     "history": chat_history
# })

[32;1m[1;3m[chain/start][0m [1m[chain:RunnableSequence] Entering Chain run with input:
[0m[inputs]
[32;1m[1;3m[chain/start][0m [1m[chain:RunnableSequence > chain:RunnableAssign<input,table_info>] Entering Chain run with input:
[0m[inputs]
[32;1m[1;3m[chain/start][0m [1m[chain:RunnableSequence > chain:RunnableAssign<input,table_info> > chain:RunnableParallel<input,table_info>] Entering Chain run with input:
[0m[inputs]
[32;1m[1;3m[chain/start][0m [1m[chain:RunnableSequence > chain:RunnableAssign<input,table_info> > chain:RunnableParallel<input,table_info> > chain:RunnableLambda] Entering Chain run with input:
[0m[inputs]
[36;1m[1;3m[chain/end][0m [1m[chain:RunnableSequence > chain:RunnableAssign<input,table_info> > chain:RunnableParallel<input,table_info> > chain:RunnableLambda] [0ms] Exiting Chain run with output:
[0m{
  "output": "Hello\nSQLQuery: "
}
[32;1m[1;3m[chain/start][0m [1m[chain:RunnableSequence > chain:RunnableAssign<input,table_info> > chain:Ru

'NO'

In [81]:
tables = ['rpt_gl_transactions']
snowflake_all_engine._all_tables = set(tables)
snowflake_all_engine.get_table_names()

snowflake_all_engine._all_tables = set(['bridge_gl_account'])
snowflake_all_engine.get_table_names()

['bridge_gl_account']

In [101]:
snowflake_all_engine._all_tables = set(['dim_general_ledger_account'])
snowflake_all_engine.get_context()

{'table_info': '\nCREATE TABLE dim_general_ledger_account (\n\t"GL_Account_Primary_Key" VARCHAR(32), \n\t"GL_Account_Id" DECIMAL(38, 0), \n\t"GL_Account_Class" VARCHAR(16777216), \n\t"GL_Account_Type" VARCHAR(16777216), \n\t"GL_Account_Name" VARCHAR(16777216), \n\t"GL_Account_Number" VARCHAR(16777216), \n\t"GL_Account_Parent_Id" VARCHAR(16777216), \n\t"GL_Child_Account_Details" ARRAY, \n\t"Is_Debit_Or_Credit_Account" VARCHAR(2), \n\t"Is_Balance_Sheet_Account" BOOLEAN, \n\t"Is_Income_Statement_Account" BOOLEAN, \n\t"Is_Revenue_Account" BOOLEAN\n)\n\n/*\n1 rows from dim_general_ledger_account table:\nGL_Account_Primary_Key\tGL_Account_Id\tGL_Account_Class\tGL_Account_Type\tGL_Account_Name\tGL_Account_Number\tGL_Account_Parent_Id\tGL_Child_Account_Details\tIs_Debit_Or_Credit_Account\tIs_Balance_Sheet_Account\tIs_Income_Statement_Account\tIs_Revenue_Account\n01161aaa0b6d1345dd8fe4e481144d84\t236\tAsset\tOther Current Asset\t1323 - Inventory Asset : New York\t1323\t1320\t[\n  {\n    "full_n

In [79]:
snowflake_all_engine.get_table_names()

['bridge_gl_account',
 'bridge_inventory_item_class',
 'customer360',
 'dim_accounting_period_date',
 'dim_bins',
 'dim_company_location',
 'dim_customer',
 'dim_dayrange',
 'dim_department',
 'dim_employee',
 'dim_entity',
 'dim_finance_account',
 'dim_forecast_actual_date',
 'dim_gl_transaction_date',
 'dim_inventory_item',
 'dim_inventory_snapshot_date',
 'dim_item',
 'dim_locations',
 'dim_master_date',
 'dim_pricerange',
 'dim_product',
 'dim_product_class',
 'dim_revenue_target_date',
 'dim_salesorder_fulfillment_date',
 'dim_salesrep',
 'dim_shiptolocation',
 'dim_units',
 'dim_wmszone',
 'dim_zipcodes',
 'fact_accumulating_salesorder',
 'fact_bin_counts',
 'fact_fulfilmentsalesorder',
 'fact_inventory_item_counts',
 'fact_invoicedsalesorder',
 'fact_quotesalesorder',
 'fact_revenue_actual_target_combined',
 'fact_revenue_target',
 'fact_rmasalesorder',
 'fact_salesorder_operational_efficiency',
 'fact_shippedsalesorder',
 'fact_snapshot_salesorder_monthly',
 'fact_transactions_

In [60]:
final_prompt.input_variables

['dialect', 'history', 'input', 'table_info', 'top_k']

pydantic.v1.main.ChatPromptTemplateOutput

In [57]:
print(custom_table_info)


CREATE TABLE rpt_gl_transactions (
	gl_entity VARCHAR(7), 
	transaction_id FLOAT NOT NULL, 
	transaction_number VARCHAR(552), 
	document_number VARCHAR(16000), 
	transaction_source VARCHAR(16000), 
	transaction_type VARCHAR(768), 
	trx_created_date TIMESTAMP_TZ, 
	bill_payment_date TIMESTAMP_TZ, 
	trx_closed_date TIMESTAMP_TZ, 
	transaction_currency VARCHAR(16), 
	gl_account_number VARCHAR(240), 
	gl_account_name VARCHAR(372), 
	gl_account_type VARCHAR(512), 
	gl_account_parent_id FLOAT, 
	department_id FLOAT, 
	department_name VARCHAR(240), 
	gl_amount FLOAT, 
	gl_period_name VARCHAR(256), 
	amortization_schedule_number DECIMAL(6, 0), 
	amortization_start_date DATE, 
	amortization_end_date DATE
)

/*
1 rows from rpt_gl_transactions table:
gl_entity	transaction_id	transaction_number	document_number	transaction_source	transaction_type	trx_created_date	bill_payment_date	trx_closed_date	transaction_currency	gl_account_number	gl_account_name	gl_account_type	gl_account_parent_id	department

In [48]:
snowflake_engine.table_info

'\nCREATE TABLE rpt_gl_transactions (\n\tgl_entity VARCHAR(7), \n\ttransaction_id FLOAT NOT NULL, \n\ttransaction_number VARCHAR(552), \n\tdocument_number VARCHAR(16000), \n\ttransaction_source VARCHAR(16000), \n\ttransaction_type VARCHAR(768), \n\ttrx_created_date TIMESTAMP_TZ, \n\tbill_payment_date TIMESTAMP_TZ, \n\ttrx_closed_date TIMESTAMP_TZ, \n\ttransaction_currency VARCHAR(16), \n\tgl_account_number VARCHAR(240), \n\tgl_account_name VARCHAR(372), \n\tgl_account_type VARCHAR(512), \n\tgl_account_parent_id FLOAT, \n\tdepartment_id FLOAT, \n\tdepartment_name VARCHAR(240), \n\tgl_amount FLOAT, \n\tgl_period_name VARCHAR(256), \n\tamortization_schedule_number DECIMAL(6, 0), \n\tamortization_start_date DATE, \n\tamortization_end_date DATE\n)\n\n/*\n1 rows from rpt_gl_transactions table:\ngl_entity\ttransaction_id\ttransaction_number\tdocument_number\ttransaction_source\ttransaction_type\ttrx_created_date\tbill_payment_date\ttrx_closed_date\ttransaction_currency\tgl_account_number\tgl_

In [None]:
snowflake_engine.get_usable_table_names()

In [None]:
snowflake_engine.get_table_info()

In [None]:
suffix = """Begin!

Relevant pieces of previous conversation:
{history}
(You do not need to use these pieces of information if not relevant)

Question: {input}
Thought: I should look at the tables in the database to see what I can query.  Then I should query the schema of the most relevant tables.
{agent_scratchpad}"""

memory = ConversationBufferMemory(memory_key = 'history' , input_key = 'input')

In [None]:
sql_agent = create_sql_agent(
    llm=llm,
    db=snowflake_engine,
    verbose=False,
    return_intermediate_steps=True,
    # agent_type="tool-calling",
    # prompt=final_prompt,
    suffix=suffix,
    agent_executor_kwargs={
        'memory': memory, 
        'return_intermediate_steps':True,
    }
    
)



In [None]:
langchain.debug = False
langchain.verbose = False
sample_run = sql_agent.invoke("List all the columns in this table")

In [None]:
sql_agent.astream("List all the columns in this table")



In [None]:
async for chunk in sql_agent.astream("Fetch data for account 2110110 for entity DDE for period Nov 2021."):
    
    # if 'steps' in chunk:
    #     chunk_actions = chunk['steps'][0]
    #     print(chunk_actions.observation)
    # if 'output' in chunk:
    #     print(chunk['output'])
    
    print(chunk, flush=True, end='\n\n\n')
    # print('\n\n\n')

In [None]:
sample_run

In [None]:
steps = sample_run['intermediate_steps']

for agent_action, output in steps:
    print(output)

In [None]:
output = sample_run['output']
print(output) #

In [None]:
from langchain_openai import ChatOpenAI
from langchain.chains.sql_database.prompt import SQLDatabaseChain

import os

os.environ['OPENAI_API_KEY'] = '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'
open_ai_llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)

# chain_prompt_template = '''Given an input question, first create a syntactically correct {dialect} query to run.
# Use the following format:

# Question: "Question here"
# SQLQuery: "SQL Query to run"
# Only use the following tables:

# {table_info}.

# Question: {input}'''

# chain_prompt = PromptTemplate(
#     input_variables=["input", "table_info", "dialect"], template=chain_prompt_template
# )

sql_chain = create_sql_query_chain(
    llm=llm, 
    db=snowflake_engine,
    prompt = final_prompt,
    
)



In this example, the create_dynamic_sql_chain function takes the LLM, the database connection, and the user's input. It then uses the determine_relevant_tables function to identify the relevant tables based on the user's input and the available table information.
The final_prompt is then created using the PromptTemplate, which includes the table_info variable that can be dynamically updated based on the relevant tables.
Finally, the SQLDatabaseChain is created with the dynamic prompt, and the table_names_to_use parameter is set to the relevant tables.
This approach allows you to tailor the prompt and the table information based on the user's input, which can improve the quality of the generated SQL queries.

In [None]:
from langchain.prompts import PromptTemplate
from langchain.chains.sql_database.prompt import SQLDatabaseChain

def create_dynamic_sql_chain(llm, db, user_input):
    # Get the available tables and their schemas
    table_info = db.get_context()

    # Determine the relevant tables based on the user's input
    relevant_tables = determine_relevant_tables(user_input, table_info)

    # Create the prompt template with the relevant table information
    final_prompt = PromptTemplate(
        template="""
        You are a SQL expert. Given the following tables and schemas:

        {table_info}

        And the user question:
        Question: {input}

        Create a SQL query to answer the question, and provide the result.

        Question: {input}
        SQLQuery: 
        SQLResult:
        Answer:
                """,
        input_variables=["input", "table_info"],
    )

    # Create the SQL chain with the dynamic prompt
    sql_chain = SQLDatabaseChain(
        llm=llm,
        database=db,
        prompt=final_prompt,
        input_key="input",
        output_key="result",
        top_k_results=5,
        table_names_to_use=relevant_tables,
    )

    return sql_chain

def determine_relevant_tables(user_input, table_info):
    # Implement your logic to determine the relevant tables based on the user's input
    # This could involve things like keyword matching, entity extraction, etc.
    relevant_tables = ["table1", "table2", "table3"]
    return relevant_tables


In [None]:
langchain.debug = True
sql_chain.invoke({"question": "Do we have any late transactions?", "table_context": table_context, "history": chat_history})

In [None]:
sql_chain.get_prompts()[0]

In [None]:
for i in final_prompt.get_prompts()[0].messages:
    try: 
        print(i.prompt.template) 
    except:
        pass

In [None]:
final_prompt.get_prompts()[0].messages

In [2]:
import instructor
from pydantic import BaseModel, Field
from typing import List


class TableColumn(BaseModel):
    '''
    Gets the table column information from the given schema
    '''
    name: str = Field(description="The name of the column")
    dtype: str = Field(description="The data type of the column")
    comment: str = Field(description="The comment about the column")

class Table(BaseModel):
    '''
    Gets the table object from the given schema
    '''
    name: str = Field(description="The name of the table")
    description: str = Field(description="The description of the table")
    columns: List[TableColumn] = Field(description="The information about the column in each table")






In [9]:
from anthropic import AnthropicBedrock
client = instructor.from_anthropic(AnthropicBedrock())
client = instructor.patch(llm, mode=instructor.Mode.MD_JSON)

AttributeError: module 'instructor' has no attribute 'from_anthropic'

In [3]:
# json functions
import json

dictionary = {
    "test": "1",
    "value": "2"
}




TypeError: the JSON object must be str, bytes or bytearray, not dict

In [7]:
snowflake_uri = f'snowflake://NISHIKANT:Nishi123#@rla01593/FINCOPILOT_CDM/ACCOUNTS_RECEIVABLE?warehouse=COMPUTE_WH&role=DATA_ENGINEER_ROLE'
snowflake_engine = SQLDatabase.from_uri(
    database_uri=snowflake_uri, 
    sample_rows_in_table_info=1, 
    view_support=True,
    # include_tables=['rpt_gl_transactions']
    )
def initialize_custom_table_info_cache() -> dict:
    '''
    Initialize the custom table info cache
    '''
    cache = {}
    
    result = snowflake_engine._engine.execute(f"show VIEWS in fincopilot_cdm.accounts_receivable;")
    rows = result.fetchall()

    for row in rows:
        table_name = row[1].lower()
        description = row[7] 

        # Add "as" filter
        description = description.split('as')[0]
        description = description.strip()
        
        cache[table_name] = description
        

    

    


    # Add sample rows?

    

    return cache

# snowflake_engine._custom_table_info = initialize_custom_table_info_cache()

# snowflake_engine._custom_table_info

ProgrammingError: (snowflake.connector.errors.ProgrammingError) 001059 (22023): SQL compilation error:
Must specify the full search path starting from database for FINCOPILOT_CDM
[SQL: SHOW /* sqlalchemy:_get_schema_primary_keys */PRIMARY KEYS IN SCHEMA fincopilot_cdm]
(Background on this error at: https://sqlalche.me/e/14/f405)

In [None]:
snowflake_engine._get_sample_rows

In [22]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.chat import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
import os

os.environ["LANGCHAIN_API_KEY"] = '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'
os.environ['LANGCHAIN_TRACING_V2'] = 'false'
os.environ['LANGCHAIN_PROJECT'] = 'chain_testing'

openai_key = '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'

prompt_template = "Tell me a {adjective} joke in the {dialect} language and using this table information {table_info}"


database_prompt = ChatPromptTemplate(input_variables=['input'])


final_prompt = ChatPromptTemplate.from_messages([
    ('system', prompt_template),
    # MessagesPlaceholder(variable_name="history"),
    ('human', '{input}'),
    database_prompt
])


llm = ChatOpenAI(openai_api_key=openai_key)
chain = final_prompt | llm | StrOutputParser()

response = chain.invoke({"adjective": "funny", "input": "Tell me something thoughtful."})


## NOTE: Do not try to replicate the snowflake engine stuff here. 
#        It will not work
# Instead, create the regular chain and make a new function that returns this:
# {
#     "dialect": "snowflake",
#     "table_info": "table information"
# }
# With the parameters: (table list, number of rows)



[32;1m[1;3m[chain/start][0m [1m[chain:RunnableSequence] Entering Chain run with input:
[0m{
  "adjective": "funny",
  "input": "Tell me something thoughtful."
}
[32;1m[1;3m[chain/start][0m [1m[chain:RunnableSequence > prompt:ChatPromptTemplate] Entering Prompt run with input:
[0m{
  "adjective": "funny",
  "input": "Tell me something thoughtful."
}
[31;1m[1;3m[chain/error][0m [1m[chain:RunnableSequence > prompt:ChatPromptTemplate] [3ms] Prompt run errored with error:
[0m"KeyError(\"Input to ChatPromptTemplate is missing variables {'table_info', 'dialect'}.  Expected: ['adjective', 'dialect', 'input', 'table_info'] Received: ['adjective', 'input']\")Traceback (most recent call last):\n\n\n  File \"/opt/miniconda3/envs/sql-generator/lib/python3.11/site-packages/langchain_core/runnables/base.py\", line 1599, in _call_with_config\n    context.run(\n\n\n  File \"/opt/miniconda3/envs/sql-generator/lib/python3.11/site-packages/langchain_core/runnables/config.py\", line 380, in 

KeyError: "Input to ChatPromptTemplate is missing variables {'table_info', 'dialect'}.  Expected: ['adjective', 'dialect', 'input', 'table_info'] Received: ['adjective', 'input']"