## 1. Import Thư viện Cần Thiết

In [None]:
import pandas as pd
import google.generativeai as genai
import pyodbc
import json
import time
from collections import deque
import re
import datetime
import os

  from .autonotebook import tqdm as notebook_tqdm


## 2. Cấu hình Hệ thống

In [None]:
# DANH SÁCH API KEYS CHO GEMINI
GEMINI_API_KEYS = [
    os.environ.get('GOOGLE_API_KEY')
]

API_RATE_LIMIT_PER_MINUTE = 100
API_KEY_COOLDOWN_SECONDS = 70 
MAX_API_RETRIES_PER_QUESTION = len(GEMINI_API_KEYS) # Thử tối đa số lượng key có

api_key_usage = {key: {"count": 0, "last_reset_time": time.time()} for key in GEMINI_API_KEYS}
current_api_key_index = 0 

DB_SERVER = 'localhost'
DB_NAME = 'text_to_sql'
username = 'sa'
password = '123456'
DB_CONN_STRING = f'DRIVER={{ODBC Driver 17 for SQL Server}};SERVER={DB_SERVER};DATABASE={DB_NAME};UID={username};PWD={password}'

INPUT_CSV_FILE = 'question.csv'
OUTPUT_CSV_FILE = 'result.csv' 
SCHEMA_FILE = 'm-schema.txt'

## 3. Các Hàm Hỗ Trợ

### 3.1. Hàm Chuyển đổi Date/Datetime cho JSON

In [3]:
def json_date_converter(obj):
    """Chuyển đổi đối tượng datetime/date thành chuỗi ISO format cho JSON."""
    if isinstance(obj, (datetime.date, datetime.datetime)):
        return obj.isoformat()
    raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")

### 3.2. Hàm Quản lý và Xoay vòng API Key

In [4]:
def get_next_available_api_key():
    global current_api_key_index
    global api_key_usage
    start_index = current_api_key_index
    attempts = 0 
    max_attempts_for_key_search = len(GEMINI_API_KEYS) * 2 # Giới hạn số lần lặp để tránh kẹt

    while attempts < max_attempts_for_key_search:
        api_key = GEMINI_API_KEYS[current_api_key_index]
        usage_stats = api_key_usage[api_key]
        current_time = time.time()

        if current_time - usage_stats["last_reset_time"] > API_KEY_COOLDOWN_SECONDS:
            # print(f"Resetting counter for API key: ...{api_key[-4:]}")
            usage_stats["count"] = 0
            usage_stats["last_reset_time"] = current_time

        if usage_stats["count"] < API_RATE_LIMIT_PER_MINUTE:
            usage_stats["count"] += 1
            # print(f"Using API key: ...{api_key[-4:]}, Usage: {usage_stats['count']}/{API_RATE_LIMIT_PER_MINUTE}")
            genai.configure(api_key=api_key)
            return api_key, current_api_key_index # Trả về cả key và index của nó

        current_api_key_index = (current_api_key_index + 1) % len(GEMINI_API_KEYS)
        attempts += 1
        # print(f"API key ...{api_key[-4:]} reached limit or in cooldown. Switching.")

        if current_api_key_index == start_index and attempts >= len(GEMINI_API_KEYS):
            # print("All API keys cycled through and are rate-limited. Calculating wait time...")
            wait_times = [(stats["last_reset_time"] + API_KEY_COOLDOWN_SECONDS + 0.5) - current_time 
                          for key_str, stats in api_key_usage.items()]
            positive_wait_times = [wt for wt in wait_times if wt > 0]
            
            if not positive_wait_times:
                 min_wait_time = 0.5 # Nếu không có thời gian chờ dương, đợi một chút rồi thử lại
            else:
                 min_wait_time = min(positive_wait_times)
            
            if min_wait_time > 0:
                 print(f"All keys rate-limited. Waiting for {min_wait_time:.2f} seconds...")
                 time.sleep(min_wait_time)
            # Reset attempts để vòng lặp có thể tiếp tục thử tìm key
            attempts = 0 
            # Không cần reset start_index ở đây, vòng lặp sẽ tự tiếp tục từ current_api_key_index
            
    print("CRITICAL: Could not get an available API key after multiple attempts and waits.")
    return None, -1 # Không tìm thấy key nào

### 3.3. Hàm Trích xuất SQL từ Markdown

In [5]:
def extract_sql_from_markdown(markdown_text):
    match_sql_block = re.search(r"```sql\s*(.*?)\s*```", markdown_text, re.DOTALL | re.IGNORECASE)
    if match_sql_block:
        return match_sql_block.group(1).strip()
    match_generic_block = re.search(r"```\s*(SELECT .*?)\s*```", markdown_text, re.DOTALL | re.IGNORECASE)
    if match_generic_block:
        potential_sql = match_generic_block.group(1).strip()
        if potential_sql.lower().startswith("select"):
            return potential_sql
    lines = markdown_text.strip().split('\n')
    if len(lines) == 1 and lines[0].strip().lower().startswith("select"):
        return lines[0].strip()
    for i in range(len(lines) - 1, -1, -1):
        if lines[i].strip().lower().startswith("select"):
            potential_multiline_sql = "\n".join(lines[i:]).strip()
            return potential_multiline_sql
    return None

### 3.4. Hàm Sinh SQL với Gemini (Query Plan CoT) - Có Retry

In [None]:
def generate_sql_with_query_plan_cot(question, condition_json, db_schema):
    model_name = 'gemini-2.0-flash' # Hoặc model bạn đang dùng
    prompt_text = f"""
    You are an expert Text-to-SQL system. Your goal is to convert a natural language question into an MS SQL Server query 
    by first generating a step-by-step Query Execution Plan (like a Chain-of-Thought) and then deriving the SQL query from that plan.

    Database Schema:
    ---
    {db_schema}
    ---

    User Context/Condition (JSON format, use this to filter data if applicable):
    ---
    {condition_json}
    ---

    Question: {question}

    Follow these steps strictly for Query Plan CoT:
    **1. Understand the Question and Evidence:**
    *   Repeat the user's question.
    *   Identify key entities, relationships, and conditions mentioned in the question and user context.
    *   Map these to the database schema (tables, columns).
    **2. Develop a Query Execution Plan (Chain-of-Thought):**
    This plan should mimic the logical steps a database engine might take. 
    Think about it as a human-readable version of an EXPLAIN plan.
    *   **Preparation Steps (if any):** Initialize variables, temporary storage (conceptually).
    *   **Data Access Steps:** 
        *   Which table(s) to access first?
        *   How to filter rows in these initial tables (based on question and context)?
        *   If multiple tables are involved, how are they joined (e.g., INNER JOIN, LEFT JOIN on specific columns)? What are the join conditions?
        *   What operations are performed on the data (e.g., aggregation - COUNT, SUM, AVG; grouping - GROUP BY; sorting - ORDER BY)?
    *   **Result Computation/Delivery:**
        *   Which final columns need to be selected to answer the question?
        *   Is there any final filtering (e.g., HAVING clause)?
        *   Is there a limit on the number of results (e.g., TOP N)?
    **3. Translate Query Plan to MS SQL Server Query:**
    *   Based on the detailed Query Execution Plan above, construct the final MS SQL Server query.
    **Final Optimized MS SQL Server Query:**
    (IMPORTANT: Provide ONLY the final, executable MS SQL Server query below this line. Do not include any other text, explanations, or markdown formatting like ```sql ... ``` around the query. The query should be a single, valid SQL statement.)
    """

    for attempt in range(MAX_API_RETRIES_PER_QUESTION):
        active_api_key, key_index = get_next_available_api_key()
        if not active_api_key:
            return "Error: No available API key after multiple waits."

        # print(f"Attempt {attempt + 1}/{MAX_API_RETRIES_PER_QUESTION} using API key ...{active_api_key[-4:]}")
        model = genai.GenerativeModel(model_name)
        
        try:
            response = model.generate_content(prompt_text)
            raw_llm_output = response.text.strip()
            # print(f"Raw LLM output (Query Plan CoT, Attempt {attempt+1}): \n---\n{raw_llm_output}\n---")

            generated_sql = extract_sql_from_markdown(raw_llm_output)

            if generated_sql:
                return generated_sql # Thành công, trả về SQL
            else:
                if raw_llm_output.lower().strip().startswith("select"):
                    print(f"Warning (Attempt {attempt+1}): extract_sql_from_markdown failed, but raw output seems to be SQL.")
                    return raw_llm_output # Trả về output thô nếu nó có vẻ là SQL
                # Nếu LLM trả về nội dung không phải SQL và không phải lỗi API, có thể là do prompt hoặc vấn đề khác
                error_msg = f"Error (Attempt {attempt+1}): Could not parse SQL from LLM. Output: {raw_llm_output[:200]}..." 
                print(error_msg)
                if attempt < MAX_API_RETRIES_PER_QUESTION - 1:
                    print("Trying next API key or waiting...")
                    # Không cần làm gì thêm, vòng lặp sẽ tự lấy key mới
                else:
                    return error_msg # Hết số lần thử
        
        except Exception as e:
            print(f"Lỗi khi gọi Gemini API với key ...{active_api_key[-4:]} (Attempt {attempt+1}, Query Plan CoT): {e}")
            # Kiểm tra lỗi cụ thể từ API (ví dụ: rate limit, quota, không khả dụng)
            error_str = str(e).lower()
            if "rate limit" in error_str or "quota" in error_str or "429" in error_str or "resource has been exhausted" in error_str or "service unavailable" in error_str:
                api_key_usage[active_api_key]["count"] = API_RATE_LIMIT_PER_MINUTE # Đánh dấu key này đã dùng hết lượt
                print(f"API key ...{active_api_key[-4:]} marked as rate-limited/exhausted.")
                if attempt < MAX_API_RETRIES_PER_QUESTION - 1:
                    print("Retrying with next available API key or waiting...")
                    # Vòng lặp for sẽ tiếp tục, get_next_available_api_key sẽ được gọi lại
                    continue 
                else:
                    return f"Error generating SQL (Query Plan CoT) after {MAX_API_RETRIES_PER_QUESTION} attempts: API Error - {e}"
            else:
                # Lỗi không xác định từ API, không nên thử lại với cùng prompt
                return f"Error generating SQL (Query Plan CoT): Unhandled API Exception - {e}"
    return f"Error generating SQL (Query Plan CoT) after {MAX_API_RETRIES_PER_QUESTION} attempts: Failed to get valid SQL."

### 3.5. Hàm Đọc Schema từ File

In [7]:
def get_database_schema_from_file(filepath):
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            schema_str = f.read()
        return schema_str.strip()
    except FileNotFoundError:
        print(f"Lỗi: Không tìm thấy file schema '{filepath}'.")
        return "Error: Schema file not found. Please create it or check the path."
    except Exception as e:
        print(f"Lỗi khi đọc file schema: {e}")
        return f"Error reading schema file: {e}"

### 3.6. Hàm Thực thi SQL Query trên MS SQL Server

In [8]:
def execute_sql_query(sql_query):
    results = []
    error_message = None
    if not sql_query or \
       not isinstance(sql_query, str) or \
       not sql_query.strip() or \
       "error:" in sql_query.lower():
        return [], f"Invalid SQL query provided to execute_sql_query: '{sql_query}'"
    cleaned_sql_query = sql_query 
    if not cleaned_sql_query.lower().startswith("select"): 
         return [], f"Non-SELECT SQL query provided or malformed: '{cleaned_sql_query}'"
    try:
        conn = pyodbc.connect(DB_CONN_STRING)
        cursor = conn.cursor()
        cursor.execute(cleaned_sql_query)
        if cursor.description is None:
            if cursor.rowcount != -1: 
                return [[]], None 
            else: 
                return [[{"QueryExecutionStatusUnknown": True}]], "Query executed but status is unknown (no description, rowcount -1)"
        columns = [column[0] for column in cursor.description]
        dict_results = []
        for row_values in cursor.fetchall():
            row_dict = {}
            for idx, col_name in enumerate(columns):
                row_dict[col_name] = row_values[idx]
            dict_results.append(row_dict)
        conn.close()
        return dict_results, error_message
    except pyodbc.Error as ex:
        sqlstate = ex.args[0]
        error_message = f"Lỗi MS SQL Server ({sqlstate}): {str(ex.args[1])} executing query: {cleaned_sql_query}"
    except Exception as e:
        error_message = f"Lỗi không xác định khi thực thi SQL '{cleaned_sql_query}': {e}"
    return [], error_message # Trả về list rỗng cho data khi có lỗi

## 4. Hàm Chính (Main Execution)

In [9]:
def main():
    print(f"Đang đọc schema cơ sở dữ liệu từ file '{SCHEMA_FILE}'...")
    db_schema_str = get_database_schema_from_file(SCHEMA_FILE)
    if "Error:" in db_schema_str: 
        print(f"Không thể tiếp tục nếu không có schema: {db_schema_str}")
        return
    try:
        questions_df = pd.read_csv(INPUT_CSV_FILE)
    except FileNotFoundError:
        print(f"Lỗi: Không tìm thấy file '{INPUT_CSV_FILE}'. Vui lòng tạo file này.")
        return
    except Exception as e:
        print(f"Lỗi khi đọc file CSV: {e}")
        return
    results_log = []
    for index, row in questions_df.iterrows():
        question = row['Question']
        condition_value = row.get('Condition') 
        condition_str = str(condition_value) if pd.notna(condition_value) else "{}" 
        print(f"\nĐang xử lý câu hỏi {index + 1}/{len(questions_df)}: {question}")
        generated_sql = generate_sql_with_query_plan_cot(question, condition_str, db_schema_str)
        # print(f"SQL đã xử lý: {generated_sql}") 
        if generated_sql and "Error:" not in generated_sql: 
            query_results, execution_error = execute_sql_query(generated_sql)
            if execution_error:
                results_log.append({
                    'Question': question,
                    'Condition': condition_str,
                    'GeneratedSQL': generated_sql,
                    'ExecutionResult': 'DB_ERROR',
                    'ErrorMessage': execution_error,
                    'Result': "[]" 
                })
            else:
                results_log.append({
                    'Question': question,
                    'Condition': condition_str,
                    'GeneratedSQL': generated_sql,
                    'ExecutionResult': 'SUCCESS',
                    'ErrorMessage': '',
                    'Result': json.dumps(query_results, default=json_date_converter, indent=2)
                })
        else: 
            results_log.append({
                'Question': question,
                'Condition': condition_str,
                'GeneratedSQL': generated_sql, 
                'ExecutionResult': 'GENERATION_ERROR',
                'ErrorMessage': generated_sql, 
                'Result': "[]"
            })

    output_df = pd.DataFrame(results_log)
    try:
        output_df.to_csv(OUTPUT_CSV_FILE, index=False, encoding='utf-8-sig')
        print(f"\nKết quả đã được lưu vào file '{OUTPUT_CSV_FILE}'")
    except Exception as e:
        print(f"Lỗi khi lưu file CSV kết quả: {e}")

## 5. Chạy Chương trình

In [10]:
if __name__ == '__main__':
    main()

Đang đọc schema cơ sở dữ liệu từ file 'm-schema.txt'...

Đang xử lý câu hỏi 1/162: What is my student ID?

Đang xử lý câu hỏi 2/162: What is my major?

Đang xử lý câu hỏi 3/162: When did I start studying at the school?

Đang xử lý câu hỏi 4/162: What was the first semester I studied at the school?

Đang xử lý câu hỏi 5/162: Find users with birthdays in January.

Đang xử lý câu hỏi 6/162: Find students whose names start with 'Student'.

Đang xử lý câu hỏi 7/162: How many female students are in the 'Software Engineering' major?

Đang xử lý câu hỏi 8/162: Show the list of students in the 'Artificial Intelligence' major.

Đang xử lý câu hỏi 9/162: Show a list of students (ID and name) and their majors.

Đang xử lý câu hỏi 10/162: How many students are there in total at the school?

Đang xử lý câu hỏi 11/162: Which major has the fewest students?

Đang xử lý câu hỏi 12/162: Find information for user ID HE011.

Đang xử lý câu hỏi 13/162: List all students.

Đang xử lý câu hỏi 14/162: List the