Grundlage:
* https://js.langchain.com/v0.1/docs/modules/chains/popular/sqlite/
* https://js.langchain.com/v0.1/docs/integrations/toolkits/sql/

* https://python.langchain.com/docs/tutorials/sql_qa/

# Setup

In [65]:
import sqlite3
import langchain
from langchain_community.utilities import SQLDatabase

In [66]:
db_path = r"./POC-LangChain/chinook-database-master/ChinookDatabase/DataSources/Chinook_Sqlite.sqlite"
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")

In [67]:
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

# Chains
Sequence of steps that does the following:
* converts the question into a SQL query;
* executes the query;
* uses the result to answer the original question.


## Application state

In [68]:
from typing_extensions import TypedDict

class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str

In [69]:
from langchain import hub

query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")

assert len(query_prompt_template.messages) == 1
query_prompt_template.messages[0].pretty_print()




Given an input question, create a syntactically correct [33;1m[1;3m{dialect}[0m query to run to help find the answer. Unless the user specifies in his question a specific number of examples they wish to obtain, always limit your query to at most [33;1m[1;3m{top_k}[0m results. You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Only use the following tables:
[33;1m[1;3m{table_info}[0m

Question: [33;1m[1;3m{input}[0m


In [70]:
!ollama list

NAME                     	ID          	SIZE  	MODIFIED     
qwen2.5-coder:7b         	2b0496514337	4.7 GB	3 hours ago 	
deepseek-r1:8b           	28f8fd6cdc67	4.9 GB	2 days ago  	
llama3.2:1b-instruct-q4_0	53f2745c8077	770 MB	3 months ago	
llama3.2:1b              	baf6a787fdff	1.3 GB	3 months ago	
llama3.1:8b              	42182419e950	4.7 GB	4 months ago	
mistral:instruct         	f974a74358d6	4.1 GB	4 months ago	


In [71]:
from langchain_ollama import ChatOllama

class MyChatOllama(ChatOllama):
    def with_structured_output(self, schema, *, include_raw=False, **kwargs):
        # Call bind_tools without the format parameter
        llm = self.bind_tools(tools=[schema])  # omit format="json"
        from langchain_core.output_parsers.pydantic import PydanticOutputParser
        from langchain_core.output_parsers.json import JsonOutputParser
        is_pydantic_schema = isinstance(schema, type)  # simplified check
        output_parser = (PydanticOutputParser(pydantic_object=schema)
                         if is_pydantic_schema else JsonOutputParser())
        # Compose the output parser
        from langchain_core.runnables import RunnableLambda
        parser_chain = RunnableLambda(lambda x: x) | output_parser
        # If include_raw is requested, add a passthrough (optional)
        if include_raw:
            from langchain_core.runnables.passthrough import RunnablePassthrough
            parser_assign = RunnablePassthrough.assign(parsed=lambda x: parser_chain.invoke(x),
                                                        parsing_error=lambda _: None)
            return llm | parser_assign
        else:
            return llm | parser_chain


In [72]:
import re
import json

def clean_response(raw_text: str) -> str:
    # Remove any header tokens like <|start_header_id|>assistant<|end_header_id|>
    cleaned = re.sub(r"<\|start_header_id\|>.*?<\|end_header_id\|>\s*", "", raw_text, flags=re.DOTALL)
    return cleaned.strip()

In [73]:
model = MyChatOllama(
    model=
        'llama3.2:1b',
    temperature=0
)

In [78]:
from typing_extensions import Annotated
class QueryOutput(TypedDict):
    """Generated SQL query."""

    query: Annotated[str, ..., "Syntactically valid SQL query."]


In [91]:
def write_query(state: dict):
    """Generate SQL query to fetch information."""
    prompt = query_prompt_template.invoke({
        "dialect": db.dialect,
        "top_k": 10,
        "table_info": db.get_table_info(),
        "input": state["question"],
    })
    llm = model
    # Try json_schema method (you can also experiment with function_calling)
    structured_llm = llm.with_structured_output(QueryOutput, method="json_schema")
    # Invoke the chain
    result = structured_llm.invoke(prompt)
    # If result is not valid JSON, attempt to clean it
    if not isinstance(result, dict):
        # Assuming result is a string containing extra tokens, clean it first.
        cleaned = clean_response(result)
        parsed = json.loads(cleaned)
        result = parsed
    return {"query": result["query"]}

In [89]:
def write_query(state: dict):
    prompt = query_prompt_template.invoke({
        "dialect": db.dialect,
        "top_k": 10,
        "table_info": db.get_table_info(),
        "input": state["question"],
    })
    # raw response
    raw_response = model.invoke(prompt).content
    # Clean header tokens
    cleaned = clean_response(raw_response)
    # wrap the plain text SQL query in a JSON object
    return {"query": cleaned}


In [92]:
write_query({"question": 
             "Write a query to find the customer who has spent the most money on purchases. Display their full name, email, and total amount spent."
             })

OutputParserException: Invalid json output: <|start_header_id|>assistant<|end_header_id|>

To solve this problem, we need to join the `Customer`, `InvoiceLine`, `Track`, `MediaType`, `Playlist`, and `PlaylistTrack` tables based on their respective IDs. We also need to filter out customers who have not made any purchases.

Here is a SQL query that accomplishes this:

```sql
SELECT 
    c.Name AS CustomerName, 
    c.Email AS CustomerEmail, 
    SUM(PL.Total) AS TotalSpent
FROM 
    Customer c
JOIN 
    InvoiceLine il ON c.CustomerId = il.CustomerId
JOIN 
    Track t ON il.TrackId = t.TrackId
JOIN 
    MediaType m ON t.MediaTypeId = m.MediaTypeId
JOIN 
    Playlist p ON t.AlbumId = p.AlbumId
JOIN 
    PlaylistTrack pt ON p.PlaylistId = pt.PlaylistId
WHERE 
    c.CustomerId NOT IN (
        SELECT 
            CustomerId 
        FROM 
            InvoiceLine 
    )
GROUP BY 
    c.Name, c.Email
ORDER BY 
    TotalSpent DESC;
```

This query works as follows:

1. It joins the `Customer` table with other tables based on their IDs.
2. It filters out customers who have not made any purchases by joining the `InvoiceLine`, `Track`, `MediaType`, `Playlist`, and `PlaylistTrack` tables to the `Customer` table, but only for customers who do not appear in the `InvoiceLine` table.
3. It groups the results by customer name and email address.
4. Finally, it orders the results by total spent in descending order.

When you run this query on your database, it will return a list of customers who have made the most purchases, along with their full names, emails, and total amounts spent.
For troubleshooting, visit: https://python.langchain.com/docs/troubleshooting/errors/OUTPUT_PARSING_FAILURE 