In [None]:
# Cài dependencies (chạy 1 lần)
import importlib.util
import subprocess
import sys

def _pip_install(*args: str) -> None:
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', *args])

def _ensure_pkg(module: str, *pip_args: str) -> None:
    if importlib.util.find_spec(module) is None:
        _pip_install('-U', *pip_args)

_ensure_pkg('llama_cpp', 'llama-cpp-python')
_ensure_pkg('duckdb', 'duckdb')

# huggingface-hub: dùng <1.0 để tránh xung đột với transformers
try:
    from importlib.metadata import version as _version
    _hub_version = _version('huggingface-hub')
    _hub_major = int(_hub_version.split('.', 1)[0])
except Exception:
    _hub_major = None

if _hub_major is None or _hub_major >= 1:
    _pip_install('-U', 'huggingface-hub<1.0,>=0.34.0')

from llama_cpp import Llama
import os

# Giảm n_ctx để tránh tốn RAM (KV cache)
llm = Llama.from_pretrained(
    repo_id='Ellbendls/Qwen-3-4b-Text_to_SQL-GGUF',
    filename='Qwen-3-4b-Text_to_SQL-q4_k_m.gguf',
    chat_format='qwen',
    n_ctx=1024,
    n_batch=128,
    n_ubatch=128,
    n_threads=os.cpu_count() or 4,
    n_gpu_layers=0,  # đổi >0 nếu có GPU + hỗ trợ
    verbose=False,
)


  from .autonotebook import tqdm as notebook_tqdm
llama_context: n_ctx_per_seq (1024) < n_ctx_train (262144) -- the full capacity of the model will not be utilized


In [None]:
import duckdb

DB_PATH = '../data/tpch-sf100.db'
con = duckdb.connect(database=DB_PATH)

# Chỉ đọc metadata (không quét dữ liệu 100GB)
tables = [r[0] for r in con.execute('SHOW TABLES').fetchall()]
tables


['customer',
 'lineitem',
 'nation',
 'orders',
 'part',
 'partsupp',
 'region',
 'supplier']

In [3]:
import re

def duckdb_schema_prompt(con, *, table_schema: str = 'main', max_tables: int | None = None) -> str:
    rows = con.execute(
        '''
        SELECT table_name, column_name, data_type, ordinal_position
        FROM information_schema.columns
        WHERE table_schema = ?
        ORDER BY table_name, ordinal_position
        ''',
        [table_schema],
    ).fetchall()

    tables: dict[str, list[tuple[str, str]]] = {}
    for table_name, column_name, data_type, _ in rows:
        tables.setdefault(str(table_name), []).append((str(column_name), str(data_type)))

    table_names = sorted(tables.keys())
    if max_tables is not None:
        table_names = table_names[:max_tables]

    lines: list[str] = []
    for t in table_names:
        lines.append(f'TABLE {t} (')
        for col, typ in tables[t]:
            lines.append(f'  {col} {typ}')
        lines.append(')')
        lines.append('')
    return '\n'.join(lines).strip()

schema_text = duckdb_schema_prompt(con)
schema_text


'TABLE customer (\n  c_custkey BIGINT\n  c_name VARCHAR\n  c_address VARCHAR\n  c_nationkey INTEGER\n  c_phone VARCHAR\n  c_acctbal DECIMAL(15,2)\n  c_mktsegment VARCHAR\n  c_comment VARCHAR\n)\n\nTABLE lineitem (\n  l_orderkey BIGINT\n  l_partkey BIGINT\n  l_suppkey BIGINT\n  l_linenumber BIGINT\n  l_quantity DECIMAL(15,2)\n  l_extendedprice DECIMAL(15,2)\n  l_discount DECIMAL(15,2)\n  l_tax DECIMAL(15,2)\n  l_returnflag VARCHAR\n  l_linestatus VARCHAR\n  l_shipdate DATE\n  l_commitdate DATE\n  l_receiptdate DATE\n  l_shipinstruct VARCHAR\n  l_shipmode VARCHAR\n  l_comment VARCHAR\n)\n\nTABLE nation (\n  n_nationkey INTEGER\n  n_name VARCHAR\n  n_regionkey INTEGER\n  n_comment VARCHAR\n)\n\nTABLE orders (\n  o_orderkey BIGINT\n  o_custkey BIGINT\n  o_orderstatus VARCHAR\n  o_totalprice DECIMAL(15,2)\n  o_orderdate DATE\n  o_orderpriority VARCHAR\n  o_clerk VARCHAR\n  o_shippriority INTEGER\n  o_comment VARCHAR\n)\n\nTABLE part (\n  p_partkey BIGINT\n  p_name VARCHAR\n  p_mfgr VARCHAR\

In [4]:
_FORBIDDEN_SQL = re.compile(
    r'\b(INSERT|UPDATE|DELETE|DROP|ALTER|CREATE|COPY|PRAGMA|ATTACH|DETACH|EXPORT|IMPORT|CALL)\b',
    re.IGNORECASE,
)

def _extract_sql(text: str) -> str:
    text = text.strip()
    m = re.search(r'```(?:sql)?\s*(.*?)```', text, flags=re.IGNORECASE | re.DOTALL)
    if m:
        text = m.group(1).strip()
    # lấy statement đầu tiên nếu model trả nhiều statement
    if ';' in text:
        text = text.split(';', 1)[0].strip()
    return text

def _is_safe_select(sql: str) -> bool:
    s = re.sub(r'--.*?$', '', sql, flags=re.MULTILINE).strip()
    if not s:
        return False
    if _FORBIDDEN_SQL.search(s):
        return False
    first = re.split(r'\s+', s, maxsplit=1)[0].upper()
    return first in {'SELECT', 'WITH'}

def _ensure_limit(sql: str, limit: int | None) -> str:
    s = sql.strip().rstrip(';').strip()
    if limit is None:
        return s
    if re.search(r'\bLIMIT\b', s, flags=re.IGNORECASE):
        return s
    return f'{s}\nLIMIT {limit}'

def generate_sql(question: str, *, schema_text: str, default_limit: int | None = 100) -> str:
    system = (
        'Bạn là trợ lý chuyển câu hỏi sang SQL cho DuckDB. '
        'Chỉ trả về SQL hợp lệ, không giải thích, không markdown. '
        'Chỉ dùng bảng/cột có trong schema. '
        'Nếu không có yêu cầu rõ ràng, hãy thêm LIMIT để query nhẹ.'
    )
    user = f'''SCHEMA:
{schema_text}

QUESTION:
{question}

SQL:'''
    resp = llm.create_chat_completion(
        messages=[
            {'role': 'system', 'content': system},
            {'role': 'user', 'content': user},
        ],
        temperature=0.0,
        max_tokens=512,
    )
    text = resp['choices'][0]['message']['content']
    sql = _extract_sql(text)
    sql = _ensure_limit(sql, default_limit)
    if not _is_safe_select(sql):
        raise ValueError(f'Unsafe/invalid SQL generated: {sql!r}')
    return sql



In [5]:
# Ví dụ
question = 'Top 10 khách hàng có tổng giá trị đơn hàng cao nhất?'
sql = generate_sql(question, schema_text=schema_text, default_limit=100)
print(sql)


SELECT c_custkey, SUM(o_totalprice) as total_value FROM customer c JOIN orders o ON c.c_custkey = o.o_custkey GROUP BY c_custkey ORDER BY total_value DESC LIMIT 10
