In [1]:
from openai import OpenAI
import os
from src.model.inference_endpoints import LLM
import sqlite3
from src.query_fix.utils import query_fixer, parse_query_fix_output

def query_fix(
    database_name: str,
    database_root_path: str,
    database_path: str,
    candidates: list[str],
    model: str,
    ir: list[str],
    question: str,
    hint: str,
    n_retries: int=5
):
    client = OpenAI(
        base_url=os.environ['BASE_URL_DEEPSEEK'],
        api_key=os.environ['API_KEY_DEEPSEEK']
    )

    llm = LLM(
        client = client,
        model = model, 
        gen_params = {
            'STREAM': False,
            'TEMPERATURE': 0,
            'MAX_NEW_TOKENS': 2048 
        }
    ) # Need to change this function

    fixed_queries = []
    for i, query in enumerate(candidates):
        attempts = 0
        candidate_set = []
        while attempts < n_retries:
            try:
                conn = sqlite3.connect(database_path)
                cursor = conn.cursor()
                cursor.execute(query)
                result = cursor.fetchall()
                print(f"query {i}, result {result}")
                conn.close()
                fixed_queries.append((query, result))
                break
            
            except Exception as e:
                print(e)
                query = query_fixer(
                    database_name=database_name,
                    database_root_path = database_root_path,
                    ir = ir,
                    query_to_correct = query,
                    question = question,
                    hint = hint,
                    result = e,
                    model = llm
                )
                query = parse_query_fix_output(query)
                print(query)
                
                attempts+=1
    return fixed_queries
    


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
database_name = "california_schools"
database_path = f"{os.environ['DATABASE_ROOT_PATH']}/{database_name}/{database_name}.sqlite"
database_root_path = f"{os.environ['DATABASE_ROOT_PATH']}/{database_name}"
candidates = [
    """SELECT 'T1'.'City', 'T1'.'Low Grade', 'T2'.'Latitude' FROM 'frpm' AS 'T1' INNER JOIN 'schools' AS 'T2' ON 'T1'.'CDSCode' = 'T2'.'CDSCode' WHERE 'T2'.'State' = 'CA' AND 'T2'.'Latitude' = (SELECT MIN('T2'.'Latitude') FROM 'schools' AS 'T2' WHERE 'T2'.'State' = 'CA')""",
    """SELECT T1.city, T1.school_name, T1.lowest_grade
FROM (
    SELECT S1.city, S1.school_name, S1.grade, MIN(C1.latitude) AS lowest_latitude
    FROM schools AS S1
    INNER JOIN coordinates AS C1 ON S1.id_school = C1.id_school
    WHERE S1.state = 'CA'
    GROUP BY S1.id_school
) AS T1
WHERE T1.lowest_latitude = (
    SELECT MIN(lowest_latitude)
    FROM (
        SELECT S1.city, S1.school_name, S1.grade, MIN(C1.latitude) AS lowest_latitude
        FROM schools AS S1
        INNER JOIN coordinates AS C1 ON S1.id_school = C1.id_school
        WHERE S1.state = 'CA'
        GROUP BY S1.id_school
    ) AS T2
);""",
    """SELECT s.City, f.School Name, f.Low Grade
    FROM schools s
    JOIN frpm f ON s.CDSCode = f.CDSCode
    WHERE s.State = 'CA'
    AND s.Latitude = (SELECT MIN(Latitude) FROM schools WHERE State = 'CA')
    LIMIT 1"""
]

model = 'tgi'

ir = ["`schools`.`City`.`San Diego`", "`frpm`.`Low Grade`", "`frpm`.`School Name`.`Vidya Mandir`", "`frpm`.`CDSCode`", "`schools`.`CDSCode`", "`schools`.`State`", "`schools`.`Latitude`"]
question = "In which city can you find the school in the state of California with the lowest latitude coordinates and what is its lowest grade? Indicate the school name."
hint = "State of California refers to state = 'CA'"

query_fix(
    database_name=database_name,
    database_path=database_path,
    database_root_path=database_root_path,
    candidates=candidates,
    model=model,
    ir=ir,
    question=question,
    hint=hint
)

no such column: T1.City

SELECT 'T1'.'City', 'T1'.'Low Grade', 'T2'.'Latitude' FROM 'frpm' AS 'T1' INNER JOIN 'schools' AS 'T2' ON 'T1'.'CDSCode' = 'T2'.'CDSCode' WHERE 'T2'.'State' = 'CA' AND 'T2'.'Latitude' = (SELECT MIN('T2'.'Latitude') FROM 'schools' AS 'T2' WHERE 'T2'.'State' = 'CA')

no such column: T1.City

SELECT T1.'City', T1.'Low Grade', T1.'School Name' FROM frpm AS T1 INNER JOIN schools AS T2 ON T1.'CDSCode' = T2.'CDSCode' WHERE T2.'State' = 'CA' AND T2.'Latitude' = (SELECT MIN(T2.'Latitude') FROM schools AS T2 WHERE T2.'State' = 'CA')

no such column: T1.City

SELECT T2.'City', T1.'Low Grade', T1.'School Name' FROM frpm AS T1 INNER JOIN schools AS T2 ON T1.'CDSCode' = T2.'CDSCode' WHERE T2.'State' = 'CA' AND T2.'Latitude' = (SELECT MIN(T2.'Latitude') FROM schools AS T2 WHERE T2.'State' = 'CA')

query 0, result [('San Ysidro', 'K', 'Willow Elementary')]
no such table: coordinates

SELECT T1.city, T1.school_name, T1.lowest_grade
FROM (
    SELECT S1.city, S1.school_name, S1.

[("\nSELECT T2.'City', T1.'Low Grade', T1.'School Name' FROM frpm AS T1 INNER JOIN schools AS T2 ON T1.'CDSCode' = T2.'CDSCode' WHERE T2.'State' = 'CA' AND T2.'Latitude' = (SELECT MIN(T2.'Latitude') FROM schools AS T2 WHERE T2.'State' = 'CA')\n",
  [('San Ysidro', 'K', 'Willow Elementary')]),
 ('\nSELECT s.City, f."School Name", f."Low Grade"\nFROM schools s\nJOIN frpm f ON s.CDSCode = f.CDSCode\nWHERE s.State = \'CA\'\nAND s.Latitude = (SELECT MIN(Latitude) FROM schools WHERE State = \'CA\')\nLIMIT 1;\n',
  [('San Ysidro', 'Willow Elementary', 'K')])]