In [None]:
import json
import requests
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection

# -----------------------------
# 1. 讀 JSON schema
# -----------------------------
# 假設 schema.json 內容：
# [
#   {
#     "table_name": "employees",
#     "columns": [
#       {"name": "id", "desc": "Employee ID"},
#       {"name": "name", "desc": "Full name of employee"},
#       {"name": "salary", "desc": "Monthly salary"}
#     ]
#   }
# ]

with open("schema.json", "r", encoding="utf-8") as f:
    schema_data = json.load(f)

# -----------------------------
# 2. 連線 Milvus
# -----------------------------
MILVUS_HOST = "10.13.18.40"
MILVUS_PORT = "19530"
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)

# Collection 名稱
COLLECTION_NAME = "SQLRAG_ALS"

# -----------------------------
# 3. 建立 Collection (如果不存在)
# -----------------------------
if COLLECTION_NAME not in Collection.list():
    fields = [
        FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=768),  # 假設 embedding 768 維
        FieldSchema(name="table", dtype=DataType.VARCHAR, max_length=100),
        FieldSchema(name="column", dtype=DataType.VARCHAR, max_length=100),
        FieldSchema(name="description", dtype=DataType.VARCHAR, max_length=500)
    ]
    schema = CollectionSchema(fields=fields, description="SQL schema embeddings")
    collection = Collection(name=COLLECTION_NAME, schema=schema)
else:
    collection = Collection(COLLECTION_NAME)

# -----------------------------
# 4. embedding function
# -----------------------------
EMBED_API = "http://10.13.18.40:14514/embed"
def get_embedding(text: str):
    payload = {"model": "Conan-embedding-v1", "input": text}
    resp = requests.post(EMBED_API, json=payload)
    resp.raise_for_status()
    return resp.json()["embedding"]

# -----------------------------
# 5. 插入向量
# -----------------------------
vectors = []
tables = []
columns = []
descriptions = []

for table in schema_data:
    table_name = table["table_name"]
    for col in table["columns"]:
        col_name = col["name"]
        desc = col.get("desc", col_name)
        text_for_embedding = f"Table: {table_name}, Column: {col_name}, Description: {desc}"
        vec = get_embedding(text_for_embedding)
        
        vectors.append(vec)
        tables.append(table_name)
        columns.append(col_name)
        descriptions.append(desc)

# 插入 Milvus
collection.insert([vectors, tables, columns, descriptions])
collection.flush()
print(f"Inserted {len(vectors)} embeddings into {COLLECTION_NAME}.")

# -----------------------------
# 6. 查詢示例 (簡單向量搜尋)
# -----------------------------
# 假設自然語言查詢
query_text = "列出每個部門的平均薪水"
query_vec = get_embedding(query_text)

search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
results = collection.search([query_vec], "vector", param=search_params, limit=3, output_fields=["table","column","description"])

print("Top results:")
for res in results[0]:
    print(res.entity.get("table"), res.entity.get("column"), res.entity.get("description"))


In [None]:
# -----------------------------
# 7. 自然語言查詢 → 生成 SQL
# -----------------------------
LLM_URL = "http://10.13.18.40:55700/v1"
LLM_MODEL = "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B"

def generate_sql(nl_query: str, top_k: int = 3):
    # 1. embedding 搜尋相關 table/column
    query_vec = get_embedding(nl_query)
    search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
    results = collection.search([query_vec], "vector", param=search_params, limit=top_k, output_fields=["table","column","description"])
    
    # 2. 整理 prompt
    retrieved_info = []
    for res in results[0]:
        t = res.entity.get("table")
        c = res.entity.get("column")
        d = res.entity.get("description")
        retrieved_info.append(f"{t}.{c}: {d}")
    retrieved_text = "\n".join(retrieved_info)
    
    prompt = f"""
你是一個 SQL 生成助手，根據下列資料表欄位資訊，請將使用者的自然語言需求轉換成可執行 SQL。

資料表欄位:
{retrieved_text}

使用者需求:
{nl_query}

請生成對應 SQL：
"""
    # 3. 呼叫 DeepSeek LLM
    payload = {
        "model": LLM_MODEL,
        "prompt": prompt,
        "max_tokens": 500
    }
    resp = requests.post(LLM_URL, json=payload)
    resp.raise_for_status()
    sql = resp.json()["choices"][0]["text"]
    return sql.strip()

# -----------------------------
# 8. 測試
# -----------------------------
nl_query = "列出每個部門的平均薪水"
sql_query = generate_sql(nl_query)
print("生成 SQL：")
print(sql_query)


In [None]:
from vanna.milvus import Milvus_VectorStore
from vanna.vllm import Vllm

# 初始化 Milvus_VectorStore
store = Milvus_VectorStore(
    host="10.13.18.40",
    port="19530",
    collection_name="SQLRAG_ALS"
)

# 初始化 LLM
llm = Vllm(
    model_name="deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
    model_url="http://10.13.18.40:55700/v1"
)

# 將 schema JSON 轉成向量
schema_texts = []
for table in schema_data:
    table_name = table["table_name"]
    for col in table["columns"]:
        col_name = col["name"]
        desc = col.get("desc", "")
        schema_texts.append(f"Table: {table_name}, Column: {col_name}, Desc: {desc}")

# 存進 Milvus
store.add_documents(schema_texts)


In [None]:
query = "列出每個部門的平均薪水"

# Vanna.ai 的 RAG pipeline
result = llm.run_rag(
    query=query,
    retriever=store,
    top_k=3
)

print("生成 SQL:")
print(result)
