#  Install required packages

In [12]:
%pip install langchain huggingface_hub ctransformers


Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.0 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip


# Import dependencies


In [4]:
from huggingface_hub import hf_hub_download
from langchain.llms import CTransformers
from langchain.prompts import ChatPromptTemplate
from langchain.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

# Download the model

In [5]:
MODEL_NAME = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
MODEL_FILE = "tinyllama-1.1b-chat-v1.0.Q8_0.gguf"
DB_URI = "mysql+pymysql://root:password@localhost:3306/chinook"

MODEL_CONFIG = {
    'max_new_tokens': 512,
    'temperature': 0.7,
    'context_length': 2048,
}

In [6]:
PRIORITY_TABLES = {
    'Artist': ['ArtistId', 'Name'],
    'Album': ['AlbumId', 'Title', 'ArtistId'],
    'Track': ['TrackId', 'Name', 'AlbumId', 'MediaTypeId', 'GenreId', 'Composer', 'Milliseconds', 'UnitPrice'],
    'Genre': ['GenreId', 'Name'],
    'MediaType': ['MediaTypeId', 'Name'],
    'Playlist': ['PlaylistId', 'Name'],
    'PlaylistTrack': ['PlaylistId', 'TrackId'],
    'Customer': ['CustomerId', 'FirstName', 'LastName', 'Company'],
    'Invoice': ['InvoiceId', 'CustomerId', 'InvoiceDate', 'Total'],
    'InvoiceLine': ['InvoiceLineId', 'InvoiceId', 'TrackId', 'UnitPrice', 'Quantity']
}

In [7]:
KEY_RELATIONSHIPS = """
- Artists create Albums
- Albums contain Tracks
- Tracks belong to Genres and MediaTypes
- Tracks can be in multiple Playlists
- Customers make purchases (Invoices)
- Invoices contain InvoiceLines which reference Tracks
"""

In [8]:
model_path = hf_hub_download(MODEL_NAME, filename=MODEL_FILE)
print(f"Model downloaded to: {model_path}")

llm = CTransformers(
    model=model_path,
    model_type="llama",
    config=MODEL_CONFIG
)

Model downloaded to: C:\Users\USER\.cache\huggingface\hub\models--TheBloke--TinyLlama-1.1B-Chat-v1.0-GGUF\snapshots\52e7645ba7c309695bec7ac98f4f005b139cf465\tinyllama-1.1b-chat-v1.0.Q8_0.gguf


In [9]:
db = SQLDatabase.from_uri(DB_URI)

In [10]:
def optimize_schema(db):
    schema_info = db.get_table_info()
    optimized_schema = {}
    
    for table, columns in PRIORITY_TABLES.items():
        table_start = False
        column_definitions = []
        
        for line in schema_info.splitlines():
            if line.strip().startswith(f"CREATE TABLE `{table}`"):
                table_start = True
            elif table_start and any(f"`{col}`" in line for col in columns):
                # Simplify column definitions to save tokens
                simplified_line = line.split('COMMENT')[0].strip().rstrip(',')
                column_definitions.append(simplified_line)
        
        if column_definitions:
            table_def = f"CREATE TABLE {table} (\n  "
            table_def += ",\n  ".join(column_definitions)
            table_def += "\n);"
            optimized_schema[table] = table_def
    
    return "\n\n".join(optimized_schema.values())

In [11]:
optimized_schema = optimize_schema(db)

In [12]:
def get_optimized_schema(_):
    return optimized_schema + "\n\nKey Relationships:\n" + KEY_RELATIONSHIPS

In [13]:
template = """You are an AI assistant that helps users query a music store database.
Given the following simplified database schema:
{db_schema}

Generate a SQL query to answer the user's question.
Guidelines:
1. Use JOIN operations when needed to connect related tables
2. Consider relationships between Artists, Albums, Tracks, and Sales
3. For aggregations, consider using GROUP BY and appropriate functions

Question: {question}
SQL Query:"""

prompt = ChatPromptTemplate.from_template(template)

In [14]:
sql_chain = (
    RunnablePassthrough.assign(db_schema=get_optimized_schema)
    | prompt
    | llm.bind(stop=["\nSQLResult:", "\nHuman:", "\nAssistant:"])
    | StrOutputParser()
)

In [22]:
result = sql_chain.invoke({"question":"How many artist are there ?"})
print(result)

 SELECT DISTINCT artist_name FROM artists;

Output: There are 568075 Artists.
