In [1]:
import json
import logging
import os
import sqlite3
import time
from pathlib import Path

import openai
import pandas as pd
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from tqdm import tqdm


load_dotenv()

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[logging.StreamHandler()],
)
logger = logging.getLogger(__name__)
openai._utils._logs.logger.setLevel(logging.WARNING)  # noqa: SLF001
openai._utils._logs.httpx_logger.setLevel(logging.WARNING)  # noqa: SLF001

In [3]:
spider_path = Path("data/spider")

with (spider_path / "tables.json").open() as f:
    tables = json.load(f)

with (spider_path / "dev.json").open() as f:
    dev_examples = json.load(f)

with (spider_path / "dev_gold.sql").open() as f:
    gold_sql = f.readlines()

db_schemas = {table["db_id"]: table for table in tables}

In [5]:
llm_api_url = os.getenv("LLM_API_URL")
api_key = os.getenv("LLM_API_TOKEN")
llm_model = os.getenv("LLM_API_MODEL")
client = openai.Client(base_url=llm_api_url, api_key=api_key)


client.chat.completions.create(
    messages=[
        {
            "role": "user",
            "content": "Hi!",
        }
    ],
    model=llm_model,
).choices[0].message.content

'Hello! How can I assist you today?'

In [6]:
class SqlQuery(BaseModel):
    reasoning: str = Field(..., description="Напиши свои мысли, как ты формируешь sql запрос")
    sql_query: str | None


def get_db_schema(db_id: str) -> str:
    schema = db_schemas[db_id]
    table_names = schema["table_names_original"]
    column_data = schema["column_names_original"]

    table_columns = {table: [] for table in table_names}

    for table_idx, col_name in column_data:
        if table_idx == -1 or col_name == "*":
            continue
        table_name = table_names[table_idx]
        table_columns[table_name].append(col_name)

    result_lines = [f"Схема базы данных: {db_id}\n"]
    for table, columns in table_columns.items():
        result_lines.extend([f"Таблица: {table}", "Столбцы:", *[f"- {col}" for col in columns], ""])

    return "\n".join(result_lines)


def generate_sql(db_id: str, question: str, client: openai.Client, model: str) -> str:
    full_schema = get_db_schema(db_id)

    system_prompt = (
        "Ты — AI-ассистент, генерирующий SQL-запросы на основе пользовательских запросов.\n"
        "Ниже — схема базы данных:\n\n"
        f"{full_schema}\n\n"
        "1) Проанализируй запрос.\n"
        "2) Опиши reasoning.\n"
        "3) Cгенерируй корректный SELECT и верни его в поле sql_query.\n"
        "4) Оптимизируй запрос для минимальной нагрузки на БД."
    )

    user_prompt = f"Запрос пользователя: {question}"

    response = client.beta.chat.completions.parse(
        model=model,
        temperature=0.25,
        messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
        response_format=SqlQuery,
    )

    return response.choices[0].message.parsed.sql_query


def execute_sql(db_id: str, sql: str) -> str | None:
    db_path = spider_path / f"database/{db_id}/{db_id}.sqlite"
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    try:
        cursor.execute(sql)
        result = cursor.fetchall()
    except Exception:
        logger.exception("Error executing SQL:")
        return None
    finally:
        conn.close()
    return result


In [7]:
example = dev_examples[0]

print("Вопрос: ", example["question"])
print("Сгенерированный SQL:", generate_sql(example["db_id"], example["question"], client=client, model=llm_model))
print("Эталонный (gold) SQL: ", gold_sql[0].split("\t")[0])

Вопрос:  How many singers do we have?
Сгенерированный SQL: SELECT COUNT(*) AS number_of_singers FROM singer;
Эталонный (gold) SQL:  SELECT count(*) FROM singer


In [9]:
def compare_execution(generated_result: list, gold_result: list) -> bool:
    return generated_result == gold_result


gold_example = gold_sql[0].strip().split("\t")
gold_sql_query = gold_example[0]
gold_db_id = gold_example[1]
gold_execution_result = execute_sql(gold_db_id, gold_sql_query)
generation_result = generate_sql(example["db_id"], example["question"], client, 'gpt-4o')
execution_result = execute_sql(example["db_id"], generation_result)

if gold_execution_result:
    print(f"Execution Accuracy: {'Correct' if compare_execution(execution_result, gold_execution_result) else 'Incorrect'}")
else:
    print("Error: Unable to execute gold SQL query.")

Execution Accuracy: Correct


In [10]:
def evaluate_execution_accuracy(
    examples: list[dict],
    client: openai.Client,
    model_name: str,
    num_questions: int,
) -> tuple[float, float]:
    correct = 0
    total = 0
    total_response_time = 0.0
    response_count = 0

    for i, ex in enumerate(tqdm(examples[:num_questions], desc=f"Evaluating {model_name}")):
        question = ex["question"]
        db_id = ex["db_id"]
        gold_sql_query = ex["query"]

        gold_result = execute_sql(db_id, gold_sql_query)
        if isinstance(gold_result, str) and gold_result.startswith("Error"):
            logger.warning(f"[{i}] ⚠️ Ошибка в gold SQL: {gold_result}")
            continue

        try:
            start_time = time.perf_counter()
            generated_sql = generate_sql(db_id, question, client, model_name)
            elapsed = time.perf_counter() - start_time

            total_response_time += elapsed
            response_count += 1

            generated_result = execute_sql(db_id, generated_sql)
        except Exception:
            logger.exception(f"[{i}] ❌ Ошибка генерации или выполнения SQL")
            continue

        if isinstance(generated_result, str) and generated_result.startswith("Error"):
            logger.warning(f"[{i}] ⚠️ Ошибка выполнения сгенерированного SQL: {generated_result}")
            continue

        if compare_execution(generated_result, gold_result):
            correct += 1
            logger.info(f"[{i}] ✅ Корректно")
        else:
            logger.info(f"[{i}] ❌ Некорректно")

        total += 1

    accuracy = correct / total if total else 0.0
    avg_latency = total_response_time / response_count if response_count else 0.0

    logger.info(f"✅ Execution Accuracy: {accuracy:.2%}")
    logger.info(f"⏱️ Среднее время ответа: {avg_latency:.2f} сек")

    return accuracy, avg_latency


def get_openai_client(model_name: str) -> tuple[openai.Client, str]:
    if model_name.startswith("gpt"):
        base_url = os.getenv("LLM_API_URL_OPENAI")
        api_key = os.getenv("LLM_API_TOKEN_OPENAI")
        model = model_name
    else:
        base_url = os.getenv("LLM_API_URL")
        api_key = os.getenv("LLM_API_TOKEN")
        model = model_name

    client = openai.Client(base_url=base_url, api_key=api_key)
    return client, model


def benchmark_models_on_spider(models: list[str], examples: list[dict], num_questions: int = 100) -> pd.DataFrame:
    results = []

    for model_name in models:
        try:
            logger.info(f"🚀 Запуск модели: {model_name}")
            client, model_id = get_openai_client(model_name)

            acc, avg_latency = evaluate_execution_accuracy(
                examples=examples, client=client, model_name=model_id, num_questions=num_questions
            )

            results.append({"model": model_name, "execution_accuracy": acc, "avg_latency_sec": avg_latency})

        except Exception:
            logger.exception(f"❌ Model {model_name} failed")
            results.append({"model": model_name, "execution_accuracy": None, "avg_latency_sec": None})

    return pd.DataFrame(results)

In [11]:
len(dev_examples)

1034

In [12]:
models_to_test = ["gpt-4o", "gpt-4.1", "gpt-4o-mini"]

df_results = benchmark_models_on_spider(models_to_test, dev_examples, num_questions=500)
df_results.to_csv("data/evaluation_results_500.csv", index=None)

2025-05-15 11:24:56,517 [INFO] 🚀 Запуск модели: gpt-4o
Evaluating gpt-4o:   0%|          | 0/500 [00:00<?, ?it/s]2025-05-15 11:24:59,700 [INFO] [0] ✅ Корректно
Evaluating gpt-4o:   0%|          | 1/500 [00:03<26:14,  3.16s/it]2025-05-15 11:25:01,989 [INFO] [1] ✅ Корректно
Evaluating gpt-4o:   0%|          | 2/500 [00:05<21:57,  2.65s/it]2025-05-15 11:25:04,834 [INFO] [2] ✅ Корректно
Evaluating gpt-4o:   1%|          | 3/500 [00:08<22:40,  2.74s/it]2025-05-15 11:25:08,118 [INFO] [3] ✅ Корректно
Evaluating gpt-4o:   1%|          | 4/500 [00:11<24:24,  2.95s/it]2025-05-15 11:25:10,736 [INFO] [4] ✅ Корректно
Evaluating gpt-4o:   1%|          | 5/500 [00:14<23:21,  2.83s/it]2025-05-15 11:25:14,820 [INFO] [5] ✅ Корректно
Evaluating gpt-4o:   1%|          | 6/500 [00:18<26:49,  3.26s/it]2025-05-15 11:25:16,950 [INFO] [6] ❌ Некорректно
Evaluating gpt-4o:   1%|▏         | 7/500 [00:20<23:44,  2.89s/it]2025-05-15 11:25:20,705 [INFO] [7] ✅ Корректно
Evaluating gpt-4o:   2%|▏         | 8/500 [00:2