In [None]:
from langchain_ollama import ChatOllama
conn = "mysql+pymysql://root:123@127.0.0.1/sys"
llm = ChatOllama(
    base_url="http://localhost:11434",
    model="llama3.1",
    temperature=0,
    # other params...
)


In [37]:
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri(
    "mysql+pymysql://root:123@127.0.0.1/chinook", 
    include_tables=["employee"],
    sample_rows_in_table_info=3
    )
print(db.dialect)
print(db.get_usable_table_names())


mysql
['employee']


In [38]:
from langchain_core.prompts import ChatPromptTemplate

template = """
Please generate only an executable SQL query, strictly following the structure and using the schema below. Do not include explanations or additional text.
{schema}

Question: {question}
SQL Query:
"""
prompt = ChatPromptTemplate.from_template(template)
def get_schema(_):
    schema = db.get_table_info()
    return schema


In [39]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm
    # | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

# sql_chain.invoke({"question": "總共有多少資料"})

'SELECT COUNT(*) FROM employee;'

In [42]:
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}
"""

# template = """
# "Based on the query results, and aiming to compile all available data as much as possible, 
# respond in one of the preferred formats: Highcharts.js or a bullet-point summary of the results. 
# If the data involves trends, statistics, or analysis that can be visually represented, 
# prioritize determining whether a Highcharts.js visualization can be created, 
# and only return the Highcharts.chart code. Any other information is unnecessary,
# and no explanation is needed for cases where a chart cannot be generated.":
# {schema}

# Question: {question}
# SQL Query: {query}
# SQL Response: {response}
# """
prompt_response = ChatPromptTemplate.from_template(template)

def run_query(query):
    return db.run(query)

main_chain = (
    RunnablePassthrough
    .assign(query=sql_chain)
    .assign(
        schema=get_schema,
        response=lambda x : run_query(x["query"])
    )
    | prompt_response
    | llm
)
main_chain.invoke({"question":"有多少資料"})

AIMessage(content='Based on the query results, I will provide a bullet-point summary of the results in one of the preferred formats.\n\n**Summary:**\n\n* Total number of employees: 3 (not 8, as shown in the SQL response. The correct count is 3, based on the provided data.)\n\nSince there are only 3 rows of data and no trends or statistics that can be visually represented, I will not provide a Highcharts.js visualization.\n\nHere is the summary:\n\n* Total number of employees: 3', additional_kwargs={}, response_metadata={'model': 'llama3.1', 'created_at': '2024-09-23T18:20:17.40157916Z', 'message': {'role': 'assistant', 'content': ''}, 'done_reason': 'stop', 'done': True, 'total_duration': 39766226256, 'load_duration': 13454131, 'prompt_eval_count': 847, 'prompt_eval_duration': 27000851000, 'eval_count': 104, 'eval_duration': 12699200000}, id='run-ff22ebef-0e06-4d86-8707-0d2c21fe7169-0', usage_metadata={'input_tokens': 847, 'output_tokens': 104, 'total_tokens': 951})

In [None]:
from langchain.chains import create_sql_query_chain

chain = create_sql_query_chain(llm, db)
#chain.get_prompts()[0].pretty_print()
response = chain.invoke({"question": "How many employees are there"})
response

In [None]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db)
chain = write_query | execute_query
chain.invoke({"question": "How many employees are there"})