#  Install required packages

In [12]:
%pip install langchain huggingface_hub ctransformers


Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.0 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip


# Import dependencies


In [2]:
from huggingface_hub import hf_hub_download
from langchain.llms import CTransformers
from langchain.prompts import ChatPromptTemplate
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain

  from .autonotebook import tqdm as notebook_tqdm


# Download the model

In [3]:
model_name = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
model_file = "tinyllama-1.1b-chat-v1.0.Q8_0.gguf"

model_path = hf_hub_download(model_name, filename=model_file)
print(f"Model downloaded to: {model_path}")

Model downloaded to: C:\Users\USER\.cache\huggingface\hub\models--TheBloke--TinyLlama-1.1B-Chat-v1.0-GGUF\snapshots\52e7645ba7c309695bec7ac98f4f005b139cf465\tinyllama-1.1b-chat-v1.0.Q8_0.gguf


# Initialize the model


In [15]:
llm = CTransformers(
    model=model_path,
    model_type="llama",
    config={
        'max_new_tokens': 512,
        'temperature': 0.7,
        'context_length': 2048,
    }
)

# Set up database connection

In [5]:
from langchain.sql_database import SQLDatabase

# Provide a connection URI
db = SQLDatabase.from_uri("mysql+pymysql://root:password@localhost:3306/chinook")


# Create chat prompt template

In [6]:
from langchain.prompts import ChatPromptTemplate

template = """You are an AI assistant that helps users query MySQL databases.
    Given the following database schema:
    {db_schema}
    
    Generate a SQL query to answer the user's question.
    Only generate the SQL query, no other text.
    Question : {question}
    """
prompt = ChatPromptTemplate.from_template(template)

# Create and configure the SQL Chain

In [61]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

sql_chain = (
    RunnablePassthrough.assign(db_schema=get_schema)  # Pass the schema from the DB
    | prompt
    | llm.bind(stop=["\nSQLResult:"]) 
    | StrOutputParser()  # Parse the output
)

In [53]:
def get_schema(_):
    schema_info = db.get_table_info()  # Ensure this is called on the SQLDatabase instance
    #print("Schema Output:")  # Debug output
    #print(schema_info)  # Print the schema for inspection
    
    # Initialize a dictionary to hold the relevant schema
    relevant_schema = {}
    current_table = None

    # Split the schema_info string into lines
    lines = schema_info.splitlines()

    for line in lines:
        line = line.strip()  # Remove leading and trailing whitespace
        
        # Check for the start of a table definition
        if line.startswith("CREATE TABLE"):
            current_table = line.split()[2]  # Extract the table name (third item)
            relevant_schema[current_table] = []  # Initialize list for columns

        # Check for column definitions within the current table
        elif current_table and line.startswith("`"):
            # Extract the column name from the definition
            column_name = line.split()[0].strip('`')  # Remove backticks
            relevant_schema[current_table].append(column_name)  # Add column to the list

    # Limit each table's columns to a specific number, e.g., the first 3
    for table_name in relevant_schema.keys():
        relevant_schema[table_name] = relevant_schema[table_name]  # Keep only first 3 columns
    #print(relevant_schema)
    return relevant_schema


In [69]:
result = sql_chain.invoke({"question":"How many artist in this database ?"})
print(result)

 SELECT COUNT(*) AS `num_artists` FROM `artist`;
