In [None]:
from langchain_ollama import ChatOllama
conn = "mysql+pymysql://root:123@127.0.0.1/sys"
llm = ChatOllama(
    base_url="http://localhost:11434",
    model="llama3.1",
    temperature=0,
    # other params...
)
llm.predict("123")

In [None]:
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri(
    "mysql+pymysql://root:123@127.0.0.1/chinook", 
    include_tables=["employee"],
    sample_rows_in_table_info=3
    )
print(db.dialect)
print(db.get_usable_table_names())


In [30]:
from langchain_core.prompts import ChatPromptTemplate

template = """
Please generate only an executable SQL query, strictly following the structure and using the schema below. Do not include explanations or additional text.
{schema}

Question: {question}
SQL Query:
"""
prompt = ChatPromptTemplate.from_template(template)
def get_schema(_):
    schema = db.get_table_info()
    return schema


In [None]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm
    # | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

sql_chain.invoke({"question": "總共有多少資料"})

In [None]:
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}
"""

# template = """
# "Based on the query results, and aiming to compile all available data as much as possible, 
# respond in one of the preferred formats: Highcharts.js or a bullet-point summary of the results. 
# If the data involves trends, statistics, or analysis that can be visually represented, 
# prioritize determining whether a Highcharts.js visualization can be created, 
# and only return the Highcharts.chart code. Any other information is unnecessary,
# and no explanation is needed for cases where a chart cannot be generated.":
# {schema}

# Question: {question}
# SQL Query: {query}
# SQL Response: {response}
# """
prompt_response = ChatPromptTemplate.from_template(template)

def run_query(query):
    return db.run(query)

main_chain = (
    RunnablePassthrough
    .assign(query=sql_chain) # 上一個chain的結果
    .assign(
        schema=get_schema,
        response=lambda x : run_query(x["query"]) # sql 執行的結果
    )
    | prompt_response
    | llm
)
main_chain.invoke({"question":"有多少資料"})

串列處理

並列處理

In [None]:
from langchain_core.runnables import RunnableLambda

def add_one(x: int) -> int:
    return x + 1

def add_two(x: int) -> int:
    return x + 2


runnable_1 = RunnableLambda(add_one)
runnable_2 = RunnableLambda(add_two)

parallel = {"runnable_1": runnable_1, "runnable_2": runnable_2}

chain = RunnableLambda(lambda x: x) | parallel
answer = chain.invoke(1)

print(answer)

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel

summ_chain = (
    ChatPromptTemplate.from_template("幫我依據資料總結資訊，以{type}呈現")
    | llm
    | StrOutputParser()
)

char_chain = (
    ChatPromptTemplate.from_template("依據資料生成可繪製的highchart.js格式，不需要其他資訊")
    | llm
    | StrOutputParser()
)

main =  RunnableParallel(joke=summ_chain, poem=char_chain)

main.invoke({"type":"條列式"}) 


In [None]:
main.get_graph().print_ascii()

完整

In [1]:
from langchain_ollama import ChatOllama
from langchain_community.utilities import SQLDatabase
conn = "mysql+pymysql://root:123@127.0.0.1/sys"
llm = ChatOllama(
    base_url="http://localhost:11434",
    model="llama3.1",
    temperature=0,
    # other params...
)

db = SQLDatabase.from_uri(
    "mysql+pymysql://root:123@127.0.0.1/chinook", 
    include_tables=["employee"],
    sample_rows_in_table_info=3
    )



In [5]:
from langchain_core.prompts import ChatPromptTemplate

template = """
Please generate only an executable SQL query, strictly following the structure and using the schema below. Do not include explanations or additional text.
{schema}

Question: {question}
SQL Query:
"""
prompt = ChatPromptTemplate.from_template(template)
def get_schema(_):
    schema = db.get_table_info()
    return schema
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm
    | StrOutputParser()
)
sql = sql_chain.invoke({"question":"有多少資料"})
print(sql)

SELECT COUNT(*) FROM employee;


In [7]:
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}
"""


prompt_response = ChatPromptTemplate.from_template(template)

def run_query(query):
    return db.run(query)

main_chain = (
    RunnablePassthrough
    .assign(query=lambda x : sql)
    .assign(
        schema=get_schema,
        response=lambda x : run_query(x["query"])
    )
    | prompt_response
    | llm
)

main_chain.invoke({"question":"有多少資料"})

AIMessage(content='有 8 個員工的資料。', additional_kwargs={}, response_metadata={'model': 'llama3.1', 'created_at': '2024-09-29T13:57:29.632042244Z', 'message': {'role': 'assistant', 'content': ''}, 'done_reason': 'stop', 'done': True, 'total_duration': 25814255383, 'load_duration': 13407892, 'prompt_eval_count': 767, 'prompt_eval_duration': 24635190000, 'eval_count': 10, 'eval_duration': 1111316000}, id='run-a3b8b3a3-bc55-448a-aad5-af057fe99ada-0', usage_metadata={'input_tokens': 767, 'output_tokens': 10, 'total_tokens': 777})

In [8]:
main_chain.get_graph().print_ascii()

                 +----------------------+                    
                 | Parallel<query>Input |                    
                 +----------------------+                    
                       **         ***                        
                     **              *                       
                    *                 **                     
             +--------+          +-------------+             
             | Lambda |          | Passthrough |             
             +--------+          +-------------+             
                       **         ***                        
                         **      *                           
                           *   **                            
                 +-----------------------+                   
                 | Parallel<query>Output |                   
                 +-----------------------+                   
                             *                               
        

In [9]:
sql_chain.get_graph().print_ascii()

         +-----------------------+           
         | Parallel<schema>Input |           
         +-----------------------+           
             ***            **               
           **                 **             
         **                     **           
+------------+              +-------------+  
| get_schema |              | Passthrough |  
+------------+              +-------------+  
             ***            **               
                **        **                 
                  **    **                   
        +------------------------+           
        | Parallel<schema>Output |           
        +------------------------+           
                     *                       
                     *                       
                     *                       
          +--------------------+             
          | ChatPromptTemplate |             
          +--------------------+             
                     *            