# Generative AI for interacting with SQL

Reference
- [Github vanna](https://github.com/vanna-ai/vanna)
- [vaana documentation](https://vanna.ai/docs/mssql-gemini-qdrant/)

In [None]:
# 先卸載現有的 onnxruntime（如果有的話） 
!pip uninstall -y onnxruntime onnxruntime-gpu

!pip install onnxruntime==1.21.0 
!pip install vanna
!pip install "qdrant-client[fastembed]"  # 包含 fastembed 的依賴
!pip install google-generativeai 
!pip install vertexai 
!pip install fastembed

Found existing installation: onnxruntime 1.17.0
Uninstalling onnxruntime-1.17.0:
  Successfully uninstalled onnxruntime-1.17.0


You can safely remove it manually.


Collecting onnxruntime==1.21.0
  Using cached onnxruntime-1.21.0-cp312-cp312-win_amd64.whl.metadata (4.9 kB)
Using cached onnxruntime-1.21.0-cp312-cp312-win_amd64.whl (11.8 MB)
Installing collected packages: onnxruntime
Successfully installed onnxruntime-1.21.0


ERROR: Invalid requirement: '#'








## 1. Set up : Vanna AI

In [4]:
# for loading environment variables
import os
from dotenv import load_dotenv

# set up vanna ai
from vanna.qdrant import Qdrant_VectorStore
from qdrant_client import QdrantClient
from vanna.google import GoogleGeminiChat

# for text cleaning used
import re

In [5]:
# Load environment variables from .env file
load_dotenv()

# Get database connection details from environment variables

db_server = os.getenv('DB_SERVER')
db_database = os.getenv('DB_DATABASE')
db_username = os.getenv('DB_USERNAME')
db_driver = os.getenv('DB_DRIVER')

gemini_api_key = os.getenv('GEMINI_API_KEY')
gemini_model = os.getenv('GEMINI_MODEL')

qdranticlient = os.getenv('QDRANT_CLIENT')
qdrant_api_key = os.getenv("QDRANT_API_KEY")


In [6]:
from vanna.qdrant import Qdrant_VectorStore
from qdrant_client import QdrantClient
from vanna.google import GoogleGeminiChat

# Your Gemini API key and model (replace with your actual values)

# Maker sure all information is loaded and set up.
if gemini_api_key is None:
    raise ValueError("GEMINI_API_KEY environment variable not set")
if gemini_model is None:
    raise ValueError("GEMINI_MODEL environment variable not set")
if qdrant_api_key is None:
    raise ValueError("QDRANT_SERVICE_API_KEY environment variable not set")

class MyVanna(Qdrant_VectorStore, GoogleGeminiChat):  # Inherit from both
    def __init__(self, config=None):
        # Initialize the parent classes
        Qdrant_VectorStore.__init__(self, config=config)
        GoogleGeminiChat.__init__(self, config={'api_key': gemini_api_key, 'model': gemini_model})

# Set up the QdrantClient
qdrant_client = QdrantClient(host = "localhost", port=6333, api_key=qdrant_api_key, https=False)

# Pass the QdrantClient object in the config
vn = MyVanna(config={'client': qdrant_client})

print("Vanna instance created successfully!") # Verification

  qdrant_client = QdrantClient(host = "localhost", port=6333, api_key=qdrant_api_key, https=False)


Vanna instance created successfully!


# 2. Connection to Database

In [7]:
conn_str = (
    f"DRIVER={{{db_driver}}};"
    f"SERVER={db_server};"
    f"DATABASE={db_database};"
)

# Add authentication details based on whether username/password are provided
if db_username and db_password:
    print("Connecting using SQL Server Authentication...")
    conn_str += f"UID={db_username};PWD={db_password};"
else:
    print("Connecting using Windows Authentication (Trusted Connection)...")
    conn_str += "Trusted_Connection=yes;"

vn.connect_to_mssql(odbc_conn_str = conn_str)

Connecting using Windows Authentication (Trusted Connection)...


# 3. Traning (Embedding) (do it once)

3-1. Training with sql information schema

In [14]:
# The information schema query may need some tweaking depending on your database. This is a good starting point.
df_information_schema = vn.run_sql("SELECT * FROM INFORMATION_SCHEMA.COLUMNS")

# This will break up the information schema into bite-sized chunks that can be referenced by the LLM
plan = vn.get_training_plan_generic(df_information_schema)
plan

# If you like the plan, then uncomment this and run it to train
# The vectors will be stored in the Qdant's "Documentation" collection.
vn.train(plan=plan)

3-2. Traning with DDL. (Data Definition Language)

In [15]:
import re

# Set path
ddl_file_path = "create_tables.sql"

# Read the DDL file
try:
    with open(ddl_file_path, 'r', encoding='utf-8-sig') as f:
        ddl_content = f.read()
    print (f"成功讀取DDL檔案: {ddl_file_path}")

    # 檢查讀取的內容是否為空
    if not ddl_content.strip():
        print("警告: DDL 檔案是空的，將不會進行訓練。")
    else:
        print("正在清理DDL內容...")
        # 1. 移除 DROP TABLE 行
        cleaned_ddl = re.sub(r"^\s*DROP TABLE.*?;?\s*$", "", ddl_content, flags=re.MULTILINE | re.IGNORECASE)

        # 2. 移除簡單的 "-- Create/Add ..." 註解 (可選)
        cleaned_ddl = re.sub(r"^\s*--\s*(Create|Add)\s+.*?\s*$", "", cleaned_ddl, flags=re.MULTILINE | re.IGNORECASE)

        # 移除可能產生的多餘空行
        cleaned_ddl = "\n".join([line for line in cleaned_ddl.splitlines() if line.strip()])

        # 3. 將讀取的 DDL 字串傳遞給 vn.train()
        print("開始使用讀取的 DDL 進行 Vanna 訓練...")
        vn.train(ddl=cleaned_ddl)
        print("使用清理後 DDL 檔案進行訓練完成。")

except FileNotFoundError:
    print(f"錯誤: 找不到 DDL 檔案 '{ddl_file_path}'。請確認檔案路徑是否正確。")
    raise
except Exception as e:
    print(f"讀取 DDL 檔案時發生錯誤: {e}")


成功讀取DDL檔案: create_tables.sql
正在清理DDL內容...
開始使用讀取的 DDL 進行 Vanna 訓練...
Adding ddl: -- Drop tables if they exist (optional, for easier re-running)
CREATE TABLE Categories (
    category_id INT PRIMARY KEY IDENTITY(1,1), -- Use IDENTITY for SQL Server. Auto-increment for MYSQL. SERIAL FOR PostgreSQL.
    category_name VARCHAR(255) NOT NULL UNIQUE
);
CREATE TABLE Customers (
    customer_id INT PRIMARY KEY IDENTITY(1,1), -- Use IDENTITY for SQL Server auto-increment
    first_name VARCHAR(255) NOT NULL,
    last_name VARCHAR(255) NOT NULL,
    email VARCHAR(255) NOT NULL UNIQUE,
    address TEXT,
    phone VARCHAR(50) -- Increased size to accommodate longer phone numbers
);
CREATE TABLE Products (
    product_id INT PRIMARY KEY IDENTITY(1,1), -- Use IDENTITY for SQL Server auto-increment
    product_name VARCHAR(255) NOT NULL,
    description TEXT,
    price DECIMAL(10, 2) NOT NULL CHECK (price >= 0),
    category_id INT,
    FOREIGN KEY (category_id) REFERENCES Categories(category_id) ON D

3-3 Training with quries and SQL statements.

In [19]:
import json
import vanna # 假設 vn 是你已經初始化好的 Vanna 物件
import pandas as pd

# 1. 載入你的 JSON 訓練資料
training_data = None
json_file_path = "training_data_bilingual.json"

try:
    with open(json_file_path, 'r', encoding='utf-8') as f:
        training_data = json.load(f)
    print(f"Successfully loaded {len(training_data)} training examples from '{json_file_path}'.")
except FileNotFoundError:
    print(f"Error: JSON file not found at '{json_file_path}'.")
    training_data = [] # 設為空列表以防後續錯誤
except json.JSONDecodeError as e:
    print(f"Error: Could not decode JSON from '{json_file_path}'. Invalid JSON format? Error: {e}")
    training_data = []
except Exception as e:
    print(f"An unexpected error occurred while reading the JSON file: {e}")
    training_data = []

# 2. 檢查 training_data 是否成功載入且是一個列表
if training_data and isinstance(training_data, list):
    print("Starting Vanna training...")
    total_items = len(training_data)
    processed_count = 0
    failed_items = []

    # 3. 遍歷列表中的每一個項目
    for i, item in enumerate(training_data):
        item_index = i + 1
        print(f"\nProcessing item {item_index}/{total_items}...")

        try:
            # 檢查項目是否為字典且包含所有必要鍵值
            required_keys = ['query_en', 'query_zh', 'sql_syntax']
            if isinstance(item, dict) and all(key in item for key in required_keys):
                 # 從字典中取出英文問題、中文問題和 SQL
                question_en = item['query_en']
                question_zh = item['query_zh']
                sql = item['sql_syntax']

                # 確保問題和 SQL 都不是空的
                if question_en and question_zh and sql:
                    # --- 訓練英文問題 ---
                    print(f"  Training (EN): Q: {question_en[:80]}... SQL: {sql[:80]}...")
                    vn.train(question=question_en, sql=sql)

                    # --- 訓練中文問題 ---
                    print(f"  Training (ZH): Q: {question_zh[:80]}... SQL: {sql[:80]}...")
                    vn.train(question=question_zh, sql=sql)

                    processed_count += 1
                    print(f"  Item {item_index} trained successfully (both languages).")
                else:
                    missing = []
                    if not question_en: missing.append("'query_en'")
                    if not question_zh: missing.append("'query_zh'")
                    if not sql: missing.append("'sql_syntax'")
                    reason = f"Empty value(s) for {', '.join(missing)}"
                    print(f"  Warning: Skipping item {item_index} due to {reason}.")
                    failed_items.append({'index': item_index, 'item': item, 'reason': reason})
            else:
                 missing_keys = [key for key in required_keys if key not in item]
                 reason = f"Invalid format or missing keys: {', '.join(missing_keys)}" if missing_keys else "Invalid format (not a dict?)"
                 print(f"  Warning: Skipping item {item_index} because: {reason}.")
                 failed_items.append({'index': item_index, 'item': item, 'reason': reason})

        except Exception as e:
            error_message = f"Unexpected error: {str(e)}"
            print(f"  Error training item {item_index}: {error_message}")
            failed_items.append({'index': item_index, 'item': item, 'reason': error_message})

    print("\n--------------------")
    print("Training loop finished.")
    print(f"Successfully processed and trained {processed_count} items (each potentially adding EN and ZH pairs).")
    if failed_items:
        print(f"{len(failed_items)} items failed or were skipped during training:")
        # 可選擇取消註解以印出詳細錯誤
        # for failure in failed_items:
        #    print(f"  - Index: {failure['index']}, Reason: {failure['reason']}") # , Item: {failure['item']}")
    else:
        print("All items processed without errors (or were validly skipped due to empty fields).")

elif not training_data:
     print("Training data is empty or failed to load. Skipping training.")
else:
     print(f"Error: Expected training_data to be a list, but got {type(training_data)}. Skipping training.")

Successfully loaded 24 training examples from 'training_data_bilingual.json'.
Starting Vanna training...

Processing item 1/24...
  Training (EN): Q: Show me all product names.... SQL: SELECT product_name FROM Products;...
  Training (ZH): Q: 顯示所有產品的名稱。... SQL: SELECT product_name FROM Products;...
  Item 1 trained successfully (both languages).

Processing item 2/24...
  Training (EN): Q: List the first name, last name, and email of customers whose first name is 'John... SQL: SELECT first_name, last_name, email FROM Customers WHERE first_name = 'John';...
  Training (ZH): Q: 列出名字是 'John' 的客戶的名字、姓氏和電子郵件。... SQL: SELECT first_name, last_name, email FROM Customers WHERE first_name = 'John';...
  Item 2 trained successfully (both languages).

Processing item 3/24...
  Training (EN): Q: Find the top 10 most expensive products, showing their name and price, ordered b... SQL: SELECT TOP 10 product_name, price FROM Products ORDER BY price DESC;...
  Training (ZH): Q: 找出前10個最貴的產品，顯示它們的名稱和價格，並按

In [21]:
# At any time you can inspect what training data the package is able to reference
full_training_data = vn.get_training_data()
full_training_data

Unnamed: 0,id,question,content,training_data_type
0,010c4f84-9f2b-569b-ae49-68a8f2ee2207-sql,Find all customers whose email address ends wi...,"SELECT customer_id, first_name, last_name, ema...",sql
1,0393e7e6-6aee-5027-9040-068503141d77-sql,建立一個名為 'ProductCategoryView' 的檢視表，顯示產品名稱及其對應的分...,CREATE VIEW ProductCategoryView AS\nSELECT \n ...,sql
2,05dd4b7e-507d-5208-9272-ec24df51bc4a-sql,List the customer ID and the total number of o...,"SELECT C.customer_id, COUNT(O.order_id) AS Tot...",sql
3,090950d2-5392-56ba-95f6-b144ff50e5f4-sql,Create a view named 'ProductCategoryView' that...,CREATE VIEW ProductCategoryView AS\nSELECT \n ...,sql
4,096443ea-758f-5cc5-a87a-8e595abc83fa-sql,新增一位客戶，姓名為 'Peter Pan'，電子郵件是 'peter.p@neverlan...,"INSERT INTO Customers (first_name, last_name, ...",sql
5,0f2ad796-0271-56f3-828d-f4973a971ba8-sql,找出訂單 ID 為 5 的產品名稱和數量。,"SELECT P.product_name, OI.quantity \nFROM Orde...",sql
6,116b6447-e81a-5d3f-822a-77fccf977486-sql,Count the number of products in each category....,"SELECT Cat.category_name, COUNT(P.product_id) ...",sql
7,19ecb329-2c3c-564c-b3d9-c8810b2799c7-sql,列出下單超過3次的客戶的客戶ID及其訂單總數。,"SELECT C.customer_id, COUNT(O.order_id) AS Tot...",sql
8,2004a09b-6c9a-5dc4-9fe8-9d415c1aa1ed-sql,顯示價格高於所有產品平均價格的產品。,"SELECT product_name, price \nFROM Products \nW...",sql
9,24d1c632-6d6e-5826-b16e-afd50a0e1c1f-sql,計算每個分類中有多少產品。顯示分類名稱和產品數量。,"SELECT Cat.category_name, COUNT(P.product_id) ...",sql


# 4. Talk to the database

In [8]:
vn.generate_sql("找出訂單 ID 為 5 的產品名稱和數量")


SQL Prompt: ["You are a T-SQL / Microsoft SQL Server expert. Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. \n===Tables \n-- Drop tables if they exist (optional, for easier re-running)\nCREATE TABLE Categories (\n    category_id INT PRIMARY KEY IDENTITY(1,1), -- Use IDENTITY for SQL Server. Auto-increment for MYSQL. SERIAL FOR PostgreSQL.\n    category_name VARCHAR(255) NOT NULL UNIQUE\n);\nCREATE TABLE Customers (\n    customer_id INT PRIMARY KEY IDENTITY(1,1), -- Use IDENTITY for SQL Server auto-increment\n    first_name VARCHAR(255) NOT NULL,\n    last_name VARCHAR(255) NOT NULL,\n    email VARCHAR(255) NOT NULL UNIQUE,\n    address TEXT,\n    phone VARCHAR(50) -- Increased size to accommodate longer phone numbers\n);\nCREATE TABLE Products (\n    product_id INT PRIMARY KEY IDENTITY(1,1), -- Use IDENTITY for SQL Server auto-increment\n    product_name VARC

'SELECT P.product_name, OI.quantity\nFROM Order_Items AS OI\nINNER JOIN Products AS P ON OI.product_id = P.product_id\nWHERE OI.order_id = 5;'

In [None]:
vn.ask("找出訂單 ID 為 5 的產品名稱和數量")

SQL Prompt: ["You are a T-SQL / Microsoft SQL Server expert. Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. \n===Tables \n-- Drop tables if they exist (optional, for easier re-running)\nCREATE TABLE Categories (\n    category_id INT PRIMARY KEY IDENTITY(1,1), -- Use IDENTITY for SQL Server. Auto-increment for MYSQL. SERIAL FOR PostgreSQL.\n    category_name VARCHAR(255) NOT NULL UNIQUE\n);\nCREATE TABLE Customers (\n    customer_id INT PRIMARY KEY IDENTITY(1,1), -- Use IDENTITY for SQL Server auto-increment\n    first_name VARCHAR(255) NOT NULL,\n    last_name VARCHAR(255) NOT NULL,\n    email VARCHAR(255) NOT NULL UNIQUE,\n    address TEXT,\n    phone VARCHAR(50) -- Increased size to accommodate longer phone numbers\n);\nCREATE TABLE Products (\n    product_id INT PRIMARY KEY IDENTITY(1,1), -- Use IDENTITY for SQL Server auto-increment\n    product_name VARC

In [None]:
from vanna.flask import VannaFlaskApp
app = VannaFlaskApp(vn)
app.run()