In this project I'll be using Mistral 7B-Instruct model

In [1]:
from langchain_huggingface import HuggingFaceEndpoint
import os
from getpass import getpass

HUGGINGFACEHUB_API_TOKEN = getpass("Hugging Face API Token: ")
os.environ["HUGGINGFACEHUB_API_TOKEN"] = HUGGINGFACEHUB_API_TOKEN
repo_id = "mistralai/Mistral-7B-Instruct-v0.3"

llm = HuggingFaceEndpoint(
    repo_id=repo_id,
    temperature=0.01,
    model_kwargs={"max_length": 100},
    huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN
)

    



The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /Users/alessandropaolini/.cache/huggingface/token
Login successful


In [2]:
from langchain.prompts import ChatPromptTemplate

template = """ 
    Based on the table schema below, write a SQL query that would answer the user's question:
    {schema}

    Question: {question}
    SQL Query
    """

prompt = ChatPromptTemplate.from_template(template)

In [3]:
prompt.format(schema="my schema", question="How many users are there?")

"Human:  \n    Based on the table schema below, write a SQL query that would answer the user's question:\n    my schema\n\n    Question: How many users are there?\n    SQL Query\n    "

In [4]:
from langchain_community.utilities import SQLDatabase

db_uri = "mysql+mysqlconnector://root:xxxxx@localhost:3306/Chinook"
db = SQLDatabase.from_uri(db_uri)

This is not expected to work if you don't put your mysql password where the xs are. Also, you need to have up and running a sql instance with Chinook, a friendly database for these small projects.

In [6]:
db.run("SELECT * FROM Album LIMIT 5;")

"[(1, 'For Those About To Rock We Salute You', 1), (2, 'Balls to the Wall', 2), (3, 'Restless and Wild', 2), (4, 'Let There Be Rock', 1), (5, 'Big Ones', 3)]"

In [7]:
def get_schema(_):
    return db.get_table_info()

In [8]:
get_schema(None)

'\nCREATE TABLE `Album` (\n\t`AlbumId` INTEGER NOT NULL, \n\t`Title` VARCHAR(160) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci NOT NULL, \n\t`ArtistId` INTEGER NOT NULL, \n\tPRIMARY KEY (`AlbumId`), \n\tCONSTRAINT `FK_AlbumArtistId` FOREIGN KEY(`ArtistId`) REFERENCES `Artist` (`ArtistId`)\n)COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB DEFAULT CHARSET=utf8mb4\n\n/*\n3 rows from Album table:\nAlbumId\tTitle\tArtistId\n1\tFor Those About To Rock We Salute You\t1\n2\tBalls to the Wall\t2\n3\tRestless and Wild\t2\n*/\n\n\nCREATE TABLE `Artist` (\n\t`ArtistId` INTEGER NOT NULL, \n\t`Name` VARCHAR(120) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci, \n\tPRIMARY KEY (`ArtistId`)\n)COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB DEFAULT CHARSET=utf8mb4\n\n/*\n3 rows from Artist table:\nArtistId\tName\n1\tAC/DC\n2\tAccept\n3\tAerosmith\n*/\n\n\nCREATE TABLE `Customer` (\n\t`CustomerId` INTEGER NOT NULL, \n\t`FirstName` VARCHAR(40) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci NOT NULL, \n\t`LastNa

In [9]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm
    | StrOutputParser()
)

In [10]:
sql_chain.invoke({"question": "how many artists are there?"})




'----------\n    SELECT COUNT(*) FROM Artist;'

In [11]:
template = """
    Based on the table schema below, question, sql query and sql response, write a natural response:
    {schema}

    Question: {question}
    SQL Query: {query}
    SQL Response: {response}
    """

prompt = ChatPromptTemplate.from_template(template)

In [12]:
def run_query(query):
    return db.run(query)

In [13]:
run_query("SELECT COUNT(*) FROM Artist;;")

'[(275,)]'

In [17]:
from langchain_core.runnables import RunnablePassthrough, RunnableLambda

def clean_query(variables):
    query = variables["query"]
    cleaned_query = query.split('\n', 1)[1].strip()
    return cleaned_query

clean_query_runnable = RunnableLambda(func=clean_query)

full_chain = (
    RunnablePassthrough.assign(query=sql_chain)
    .assign(query=clean_query_runnable)
    .assign(
        schema=get_schema,
        response=lambda variables: run_query(variables["query"])
    )
    | prompt
    | llm
    | StrOutputParser()
)


In [28]:
full_chain.invoke({"question": "how many songs are in the whole database?"})

'3503 songs are in the whole database.'