In [21]:
from langchain_core.prompts import PromptTemplate

common_prompt_template = PromptTemplate.from_template(
    """You are an expert in {dialect}. Your job is to read and understand the following [Database Schema] description, along with any [Reference Information], and then use your knowledge of {dialect} to generate an SQL statement that answers the [User Question]. Pay attention to the [Database Schema], only use tables and columns that are in the [Database Schema]. Avoid using any other tables or columns that are not in the [Database Schema].

[User Question]
{user_question}

[Database Schema]
{schema}

[Reference Information]
None

ONLY OUTPUT THE SQL STATEMENT, NO OTHER TEXT.
"""
)
xiyan_en_prompt_template = PromptTemplate.from_template(
    """You are an expert in {dialect}. Your job is to read and understand the following [Database Schema] description, along with any [Reference Information], and then use your knowledge of {dialect} to generate an SQL statement that answers the [User Question].

[User Question]
{user_question}

[Database Schema]
{schema}

[Reference Information]
None

[User Question]
{user_question}
```sql
"""
)
xiyan_cn_prompt_template = PromptTemplate.from_template(
    """你是一名{dialect}专家，现在需要阅读并理解下面的【数据库schema】描述，以及可能用到的【参考信息】，并运用{dialect}知识生成sql语句回答【用户问题】。
【用户问题】
{user_question}

【数据库schema】
{schema}

【参考信息】
None

【用户问题】
{user_question}

```sql"""
)

In [22]:
model_config = {
    "llama32_3b": {
        "model_name": "llama3.2:3b",
        "model_config_init": {
            "temperature": 0.0,
        },
        "prompt_template": common_prompt_template,
    },
    "sqlllama_7b_16": {
        "model_name": "hf.co/mradermacher/SQL-Llama-v0.5-GGUF:F16",
        "model_config_init": {
            "temperature": 0.0,
        },
        "prompt_template": common_prompt_template,
    },
    "xiyansql_7b_8_en_prompt": {
        "model_name": "hf.co/mradermacher/XiYanSQL-QwenCoder-7B-2504-GGUF:Q8_0",
        "model_config_init": {
            "temperature": 0.0,
        },
        "prompt_template": xiyan_en_prompt_template,
    },
    "xiyansql_7b_8_cn_prompt": {
        "model_name": "hf.co/mradermacher/XiYanSQL-QwenCoder-7B-2504-GGUF:Q8_0",
        "model_config_init": {
            "temperature": 0.0,
        },
        "prompt_template": xiyan_cn_prompt_template,
    },
    "xiyansql_7b_16_en_prompt": {
        "model_name": "hf.co/mradermacher/XiYanSQL-QwenCoder-7B-2504-GGUF:F16",
        "model_config_init": {"temperature": 0.1, "top_p": 0.8},
        "prompt_template": xiyan_en_prompt_template,
    },
    "xiyansql_7b_16_cn_prompt": {
        "model_name": "hf.co/mradermacher/XiYanSQL-QwenCoder-7B-2504-GGUF:F16",
        "model_config_init": {"temperature": 0.1, "top_p": 0.8},
        "prompt_template": xiyan_cn_prompt_template,
    },
}

In [26]:
from langchain_ollama import ChatOllama
import pandas as pd
from tqdm.notebook import tqdm

tqdm.pandas()


def benchmark_model(
    model: str, input_file_path: str, output_file_path: str, dialect: str
):
    config = model_config[model]
    llm = ChatOllama(model=config["model_name"], **config["model_config_init"])
    prompt_template = config["prompt_template"]

    def generate_sql(user_question: str, schema_list: list[str], dialect: str):
        message = prompt_template.invoke(
            {
                "user_question": user_question,
                "schema": "\n".join(schema_list),
                "dialect": dialect,
            }
        )
        message = llm.invoke(message).content
        # format the message to be a valid sql statement
        message = (
            message.strip()
            .replace("```sql", "")
            .replace("```", "")
            .replace("\n", "")
            .replace("\t", "")
        )
        # print(f"Generated answer for question: {user_question[:20]}...")
        return message

    df = pd.read_csv(input_file_path)
    inference_df = df.assign(
        answer=lambda df_: df_.progress_apply(
            lambda row: generate_sql(row.question, row.schemas, dialect), axis=1
        )
    )
    with open(output_file_path, "w+") as f:
        answer_list = inference_df["answer"].tolist()
        f.write("\n".join(answer_list))

    return inference_df

In [30]:
test_list = [
    # "llama32_3b",
    "sqlllama_7b_16",
    "xiyansql_7b_8_en_prompt",
    "xiyansql_7b_8_cn_prompt",
    # "xiyansql_7b_16_en_prompt",
    # "xiyansql_7b_16_cn_prompt",
]

In [None]:
for dataset in ["dev", "test"]:
    for model_name in test_list:
        dialect = "sqlite"
        input_file_path = f"prepare_data/{dataset}_input.csv"
        output_file_path = f"inference_data/{model_name}_{dataset}_inf.txt"
        print(f"Benchmarking {model_name}, {dataset} dataset")
        inference_df = benchmark_model(
            model_name, input_file_path, output_file_path, dialect
        )

Benchmarking sqlllama_7b_16, dev dataset


  0%|          | 0/1034 [00:00<?, ?it/s]

Benchmarking xiyansql_7b_8_en_prompt, dev dataset


  0%|          | 0/1034 [00:00<?, ?it/s]

Benchmarking xiyansql_7b_8_cn_prompt, dev dataset


  0%|          | 0/1034 [00:00<?, ?it/s]

Benchmarking sqlllama_7b_16, test dataset


  0%|          | 0/2147 [00:00<?, ?it/s]

Benchmarking xiyansql_7b_8_en_prompt, test dataset


  0%|          | 0/2147 [00:00<?, ?it/s]