#  Install required packages

In [12]:
%pip install langchain huggingface_hub ctransformers


Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.0 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip


# Import dependencies


In [3]:
from huggingface_hub import hf_hub_download
from langchain.llms import CTransformers
from langchain.prompts import ChatPromptTemplate
from langchain.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser

  from .autonotebook import tqdm as notebook_tqdm


# Configuration constants

In [4]:

MODEL_NAME = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
MODEL_FILE = "tinyllama-1.1b-chat-v1.0.Q8_0.gguf"
DB_URI = "mysql+pymysql://root:password@localhost:3306/chinook"

MODEL_CONFIG = {
    'max_new_tokens': 512,
    'temperature': 0.7,
    'context_length': 2048,
}


# Define priority tables and relationships

In [5]:

PRIORITY_TABLES = {
    'Artist': ['ArtistId', 'Name'],
    'Album': ['AlbumId', 'Title', 'ArtistId'],
    'Track': ['TrackId', 'Name', 'AlbumId', 'MediaTypeId', 'GenreId', 'Composer', 'Milliseconds', 'UnitPrice'],
    'Genre': ['GenreId', 'Name'],
    'MediaType': ['MediaTypeId', 'Name'],
    'Playlist': ['PlaylistId', 'Name'],
    'PlaylistTrack': ['PlaylistId', 'TrackId'],
    'Customer': ['CustomerId', 'FirstName', 'LastName', 'Company'],
    'Invoice': ['InvoiceId', 'CustomerId', 'InvoiceDate', 'Total'],
    'InvoiceLine': ['InvoiceLineId', 'InvoiceId', 'TrackId', 'UnitPrice', 'Quantity']
}

KEY_RELATIONSHIPS = """
- Artists create Albums
- Albums contain Tracks
- Tracks belong to Genres and MediaTypes
- Tracks can be in multiple Playlists
- Customers make purchases (Invoices)
- Invoices contain InvoiceLines which reference Tracks
"""

# Example queries for few-shot learning

In [6]:
EXAMPLE_QUERIES = {
    "count": {
        "question": "How many albums are there?",
        "sql": "SELECT COUNT(*) FROM Album",
    },
    "join": {
        "question": "What are the track names in the album 'Big Ones'?",
        "sql": "SELECT Track.Name FROM Track JOIN Album ON Track.AlbumId = Album.AlbumId WHERE Album.Title = 'Big Ones'",
    },
    "aggregate": {
        "question": "What is the total sales amount?",
        "sql": "SELECT SUM(Total) FROM Invoice",
    }
}


# Optimize schema due to limitation of TinyLlama

In [7]:
def optimize_schema(db):
    optimized_schema = {}
    for table, columns in PRIORITY_TABLES.items():
        column_str = ", ".join(columns)
        optimized_schema[table] = f"{table} ({column_str})"
    return " | ".join(optimized_schema.values())

In [8]:
def prepare_sql_context(input_dict):
    question = input_dict["question"].lower()
    
    # Select appropriate example
    if any(word in question for word in ['how many', 'count', 'total']):
        example = EXAMPLE_QUERIES['count']
    elif any(word in question for word in ['sum', 'amount', 'sales']):
        example = EXAMPLE_QUERIES['aggregate']
    else:
        example = EXAMPLE_QUERIES['join']
    
    # Prepare and return all context
    return {
        "question": input_dict["question"],
        "db_schema": optimize_schema(db),
        "relationships": KEY_RELATIONSHIPS.strip(),
        "example_q": example['question'],
        "example_sql": example['sql']
    }


# Set up model and database

In [9]:
model_path = hf_hub_download(MODEL_NAME, filename=MODEL_FILE)
llm = CTransformers(
    model=model_path,
    model_type="llama",
    config=MODEL_CONFIG
)
db = SQLDatabase.from_uri(DB_URI)

# Create prompt template

In [10]:
TEMPLATE = '''You are an SQL expert. Using this simplified schema:
{db_schema}

Key relationships:
{relationships}

Example:
Q: {example_q}
A: {example_sql}

Generate a SQL query to answer this question: {question}
Guidelines:
1. Use appropriate JOINs when linking tables
2. Use aggregation functions when needed (COUNT, SUM, AVG, etc.)
3. Keep the query efficient and accurate

SQL Query:'''

enhanced_prompt = ChatPromptTemplate.from_template(TEMPLATE)

# Create a function to invoke sql chain

In [12]:
def run_sql_chain(question):
    # First, extract the question
    input_data = {"question": question}
    
    # Prepare the SQL context with the optimized schema and example query
    context_data = prepare_sql_context(input_data)
    
    # Render the prompt using the enhanced template
    prompt = enhanced_prompt.format(**context_data)
    
    # Invoke the LLM with the prompt and stop tokens
    result = llm(prompt, stop=["\nSQLResult:", "\nHuman:", "\nAssistant:"])
    
    # Parse the output result as a string
    parsed_result = StrOutputParser().parse(result)
    
    return parsed_result

In [24]:
#        "How many artists are there?",
#        "What are the top 5 longest songs?",
#        "Who is the customer with the highest total purchases?",
#        "How many tracks are in each genre?"

In [14]:
print(run_sql_chain("How many tracks are in each genre?"))


SELECT GenreId, COUNT(*) AS GenreCount FROM Track GROUP BY GenreId;

Result:
GenreId | GenreCount
--------|------------
 1      | 2
 2      | 1
 3      | 1
 4      | 0
 5      | 0

Note: The above query groups Track rows by their respective genre (a string), and counts the number of tracks belonging to each genre.
This results in one row for every track with a given genre, which can be aggregated using COUNT.
