In [None]:
import pandas as pd
import google.generativeai as genai
import pyodbc
import json
import time
from collections import deque
import re
import datetime 
import os

In [None]:
# DANH SÁCH API KEYS CHO GEMINI
# Thay thế bằng các API key thực tế của bạn
GEMINI_API_KEYS = [
    os.environ.get('GOOGLE_API_KEY')
]

# Giới hạn tỷ lệ cho mỗi API key
API_RATE_LIMIT_PER_MINUTE = 100
API_KEY_COOLDOWN_SECONDS = 70 # 1 phút

# Cấu trúc để theo dõi việc sử dụng API key
api_key_usage = {key: {"count": 0, "last_reset_time": time.time()} for key in GEMINI_API_KEYS}
current_api_key_index = 0 # Bắt đầu với API key đầu tiên

# Cấu hình kết nối MS SQL Server
DB_SERVER = 'localhost'
DB_NAME = 'text_to_sql'
username = 'sa'
password = '123456'
DB_CONN_STRING = f'DRIVER={{ODBC Driver 17 for SQL Server}};SERVER={DB_SERVER};DATABASE={DB_NAME};UID={username};PWD={password}'

# Tên file input, output và schema
INPUT_CSV_FILE = 'question.csv'
OUTPUT_CSV_FILE = 'result.csv'
SCHEMA_FILE = 'm-schema.txt'

In [None]:
def json_date_converter(obj):
    """Chuyển đổi đối tượng datetime/date thành chuỗi ISO format cho JSON."""
    if isinstance(obj, (datetime.date, datetime.datetime)):
        return obj.isoformat()
    raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")

In [None]:
def get_next_available_api_key():
    global current_api_key_index
    global api_key_usage
    start_index = current_api_key_index
    while True:
        api_key = GEMINI_API_KEYS[current_api_key_index]
        usage_stats = api_key_usage[api_key]
        current_time = time.time()
        if current_time - usage_stats["last_reset_time"] > API_KEY_COOLDOWN_SECONDS:
            usage_stats["count"] = 0
            usage_stats["last_reset_time"] = current_time
        if usage_stats["count"] < API_RATE_LIMIT_PER_MINUTE:
            usage_stats["count"] += 1
            genai.configure(api_key=api_key)
            return api_key
        current_api_key_index = (current_api_key_index + 1) % len(GEMINI_API_KEYS)
        if current_api_key_index == start_index:
            wait_times = [(usage_stats["last_reset_time"] + API_KEY_COOLDOWN_SECONDS + 1) - current_time
                          for key_str, usage_stats in api_key_usage.items()]
            positive_wait_times = [wt for wt in wait_times if wt > 0]
            if not positive_wait_times: 
                 min_wait_time = 0.5 
            else:
                 min_wait_time = min(positive_wait_times)
            if min_wait_time > 0:
                 print(f"Waiting for {min_wait_time:.2f} seconds...")
                 time.sleep(min_wait_time)
            continue
    return None

def extract_sql_from_markdown(markdown_text):
    match = re.search(r"```(?:sql)?\s*(.*?)\s*```", markdown_text, re.DOTALL | re.IGNORECASE)
    if match:
        return match.group(1).strip()
    lines = markdown_text.strip().split('\n')
    if len(lines) == 1 and lines[0].strip().lower().startswith("select"):
        return lines[0].strip()
    for line in reversed(lines):
        cleaned_line = line.strip()
        if cleaned_line.lower().startswith("select"):
            idx = lines.index(line)
            potential_multiline_sql = "\n".join(lines[idx:]).strip()
            if potential_multiline_sql.count("SELECT") == 1 or potential_multiline_sql.count("select") == 1 :
                 return potential_multiline_sql
            return cleaned_line 
    return None

def generate_sql_with_divide_conquer_cot(question, condition_json, db_schema):
    active_api_key = get_next_available_api_key()
    if not active_api_key:
        return "Error: No available API key."
    model = genai.GenerativeModel('gemini-2.0-flash') 
    prompt = f"""
    You are an expert Text-to-SQL system using the Divide and Conquer Chain-of-Thought (CoT) method.
    Your goal is to convert a natural language question into an MS SQL Server query.
    Database Schema:
    ---
    {db_schema}
    ---
    User Context/Condition (JSON format, use this to filter data if applicable):
    ---
    {condition_json}
    ---
    Question: {question}
    Follow these steps strictly for Divide and Conquer CoT:
    **1. Divide and Conquer:**
    *   **Main Question:** Identify the main goal of the question.
    *   **Analysis (Main Question):** Briefly analyze the main question. What information is being requested? Which tables/columns might be involved based on the schema and question?
    *   **Pseudo SQL (Main Question):** Write a high-level pseudo SQL for the main question, indicating where sub-queries or complex logic will go.
    *   **Sub-question 1 (if any):** Break down the main question into the first logical sub-problem.
    *   **Analysis (Sub-question 1):** Analyze this sub-problem.
    *   **Pseudo SQL (Sub-question 1):** Write pseudo SQL for this sub-problem.
    *   **(Repeat for other sub-questions, e.g., Sub-question 1.1, Sub-question 2, etc. as needed for complex queries)**
    **2. Assembling SQL:**
    *   **SQL (Sub-question 1.1, then Sub-question 1, etc.):** Translate the pseudo SQL for the innermost sub-question into actual MS SQL Server SQL. Then, use this to build up the SQL for the next level of sub-questions.
    *   **SQL (Main Question):** Combine the SQL parts to form the complete SQL for the main question, using subqueries or JOINs as determined in the divide phase.
    **3. Simplification and Optimization (Optional but good):**
    *   Review the assembled SQL. Can any nested queries be simplified (e.g., using JOINs instead of complex subqueries)? Is the filtering applied correctly based on the User Context/Condition?
    **Final Optimized MS SQL Server Query:**
    (Provide ONLY the final, executable MS SQL Server query here. No other text, no explanations before or after the query. Ensure the query is a single SQL statement without markdown formatting.)
    """
    try:
        response = model.generate_content(prompt)
        raw_llm_output = response.text.strip()
        print(f"Raw LLM output: \n---\n{raw_llm_output}\n---") 
        generated_sql = extract_sql_from_markdown(raw_llm_output)
        if generated_sql:
            return generated_sql
        else:
            print(f"Cảnh báo: Không thể trích xuất SQL từ output của LLM. Output:\n{raw_llm_output}")
            return "Error: Could not parse SQL from LLM response."
    except Exception as e:
        print(f"Lỗi khi gọi Gemini API với key ...{active_api_key[-4:]}: {e}")
        if "rate limit" in str(e).lower() or "quota" in str(e).lower() or "429" in str(e):
            api_key_usage[active_api_key]["count"] = API_RATE_LIMIT_PER_MINUTE
            print(f"Rate limit có thể đã bị đạt sớm cho key ...{active_api_key[-4:]}.")
        return f"Error generating SQL: {e}"

def get_database_schema_from_file(filepath):
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            schema_str = f.read()
        return schema_str.strip()
    except FileNotFoundError:
        print(f"Lỗi: Không tìm thấy file schema '{filepath}'.")
        return "Error: Schema file not found. Please create it or check the path."
    except Exception as e:
        print(f"Lỗi khi đọc file schema: {e}")
        return f"Error reading schema file: {e}"

def execute_sql_query(sql_query):
    results = []
    error_message = None
    if not sql_query or \
       "error" in sql_query.lower() or \
       not isinstance(sql_query, str) or \
       not sql_query.strip():
        return [], f"Invalid SQL query provided to execute_sql_query: '{sql_query}'"
    cleaned_sql_query = sql_query.replace("```sql", "").replace("```", "").strip()
    if not cleaned_sql_query.lower().startswith("select"):
         return [], f"Non-SELECT SQL query provided or malformed: '{cleaned_sql_query}'"

    try:
        conn = pyodbc.connect(DB_CONN_STRING)
        cursor = conn.cursor()
        print(f"Đang thực thi SQL: {cleaned_sql_query}")
        cursor.execute(cleaned_sql_query)

        if cursor.description is None:
            if cursor.rowcount != -1 :
                print(f"Truy vấn thực thi thành công, {cursor.rowcount} dòng bị ảnh hưởng.")
                results.append([f"QueryExecutedSuccessfully {cursor.rowcount} rows affected"])
            else:
                print("Truy vấn thực thi nhưng không có description và rowcount không xác định.")
                results.append(["QueryExecutionStatusUnknown"])
                error_message = "Query executed but status is unknown (no description, rowcount -1)"
            conn.close()
            return results, error_message

        columns = [column[0] for column in cursor.description]
        for row_values in cursor.fetchall():
            row_dict = dict(zip(columns, row_values))
            results.append(row_dict)

        conn.close()
    except pyodbc.Error as ex:
        sqlstate = ex.args[0]
        error_message = f"Lỗi MS SQL Server ({sqlstate}): {ex.args[1]} executing query: {cleaned_sql_query}"
        print(error_message)
    except Exception as e:
        error_message = f"Lỗi không xác định khi thực thi SQL '{cleaned_sql_query}': {e}"
        print(error_message)
    return results, error_message

In [None]:
def main():
    print(f"Đang đọc schema cơ sở dữ liệu từ file '{SCHEMA_FILE}'...")
    db_schema_str = get_database_schema_from_file(SCHEMA_FILE)
    if "Error:" in db_schema_str:
        print(f"Không thể tiếp tục nếu không có schema: {db_schema_str}")
        return
    print("\nSchema cơ sở dữ liệu đã nạp:")
    print(db_schema_str)
    print("-" * 30)
    try:
        questions_df = pd.read_csv(INPUT_CSV_FILE)
    except FileNotFoundError:
        print(f"Lỗi: Không tìm thấy file '{INPUT_CSV_FILE}'. Vui lòng tạo file này.")
        return
    except Exception as e:
        print(f"Lỗi khi đọc file CSV: {e}")
        return
    results_log = []
    for index, row in questions_df.iterrows():
        question = row['Question']
        condition_str = str(row['Condition'])
        print(f"\nĐang xử lý câu hỏi {index + 1}: {question}")
        print(f"Điều kiện: {condition_str}")
        generated_sql = generate_sql_with_divide_conquer_cot(question, condition_str, db_schema_str)
        print(f"SQL đã xử lý: {generated_sql}") 
        if generated_sql and "Error:" not in generated_sql: 
            query_results, execution_error = execute_sql_query(generated_sql)
            if execution_error:
                results_log.append({
                    'Question': question,
                    'Condition': condition_str,
                    'GeneratedSQL': generated_sql,
                    'ExecutionResult': 'ERROR',
                    'ErrorMessage': execution_error,
                    'Result': "[]"
                })
            else:
                results_log.append({
                    'Question': question,
                    'Condition': condition_str,
                    'GeneratedSQL': generated_sql,
                    'ExecutionResult': 'SUCCESS',
                    'ErrorMessage': '',
                    'Result': json.dumps(query_results, default=json_date_converter, indent=2)
                })
        else: 
            results_log.append({
                'Question': question,
                'Condition': condition_str,
                'GeneratedSQL': generated_sql, 
                'ExecutionResult': 'GENERATION_ERROR',
                'ErrorMessage': generated_sql, 
                'Result': "[]"
            })
        print("-" * 30)
    output_df = pd.DataFrame(results_log)
    try:
        output_df.to_csv(OUTPUT_CSV_FILE, index=False, encoding='utf-8-sig')
        print(f"\nKết quả đã được lưu vào file '{OUTPUT_CSV_FILE}'")
    except Exception as e:
        print(f"Lỗi khi lưu file CSV kết quả: {e}")

if __name__ == '__main__':
    main()

Đang đọc schema cơ sở dữ liệu từ file 'm-schema.txt'...

Schema cơ sở dữ liệu đã nạp:
【DB_ID】 text_to_sql
【Schema】
# Table: Attendance
[
  (enrollment_id:VARCHAR, Primary Key, Foreign key referencing Enrollments table. Maps to Enrollments(enrollment_id)., Examples: ['HE001_AI01_AIL', 'HE001_AI01_MAE', 'HE001_AI01_NLP']),
  (schedule_id:VARCHAR, Primary Key, Foreign key referencing Schedules table. Maps to Schedules(schedule_id)., Examples: ['AI01_AIL_1', 'AI01_AIL_2', 'AI01_AIL_3']),
  (status:VARCHAR, Attendance status (e.g., Present, Absent)., Examples: ['Present', 'Future', 'Absent'])
]
【Foreign keys】
  Attendance.enrollment_id = Enrollments.enrollment_id
  Attendance.schedule_id = Schedules.schedule_id

# Table: ClassCourse
[
  (class_course_id:VARCHAR, Primary Key, Unique identifier for the class course., Examples: ['AI01_AIL', 'AI01_MAE', 'AI01_NLP']),
  (class_id:VARCHAR, Foreign key referencing Classes table. Maps to Classes(class_id)., Examples: ['AI01', 'IA01']),
  (course_id