In [37]:
from sqlalchemy import create_engine
import sqlalchemy
from langchain_community.utilities.sql_database import SQLDatabase
from vertexai import init
from langchain.chat_models import init_chat_model
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain.tools import Tool
from langchain import hub
from langchain_core.messages import HumanMessage
from langgraph.prebuilt import create_react_agent
import vertexai
from google.cloud import aiplatform
from langchain_google_vertexai import VertexAI
from langchain.chains import RetrievalQA
from typing import List, TypedDict
from langgraph.graph import StateGraph
from typing import Any
from pydantic import Field

from langchain_google_vertexai import (
    VertexAIEmbeddings,
    VectorSearchVectorStore,
)

from prompt import LLM_SQL_SYS_PROMPT
from IPython.display import Image, display

from langchain.schema import AIMessage
import json


import re

ROLE_ACCESS_LIST = {"cn":['CNY'], "kr":['KRW'], "gb":['TWD', 'KRW', 'CNY', 'USD']}
############################################
# 1. 設定資料庫連線
############################################
PROJECT_ID = "tsmccareerhack2025-bsid-grp2"
REGION = "us-central1"  
INSTANCE = "sql-instance-relational"
DATABASE = "postgres" 
TABLE_NAME = "fin_data" 
DB_HOST = "34.56.145.52"  # Cloud SQL Public IP
DB_PORT = "5432"  # PostgreSQL 預設端口
_USER = "postgres"
_PASSWORD = "postgres"

db_url = f'postgresql+psycopg2://{_USER}:{_PASSWORD}@{DB_HOST}:{DB_PORT}/{DATABASE}'
engine = sqlalchemy.create_engine(db_url)
db = SQLDatabase(engine)

############################################
# 2. 初始化 Vertex AI & LLM
############################################
init(project=PROJECT_ID, location=REGION)
llm = init_chat_model("gemini-2.0-flash", model_provider="google_vertexai")

############################################
# 3. 設定 SQL Agent 工具
############################################
def extract_sql_from_response(response):
    """從 LLM 回傳的內容中提取 SQL 查詢"""
    if "messages" in response:
        return response["messages"][-1].content.strip()
    return str(response).strip()

class SQLQueryGenerator(Tool):
    """覆寫 SQLDatabaseToolkit，讓 LLM 只產生 SQL，而不執行"""

    agent_executor: Any = Field(...)  # 明確宣告 agent_executor，讓 pydantic 允許這個屬性

    def __init__(self, agent_executor, **kwargs):
        super().__init__(
            name="Generate SQL",
            description="Generate SQL query based on user input without executing it.",
            func=self.run,  # 🛠 **這裡修正，指定一個可執行函式**
            agent_executor=agent_executor,  # 傳入 `agent_executor`
            **kwargs
        )

    def run(self, query):
        """讓 LLM 產生 SQL 查詢但不執行"""
        response = self.agent_executor.invoke({"messages": [HumanMessage(content=query)]})
        print(response)  # 確認 response 結構

        db_queries = []
        
        # 檢查 response["messages"] 內的 `tool_calls`
        for message in response["messages"]:
            if isinstance(message, AIMessage) and "function_call" in message.additional_kwargs:
                function_call = message.additional_kwargs["function_call"]
                print("function_call: ", function_call)

                if function_call["name"] == "sql_db_query":
                    try:
                        # 解析 JSON 字符串
                        function_args = json.loads(function_call["arguments"])
                        sql_query = function_args["query"]
                        db_queries.append(sql_query)
                    except json.JSONDecodeError as e:
                        print("JSON 解析錯誤: ", e)

            # 如果 `tool_calls` 是獨立陣列
            if "tool_calls" in message.additional_kwargs:
                for tool_call in message.additional_kwargs["tool_calls"]:
                    if tool_call["name"] == "sql_db_query":
                        sql_query = tool_call["args"]["query"]
                        db_queries.append(sql_query)

        print("db_queries: ", db_queries)
        
        if db_queries:
            return db_queries[0]  # 只取第一個 SQL 查詢
        return "No SQL query generated"


def modify_query_with_companies(query: str, allow_companies: list) -> str:
    """根據允許的公司列表修改 SQL 查詢，確保只能存取對應的 company_name，並正確處理 LIMIT"""

    # 如果 allow_companies 為 ['all']，則不做任何限制
    if allow_companies == ['all']:
        return query

    # 移除結尾的 `;`，避免影響拼接條件
    has_semicolon = query.strip().endswith(";")
    query = query.strip().rstrip(";")

    # 確保公司名稱正確格式化，轉換成 SQL IN ('xxx', 'yyy')
    company_filter = ", ".join([f"'{company}'" for company in allow_companies])

    # 使用正則表達式分割 `LIMIT`
    parts = re.split(r"(\bLIMIT\b\s+\d+)", query, flags=re.IGNORECASE)

    # 確保 `WHERE` 在 `LIMIT` 之前
    if len(parts) > 1:
        base_query = parts[0].strip()  # `LIMIT` 之前的查詢
        limit_clause = parts[1]  # `LIMIT` 子句
        if "WHERE" in base_query.upper():
            base_query += f" AND company_name IN ({company_filter})"
        else:
            base_query += f" WHERE company_name IN ({company_filter})"
        query = f"{base_query} {limit_clause}"
    else:
        if "WHERE" in query.upper():
            query += f" AND company_name IN ({company_filter})"
        else:
            query += f" WHERE company_name IN ({company_filter})"

    # 如果原本有 `;`，則補回
    if has_semicolon:
        query += ";"
    
    return query


toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
system_message = prompt_template.format(dialect="PostgreSQL", top_k=5)
agent_executor = create_react_agent(llm, tools, prompt=system_message)
sql_generator = SQLQueryGenerator(agent_executor)  # 創建 SQL 產生工具
USER_ROLE = "cn"




In [38]:
query = "Retrieve Amazon's `Revenue` data for Q1 2020."
generated_sql = sql_generator.run(query)

# 在查詢中加入權限控制條件
secured_sql = modify_query_with_companies(generated_sql, ["all"])
print(f"Secured SQL: {secured_sql}\n")

# 執行修改後的 SQL
final_response = db.run(secured_sql)
print("final_response: \n", final_response)

{'messages': [HumanMessage(content="Retrieve Amazon's `Revenue` data for Q1 2020.", additional_kwargs={}, response_metadata={}, id='da431bd6-cb0c-4521-9495-1553c600c1f0'), AIMessage(content="Okay, I need to find the relevant table that contains Amazon's revenue data. I'll start by listing the available tables.\n", additional_kwargs={'function_call': {'name': 'sql_db_list_tables', 'arguments': '{}'}}, response_metadata={'is_blocked': False, 'safety_ratings': [], 'usage_metadata': {'prompt_token_count': 521, 'candidates_token_count': 35, 'total_token_count': 556, 'prompt_tokens_details': [{'modality': 1, 'token_count': 521}], 'candidates_tokens_details': [{'modality': 1, 'token_count': 35}], 'cached_content_token_count': 0, 'cache_tokens_details': []}, 'finish_reason': 'STOP', 'avg_logprobs': -0.21232271194458008}, id='run-ea223a7d-2361-4453-b681-9683eac8c742-0', tool_calls=[{'name': 'sql_db_list_tables', 'args': {}, 'id': '90af05e6-2d19-4854-a8f8-82a1c856a073', 'type': 'tool_call'}], us

In [39]:
query = "SELECT * FROM fin_data;"
allow_companies = ["Baidu"]

secured_sql = modify_query_with_companies(query, allow_companies)
print(f"Secured SQL: {secured_sql}\n")

# 執行修改後的 SQL
final_response = db.run(secured_sql)
print("final_response: \n", final_response)

Secured SQL: SELECT * FROM fin_data WHERE company_name IN ('Baidu');

final_response: 
 [(571, 'Baidu', 'Cost of Goods Sold', 2020, 'Q1', 2104.1, 'CNY', 'Million', 14687.0), (572, 'Baidu', 'Operating Expense', 2020, 'Q1', 3292.46, 'CNY', 'Million', 22982.0), (573, 'Baidu', 'Operating Income', 2020, 'Q1', -62.61, 'CNY', 'Million', -437.0), (574, 'Baidu', 'Revenue', 2020, 'Q1', 3229.85, 'CNY', 'Million', 22545.0), (575, 'Baidu', 'Tax Expense', 2020, 'Q1', 28.37, 'CNY', 'Million', 198.0), (576, 'Baidu', 'Total Asset', 2020, 'Q1', 42220.32, 'CNY', 'Million', 299017.0), (577, 'Baidu', 'Cost of Goods Sold', 2020, 'Q2', 1853.15, 'CNY', 'Million', 13134.0), (578, 'Baidu', 'Operating Expense', 2020, 'Q2', 3159.13, 'CNY', 'Million', 22390.0), (579, 'Baidu', 'Operating Income', 2020, 'Q2', 514.15, 'CNY', 'Million', 3644.0), (580, 'Baidu', 'Revenue', 2020, 'Q2', 3673.29, 'CNY', 'Million', 26034.0), (581, 'Baidu', 'Tax Expense', 2020, 'Q2', 172.42, 'CNY', 'Million', 1222.0), (582, 'Baidu', 'Total A

In [40]:
query = "Local_Currency 有哪些"
# 讓 LLM 產生 SQL 查詢
generated_sql = sql_generator.run(query)
print(f"Generated SQL: {generated_sql}\n")

# 在查詢中加入權限控制條件
secured_sql = modify_query_with_companies(generated_sql, ["Baidu"])
print(f"Secured SQL: {secured_sql}\n")

# 執行修改後的 SQL
final_response = db.run(secured_sql)
print("final_response: \n", final_response)

{'messages': [HumanMessage(content='Local_Currency 有哪些', additional_kwargs={}, response_metadata={}, id='b098ee10-8ede-43f6-ab35-fd393382e2ae'), AIMessage(content='', additional_kwargs={'function_call': {'name': 'sql_db_list_tables', 'arguments': '{}'}}, response_metadata={'is_blocked': False, 'safety_ratings': [], 'usage_metadata': {'prompt_token_count': 509, 'candidates_token_count': 7, 'total_token_count': 516, 'prompt_tokens_details': [{'modality': 1, 'token_count': 509}], 'candidates_tokens_details': [{'modality': 1, 'token_count': 7}], 'cached_content_token_count': 0, 'cache_tokens_details': []}, 'finish_reason': 'STOP', 'avg_logprobs': -0.059874896492276876}, id='run-125bb51c-87e1-48b6-9dce-d78002777769-0', tool_calls=[{'name': 'sql_db_list_tables', 'args': {}, 'id': '17fc6922-bbfd-4d4f-b82c-bc28f597e0fc', 'type': 'tool_call'}], usage_metadata={'input_tokens': 509, 'output_tokens': 7, 'total_tokens': 516}), ToolMessage(content='fin_data, trainscript_data, users', name='sql_db_li

In [26]:
final_response

'[(75452.0,)]'

In [11]:
generated_sql

"Amazon's Revenue for Q1 2020 was 75452.0 Million USD."