In [1]:
# 1. which bucket the correct answer falls in - what is the probability pi of that bucket
# 2. which buckets each of the method's answer fall in - " " "

In [2]:
import numpy as np
from transformers import AutoProcessor, MarkupLMModel
from openai import OpenAI
import os
from src.model.inference_endpoints import LLM
import sqlite3
from src.query_fix.utils import (
    query_fixer, 
    parse_query_fix_output, 
    sql_result_to_html, 
    html_to_features, 
    cluster_sql_queries, 
    calculate_semantic_entropy,
    check_exec_accuracy
)
from collections import defaultdict

markup_processor = AutoProcessor.from_pretrained("microsoft/markuplm-base")
markup_model = MarkupLMModel.from_pretrained("microsoft/markuplm-base")

def query_fix(
    database_name: str,
    database_root_path: str,
    database_path: str,
    candidates: list[str],
    model: str,
    ir: list[str],
    question: str,
    hint: str,
    ground_truth: str,
    n_retries: int=10
):
    client = OpenAI(
        base_url=os.environ['BASE_URL_DEEPSEEK'],
        api_key=os.environ['API_KEY_DEEPSEEK']
    )

    llm = LLM(
        client = client,
        model = model, 
        gen_params = {
            'STREAM': False,
            'TEMPERATURE': 0,
            'MAX_NEW_TOKENS': 2048 
        }
    ) # Need to change this function

    fixed_flags = defaultdict(bool)
    fixed_queries = []
    qents= []
    attempts = 0
    while attempts < n_retries:
        new_candidates = []
        for i, query in enumerate(candidates):
            try:
                if fixed_flags[i] == 1:
                    new_candidates.append(query)
                    continue
                else:
                    conn = sqlite3.connect(database_path)
                    cursor = conn.cursor()
                    cursor.execute(query)
                    result = cursor.fetchall()
                    print(f"query {i}, result {result}")
                    correct_flag = check_exec_accuracy(database_path=database_path, query=query, ground_truth_query=ground_truth)
                    if correct_flag:
                        correct_index = i
                    conn.close()
                    fixed_queries.append((query, result))
                    new_candidates.append(query)
                    fixed_flags[i] = 1
            
            except Exception as e:
                print(e)
                fixed_flags[i] = 0
                query = query_fixer(
                    database_name=database_name,
                    database_root_path = database_root_path,
                    ir = ir,
                    query_to_correct = query,
                    question = question,
                    hint = hint,
                    result = e,
                    model = llm
                )
                query = parse_query_fix_output(query)
                new_candidates.append(query)
        
        all_features = []
        for i, cand in enumerate(new_candidates):
            try:
                # Connect to the SQLite database
                conn = sqlite3.connect(database_path)
                cursor = conn.cursor()
                
                # Execute the query
                cursor.execute(cand)
                
                # Fetch all results
                results = cursor.fetchall()
                columns = [description[0] for description in cursor.description]
                html_result = sql_result_to_html(column_names=columns, result=results)
                features = html_to_features(
                    html_string=html_result, 
                    markup_lm_processor=markup_processor, 
                    markup_lm_model=markup_model,
                )
                all_features.append(features.detach().squeeze(dim=0))
                
            except sqlite3.Error as e:
                print(f"An error occurred: {e}")
                html_result = sql_result_to_html(error=e)
                features = html_to_features(
                    html_string=html_result, 
                    markup_lm_processor=markup_processor, 
                    markup_lm_model=markup_model,
                )
                all_features.append(features.detach().squeeze(dim=0))
            
            print('shape', np.array(all_features).shape)
        clusters_DB, pi_correct = cluster_sql_queries(embeddings=np.array(all_features), correct_ind = correct_index)
        qents.append(calculate_semantic_entropy(clusters=clusters_DB))
        attempts+=1
        
    return fixed_queries, qents, pi_correct
    


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
database_name = "california_schools"
database_path = f"{os.environ['DATABASE_ROOT_PATH']}/{database_name}/{database_name}.sqlite"
database_root_path = f"{os.environ['DATABASE_ROOT_PATH']}/{database_name}"
candidates = [
    """SELECT 'T1'.'City', 'T1'.'Low Grade', 'T2'.'Latitude' FROM 'frpm' AS 'T1' INNER JOIN 'schools' AS 'T2' ON 'T1'.'CDSCode' = 'T2'.'CDSCode' WHERE 'T2'.'State' = 'CA' AND 'T2'.'Latitude' = (SELECT MIN('T2'.'Latitude') FROM 'schools' AS 'T2' WHERE 'T2'.'State' = 'CA')""",
    """SELECT T1.city, T1.school_name, T1.lowest_grade
FROM (
    SELECT S1.city, S1.'school name', S1.grade, MIN(C1.latitude) AS lowest_latitude
    FROM schools AS S1
    INNER JOIN coordinates AS C1 ON S1.'id school' = C1.'id school'
    WHERE S1.state = 'CA'
    GROUP BY S1.'id school'
) AS T1
WHERE T1.lowest_latitude = (
    SELECT MIN(lowest_latitude)
    FROM (
        SELECT S1.city, S1.school_name, S1.grade, MIN(C1.latitude) AS lowest_latitude
        FROM schools AS S1
        INNER JOIN coordinates AS C1 ON S1.id_school = C1.id_school
        WHERE S1.state = 'CA'
        GROUP BY S1.id_school
    ) AS T2
);""",
    """SELECT s.City, f.School Name, f.Low Grade
    FROM schools s
    JOIN frpm f ON s.CDSCode = f.CDSCode
    WHERE s.State = 'CA'
    AND s.Latitude = (SELECT MIN(Latitude) FROM schools WHERE State = 'CA')
    LIMIT 1""",
    "SELECT T2.City, T1.`Low Grade`, T1.`School Name` FROM frpm AS T1 INNER JOIN schools AS T2 ON T1.CDSCode = T2.CDSCode WHERE T2.State = 'CA' ORDER BY T2.Latitude ASC LIMIT 1"
]

model = 'tgi'

ir = ["`schools`.`City`.`San Diego`", "`frpm`.`Low Grade`", "`frpm`.`School Name`.`Vidya Mandir`", "`frpm`.`CDSCode`", "`schools`.`CDSCode`", "`schools`.`State`", "`schools`.`Latitude`"]
question = "In which city can you find the school in the state of California with the lowest latitude coordinates and what is its lowest grade? Indicate the school name."
hint = "State of California refers to state = 'CA'"
ground_truth = "SELECT T2.City, T1.`Low Grade`, T1.`School Name` FROM frpm AS T1 INNER JOIN schools AS T2 ON T1.CDSCode = T2.CDSCode WHERE T2.State = 'CA' ORDER BY T2.Latitude ASC LIMIT 1"

fixed_queries, quents, pi_correct = query_fix(
    database_name=database_name,
    database_path=database_path,
    database_root_path=database_root_path,
    candidates=candidates,
    model=model,
    ir=ir,
    question=question,
    hint=hint, 
    ground_truth = ground_truth
)

no such column: T1.City
no such table: coordinates
no such column: f.School
query 3, result [('Oroville', '7', 'Central Middle - RISE')]
An error occurred: no such column: T1.City
shape (1, 768)
An error occurred: no such column: S1.school_name
shape (2, 768)
An error occurred: no such column: f.School
shape (3, 768)
shape (4, 768)
no such column: T1.City
no such table: coordinates
no such column: f.School
An error occurred: no such column: T1.City
shape (1, 768)
An error occurred: no such table: coordinates
shape (2, 768)
shape (3, 768)
shape (4, 768)
no such column: T1.City
no such table: coordinates
no such column: f.School
An error occurred: no such column: T1.City
shape (1, 768)
An error occurred: near "Final": syntax error
shape (2, 768)
An error occurred: no such column: f.School_Name
shape (3, 768)
shape (4, 768)
no such column: T1.City
no such table: coordinates
no such column: f.School
An error occurred: no such column: T1.City
shape (1, 768)
An error occurred: no such column

In [4]:
quents

[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]

In [5]:
fixed_queries

[("SELECT T2.City, T1.`Low Grade`, T1.`School Name` FROM frpm AS T1 INNER JOIN schools AS T2 ON T1.CDSCode = T2.CDSCode WHERE T2.State = 'CA' ORDER BY T2.Latitude ASC LIMIT 1",
  [('Oroville', '7', 'Central Middle - RISE')])]

In [6]:
pi_correct

0.25