In [1]:
import sqlite3
from concurrent.futures import ThreadPoolExecutor, TimeoutError
from collections import defaultdict
import os
from src.query_fix.utils import *
from openai import OpenAI
from transformers import AutoProcessor, MarkupLMModel
import time

def fetch_results_with_timeout(database_path, query):
    """Function to fetch results with a timeout."""
    # Create a new connection and cursor in each thread
    conn = sqlite3.connect(database_path)
    cursor = conn.cursor()
    cursor.execute(query)
    
    # Extract column names only if it's a SELECT query
    if cursor.description is not None:
        columns = [description[0] for description in cursor.description]
    else:
        columns = []  # No columns for non-SELECT queries
    
    result = cursor.fetchall()
    conn.close()  # Close the connection after fetching results
    return columns, result

def query_fix(
    database_name: str,
    database_root_path: str,
    database_path: str,
    candidates: dict[str, str],
    model: str,
    ir: list[str],
    question: str,
    hint: str,
    ground_truth: str,
    markup_processor: AutoProcessor,
    markup_model: MarkupLMModel,
    n_retries: int=10,
):
    client = OpenAI(
        base_url=os.environ['BASE_URL_DEEPSEEK'],
        api_key=os.environ['API_KEY_DEEPSEEK']
    )

    llm = LLM(
        client = client,
        model = model, 
        gen_params = {
            'STREAM': False,
            'TEMPERATURE': 0,
            'MAX_NEW_TOKENS': 2048 
        }
    )

    methods = []
    candidates_tmp = []
    for k, v in candidates.items():
        methods.append(v)
        candidates_tmp.append(k)
    
    candidates = candidates_tmp

    fixed_flags = defaultdict(bool)
    fixed_queries = []
    qents = []
    method_percents = []
    all_qr = {}
    attempts = 0

    while attempts < n_retries:
        correct_index = False
        new_candidates = []
        intermediate_qr = []

        for i, query in enumerate(candidates):
            try:
                if fixed_flags[i] == 1:
                    new_candidates.append(query)

                    # Use ThreadPoolExecutor to apply timeout for fetchall
                    with ThreadPoolExecutor() as executor:
                        future = executor.submit(fetch_results_with_timeout, database_path, query)
                        try:
                            result, columns = future.result(timeout=5)  # Timeout in seconds
                        except TimeoutError:
                            print(f"Query {i} timed out.")
                            result = []  
                            columns = []

                    intermediate_qr.append((query, result))
                    continue
                else:
                    conn = sqlite3.connect(database_path)
                    cursor = conn.cursor()

                    # Use ThreadPoolExecutor to apply timeout for fetchall
                    with ThreadPoolExecutor() as executor:
                        future = executor.submit(fetch_results_with_timeout, database_path, query)
                        try:
                            result, columns = future.result(timeout=5)  # Timeout in seconds
                        except TimeoutError:
                            print(f"Query {i} timed out.")
                            result = []
                            columns = []
                    
                    print(f"query {i}, result {result}")
                    correct_flag = check_exec_accuracy(database_path=database_path, query=query, ground_truth_query=ground_truth)
                    if correct_flag:
                        correct_index = i
                    conn.close()
                    fixed_queries.append((query, result))
                    intermediate_qr.append((query, result))
                    new_candidates.append(query)
                    fixed_flags[i] = 1
            
            except Exception as e:
                fixed_flags[i] = 0
                query = query_fixer(
                    database_name=database_name,
                    database_root_path=database_root_path,
                    ir=ir,
                    query_to_correct=query,
                    question=question,
                    hint=hint,
                    result=e,
                    model=llm
                )
                query = parse_query_fix_output(query)
                new_candidates.append(query)
                intermediate_qr.append((query, e))

        all_qr[attempts] = intermediate_qr
        all_features = []
        
        for cand in new_candidates:
            try:
                # Use ThreadPoolExecutor to apply timeout for fetchall
                with ThreadPoolExecutor() as executor:
                    future = executor.submit(fetch_results_with_timeout, database_path, cand)
                    try:
                        results, columns = future.result(timeout=10)  # Timeout in seconds
                    except TimeoutError:
                        print(f"Query for candidate {cand} timed out.")
                        results = [] 
                        columns = []

                html_result = sql_result_to_html(column_names=columns, result=results)
                features = html_to_features(
                    html_string=html_result, 
                    markup_lm_processor=markup_processor, 
                    markup_lm_model=markup_model,
                )
                all_features.append(features.detach().squeeze(dim=0))
            
            except sqlite3.Error as e:
                print(f"An error occurred: {e}")
                html_result = sql_result_to_html(error=e)
                features = html_to_features(
                    html_string=html_result, 
                    markup_lm_processor=markup_processor, 
                    markup_lm_model=markup_model,
                )
                all_features.append(features.detach().squeeze(dim=0))
        
        pi_correct = False
        if correct_index:
            clusters_DB, pi_correct = cluster_sql_queries(embeddings=np.array(all_features), correct_ind=correct_index)
        else:
            clusters_DB = cluster_sql_queries(embeddings=np.array(all_features))
        
        entropy, cluster_percentages = calculate_semantic_entropy(clusters=clusters_DB, methods=methods)
        qents.append(entropy)
        method_percents.append(cluster_percentages)
        attempts += 1

    log_values = {"INTERMEDIATE_QR": all_qr} 
    if pi_correct:
        return fixed_queries, qents, log_values, pi_correct, method_percents
    else:
        return fixed_queries, qents, log_values, None, method_percents

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
database_name = "california_schools"
database_path = f"{os.environ['DATABASE_ROOT_PATH']}/{database_name}/{database_name}.sqlite"
database_root_path = f"{os.environ['DATABASE_ROOT_PATH']}/{database_name}"
candidates = {"""SELECT 'T1'.'City', 'T1'.'Low Grade', 'T2'.'Latitude' FROM 'frpm' AS 'T1' INNER JOIN 'schools' AS 'T2' ON 'T1'.'CDSCode' = 'T2'.'CDSCode' WHERE 'T2'.'State' = 'CA' AND 'T2'.'Latitude' = (SELECT MIN('T2'.'Latitude') FROM 'schools' AS 'T2' WHERE 'T2'.'State' = 'CA')""": 'DAC',
    """SELECT T1.city, T1.school_name, T1.lowest_grade
FROM (
    SELECT S1.city, S1.school_name, S1.grade, MIN(C1.latitude) AS lowest_latitude
    FROM schools AS S1
    INNER JOIN coordinates AS C1 ON S1.id_school = C1.id_school
    WHERE S1.state = 'CA'
    GROUP BY S1.id_school
) AS T1
WHERE T1.lowest_latitude = (
    SELECT MIN(lowest_latitude)
    FROM (
        SELECT S1.city, S1.school_name, S1.grade, MIN(C1.latitude) AS lowest_latitude
        FROM schools AS S1
        INNER JOIN coordinates AS C1 ON S1.id_school = C1.id_school
        WHERE S1.state = 'CA'
        GROUP BY S1.id_school
    ) AS T2
);""": 'DAC',
    """SELECT s.City, f.School Name, f.Low Grade
    FROM schools s
    JOIN frpm f ON s.CDSCode = f.CDSCode
    WHERE s.State = 'CA'
    AND s.Latitude = (SELECT MIN(Latitude) FROM schools WHERE State = 'CA')
    LIMIT 1""": 'DAC'}


model = 'tgi'
markup_processor = AutoProcessor.from_pretrained("microsoft/markuplm-base")
markup_model = MarkupLMModel.from_pretrained("microsoft/markuplm-base")

ir = ["`schools`.`City`.`San Diego`", "`frpm`.`Low Grade`", "`frpm`.`School Name`.`Vidya Mandir`", "`frpm`.`CDSCode`", "`schools`.`CDSCode`", "`schools`.`State`", "`schools`.`Latitude`"]
question = "In which city can you find the school in the state of California with the lowest latitude coordinates and what is its lowest grade? Indicate the school name."
hint = "State of California refers to state = 'CA'"
ground_truth = "SELECT T2.City, T1.`Low Grade`, T1.`School Name` FROM frpm AS T1 INNER JOIN schools AS T2 ON T1.CDSCode = T2.CDSCode WHERE T2.State = 'CA' ORDER BY T2.Latitude ASC LIMIT 1"

query_fix(
    database_name=database_name,
    database_path=database_path,
    database_root_path=database_root_path,
    candidates=candidates,
    model=model,
    ir=ir,
    question=question,
    ground_truth=ground_truth,
    markup_model=markup_model,
    markup_processor=markup_processor,
    hint=hint
)



An error occurred: no such column: T1.City
An error occurred: no such column: S1.school_name
An error occurred: no such column: f.School
An error occurred: no such column: T1.City
An error occurred: no such column: S1.school_name
An error occurred: no such column: f.School_Name
