In [1]:
from langchain_core.prompts import PromptTemplate

common_prompt_template = PromptTemplate.from_template(
    """You are an expert in {dialect}. Your job is to read and understand the following [Database Schema] description, along with any [Reference Information], and then use your knowledge of {dialect} to generate an SQL statement that answers the [User Question]. Pay attention to the [Database Schema], only use tables and columns that are in the [Database Schema]. Avoid using any other tables or columns that are not in the [Database Schema].

[User Question]
{user_question}

[Database Schema]
{schema}

[Reference Information]
{example_rows}

ONLY OUTPUT THE SQL STATEMENT, NO OTHER TEXT.
"""
)
xiyan_en_prompt_template = PromptTemplate.from_template(
    """You are an expert in {dialect}. Your job is to read and understand the following 【Schema】 description, along with any 【Evidence】, and then use your knowledge of {dialect} to generate an SQL statement that answers the 【Question】.

【Question】
{user_question}

【Schema】
{schema}

【Evidence】
{example_rows}

【Question】
{user_question}
```sql
"""
)
xiyan_cn_prompt_template = PromptTemplate.from_template(
    """你是一名{dialect}专家，现在需要阅读并理解下面的【数据库schema】描述，以及可能用到的【参考信息】，并运用{dialect}知识生成sql语句回答【用户问题】。
【用户问题】
{user_question}

【数据库schema】
{schema}

【参考信息】
{example_rows}

【用户问题】
{user_question}

```sql"""
)

In [2]:
llama_model_base_config = {
    "model_name": "llama3.2:3b",
    "model_config_init": {
        "temperature": 0.1,
    },
    "prompt_template": common_prompt_template,
    "inference_type": "ollama",
    "use_mschema": False,
}
xiyan7b16_model_base_config = {
    "model_name": "hf.co/mradermacher/XiYanSQL-QwenCoder-7B-2504-GGUF:F16",
    "model_config_init": {"temperature": 0.1, "top_p": 0.8},
    "prompt_template": xiyan_en_prompt_template,
    "inference_type": "ollama",
}

model_config = {
    "llama32_3b_base": {
        **llama_model_base_config,
        "use_example_rows": False,
        "retry_times": 0,
    },
    "llama32_3b_w_ex": {
        **llama_model_base_config,
        "use_example_rows": True,
        "retry_times": 0,
    },
    "llama32_3b_w_re": {
        **llama_model_base_config,
        "use_example_rows": False,
        "retry_times": 10,
    },
    "llama32_3b_all": {
        **llama_model_base_config,
        "use_example_rows": True,
        "retry_times": 10,
    },
    "xiyansql_7b_16_base": {
        **xiyan7b16_model_base_config,
        "use_mschema": False,
        "use_example_rows": False,
        "retry_times": 0,
    },
    "xiyansql_7b_16_w_ex": {
        **xiyan7b16_model_base_config,
        "use_mschema": False,
        "use_example_rows": True,
        "retry_times": 0,
    },
    "xiyansql_7b_16_w_re": {
        **xiyan7b16_model_base_config,
        "use_mschema": False,
        "use_example_rows": False,
        "retry_times": 10,
    },
    "xiyansql_7b_16_all": {
        **xiyan7b16_model_base_config,
        "use_mschema": False,
        "use_example_rows": True,
        "retry_times": 10,
    },
    "xiyansql_7b_8_all": {
        **xiyan7b16_model_base_config,
        "model_name": "hf.co/mradermacher/XiYanSQL-QwenCoder-7B-2504-GGUF:Q8_0",
        "use_mschema": False,
        "use_example_rows": True,
        "retry_times": 10,
    },
}

In [3]:
from langchain_ollama import ChatOllama
from langchain_community.llms import VLLM
import pandas as pd
from tqdm.notebook import tqdm
from sqlalchemy import create_engine
from sqlalchemy.sql import text
from langchain_core.prompt_values import PromptValue

tqdm.pandas()


class ModelCreater:
    def __init__(self):
        self.llm = None
        self.model_name = None
        self.infer_type = None

    def create_model(self, model_config: dict):
        if (
            self.llm
            and self.model_name == model_config["model_name"]
            and self.infer_type == model_config["inference_type"]
        ):
            return self.llm
        else:
            # shutdown created vllm
            if self.infer_type == "vllm":
                pass

            if model_config["inference_type"] == "ollama":
                self.llm = ChatOllama(
                    model=model_config["model_name"],
                    **model_config["model_config_init"],
                )
            elif model_config["inference_type"] == "vllm":
                self.llm = VLLM(
                    model=model_config["model_name"],
                    trust_remote_code=True,
                    **model_config["model_config_init"],
                )
            else:
                raise ValueError(
                    f"Inference type {model_config['inference_type']} not supported"
                )
        self.model_name = model_config["model_name"]
        self.infer_type = model_config["inference_type"]
        return self.llm


model_creater = ModelCreater()
retry_prompt_template = PromptTemplate.from_template(
    """Your query is invalid. Carefully read the table schemas and the user question, and then regenerate a new query.
    
User question:
{user_question}

Table schemas:
{schema}

Here is your generated query:
{invalid_query}

And here is the error:
{error_message}

ONLY OUTPUT THE SQL STATEMENT, NO OTHER TEXT.
"""
)


def check_sql_valid(sql: str, db_path: str):
    engine = create_engine(f"sqlite:///{db_path}")
    try:
        with engine.connect() as conn:
            conn.execute(text("EXPLAIN " + sql))
    except Exception as e:
        return False, repr(e)
    return True, ""


def gen_sql(llm, prompt: PromptValue):
    message = llm.invoke(prompt)
    if not isinstance(message, str):
        message = message.content
    message = (
        message.strip()
        .replace("```sql", "")
        .replace("```", "")
        .replace("\n", "")
        .replace("\t", "")
    )
    return message


def benchmark_model(
    model: str,
    dialect: str,
    input_file_path: str,
    output_file_path: str,
    db_root_dir: str,
):
    config = model_config[model]
    llm = model_creater.create_model(config)
    prompt_template = config["prompt_template"]

    def generate_sql(
        user_question: str,
        schema: str,
        dialect: str,
        db_id: str,
        example_rows: list[str] | None = None,
    ):
        prompt_msg = prompt_template.invoke(
            {
                "user_question": user_question,
                "schema": schema,
                "dialect": dialect,
                "example_rows": example_rows if example_rows else "None",
            }
        )
        message = gen_sql(llm, prompt_msg)

        if config["retry_times"] > 0:
            db_path = f"{db_root_dir}/{db_id}/{db_id}.sqlite"
            is_valid, err_msg = check_sql_valid(message, db_path)

            i = 0
            while not is_valid and i < config["retry_times"]:
                retry_prompt_msg = retry_prompt_template.invoke(
                    {
                        "user_question": user_question,
                        "schema": schema,
                        "invalid_query": message,
                        "error_message": err_msg,
                    }
                )
                print(
                    f"({i + 1}/{config['retry_times']}) Retry prompt: {retry_prompt_msg}"
                )
                message = gen_sql(llm, retry_prompt_msg)
                is_valid, err_msg = check_sql_valid(message, db_path)
                i += 1

        return message

    df = pd.read_csv(input_file_path)
    inference_df = df.assign(
        answer=lambda df_: df_.progress_apply(
            lambda row: generate_sql(
                user_question=row.question,
                schema=row.schemas if not config["use_mschema"] else row.mschemas,
                dialect=dialect,
                db_id=row.db_id,
                example_rows=row.example_rows if config["use_example_rows"] else None,
            ),
            axis=1,
        )
    )
    with open(output_file_path, "w+") as f:
        answer_list = inference_df["answer"].tolist()
        f.write("\n".join(answer_list))

    return inference_df

In [4]:
test_list = [
    ("llama32_3b_base", "dev"),
    ("llama32_3b_w_ex", "dev"),
    ("llama32_3b_w_re", "dev"),
    ("llama32_3b_all", "dev"),
    ("xiyansql_7b_16_base", "dev"),
    ("xiyansql_7b_16_w_ex", "dev"),
    ("xiyansql_7b_16_w_re", "dev"),
    ("xiyansql_7b_16_all", "dev"),
    ("llama32_3b_all", "test"),
    ("xiyansql_7b_16_all", "test"),
    ("xiyansql_7b_8_all", "test"),
]

In [5]:
for model_name, dataset in test_list:
    dialect = "sqlite"
    input_file_path = f"prepare_data/{dataset}_input.csv"
    output_file_path = f"inference_data/{model_name}_{dataset}_inf.txt"
    db_root_dir = (
        "spider_data/database" if dataset == "dev" else "spider_data/test_database"
    )
    print(f"Benchmarking {model_name}, {dataset} dataset")
    inference_df = benchmark_model(
        model_name, dialect, input_file_path, output_file_path, db_root_dir
    )

Benchmarking llama32_3b_base, dev dataset


  0%|          | 0/1034 [00:00<?, ?it/s]

Benchmarking llama32_3b_w_ex, dev dataset


  0%|          | 0/1034 [00:00<?, ?it/s]

Benchmarking llama32_3b_w_re, dev dataset


  0%|          | 0/1034 [00:00<?, ?it/s]

(1/10) Retry prompt: text='Your query is invalid. Carefully read the table schemas and the user question, and then regenerate a new query.\n\nUser question:\nShow the stadium name and capacity with most number of concerts in year 2014 or after.\n\nTable schemas:\nCREATE TABLE "concert" (\n"concert_ID" int,\n"concert_Name" text,\n"Theme" text,\n"Stadium_ID" text,\n"Year" text,\nPRIMARY KEY ("concert_ID"),\nFOREIGN KEY ("Stadium_ID") REFERENCES "stadium"("Stadium_ID")\n)\nCREATE TABLE "stadium" (\n"Stadium_ID" int,\n"Location" text,\n"Name" text,\n"Capacity" int,\n"Highest" int,\n"Lowest" int,\n"Average" int,\nPRIMARY KEY ("Stadium_ID")\n)\n\nHere is your generated query:\nSELECT T2.Name, T1.Capacity FROM concert AS T1 INNER JOIN stadium AS T2 ON T1.Stadium_ID = T2.Stadium_ID WHERE STRFTIME(\'%Y\', T1.Year) >= \'2014\' GROUP BY T1.Stadium_ID ORDER BY COUNT(T1.concert_ID) DESC LIMIT 1\n\nAnd here is the error:\nOperationalError(\'(sqlite3.OperationalError) no such column: T1.Capacity\')\n

  0%|          | 0/1034 [00:00<?, ?it/s]

(1/10) Retry prompt: text='Your query is invalid. Carefully read the table schemas and the user question, and then regenerate a new query.\n\nUser question:\nWhat is the name and capacity of the stadium with the most concerts after 2013 ?\n\nTable schemas:\nCREATE TABLE "concert" (\n"concert_ID" int,\n"concert_Name" text,\n"Theme" text,\n"Stadium_ID" text,\n"Year" text,\nPRIMARY KEY ("concert_ID"),\nFOREIGN KEY ("Stadium_ID") REFERENCES "stadium"("Stadium_ID")\n)\nCREATE TABLE "stadium" (\n"Stadium_ID" int,\n"Location" text,\n"Name" text,\n"Capacity" int,\n"Highest" int,\n"Lowest" int,\n"Average" int,\nPRIMARY KEY ("Stadium_ID")\n)\n\nHere is your generated query:\nSELECT T2.Name, T1.Capacity FROM concert AS T1 INNER JOIN stadium AS T2 ON T1.Stadium_ID = T2.Stadium_ID WHERE STRFTIME(\'%Y\', T1.Year) > \'2013\' GROUP BY T1.Stadium_ID ORDER BY COUNT(T1.concert_ID) DESC LIMIT 1\n\nAnd here is the error:\nOperationalError(\'(sqlite3.OperationalError) no such column: T1.Capacity\')\n\nONLY 

  0%|          | 0/1034 [00:00<?, ?it/s]

Benchmarking xiyansql_7b_16_w_ex, dev dataset


  0%|          | 0/1034 [00:00<?, ?it/s]

Benchmarking xiyansql_7b_16_w_re, dev dataset


  0%|          | 0/1034 [00:00<?, ?it/s]

(1/10) Retry prompt: text="Your query is invalid. Carefully read the table schemas and the user question, and then regenerate a new query.\n\nUser question:\nFind the first name of students who have cat or dog pet.\n\nTable schemas:\nCREATE TABLE Has_Pet (\n       StuID\t\tINTEGER,\n       PetID\t\tINTEGER,\n       FOREIGN KEY(PetID) REFERENCES Pets(PetID),\n       FOREIGN KEY(StuID) REFERENCES Student(StuID)\n)\nCREATE TABLE Pets (\n       PetID\t\tINTEGER PRIMARY KEY,\n       PetType\t\tVARCHAR(20),\n       pet_age INTEGER,\n       weight REAL\n)\nCREATE TABLE Student (\n       StuID    \tINTEGER PRIMARY KEY,\n       LName\t\tVARCHAR(12),\n       Fname\t\tVARCHAR(12),\n       Age\t\tINTEGER,\n       Sex\t\tVARCHAR(1),\n       Major\t\tINTEGER,\n       Advisor\t\tINTEGER,\n       city_code\tVARCHAR(3)\n)\n\nHere is your generated query:\nSELECT DISTINCT T1.Fname FROM Student AS T1 INNER JOIN Has_Pet AS T2 ON T1.StuID = T2.StuID INNER JOIN Pets AS T3 ON T3.PetID = T2.PetID WHERE T3.pet

  0%|          | 0/1034 [00:00<?, ?it/s]

(1/10) Retry prompt: text="Your query is invalid. Carefully read the table schemas and the user question, and then regenerate a new query.\n\nUser question:\nFind the first name of students who have cat or dog pet.\n\nTable schemas:\nCREATE TABLE Has_Pet (\n       StuID\t\tINTEGER,\n       PetID\t\tINTEGER,\n       FOREIGN KEY(PetID) REFERENCES Pets(PetID),\n       FOREIGN KEY(StuID) REFERENCES Student(StuID)\n)\nCREATE TABLE Pets (\n       PetID\t\tINTEGER PRIMARY KEY,\n       PetType\t\tVARCHAR(20),\n       pet_age INTEGER,\n       weight REAL\n)\nCREATE TABLE Student (\n       StuID    \tINTEGER PRIMARY KEY,\n       LName\t\tVARCHAR(12),\n       Fname\t\tVARCHAR(12),\n       Age\t\tINTEGER,\n       Sex\t\tVARCHAR(1),\n       Major\t\tINTEGER,\n       Advisor\t\tINTEGER,\n       city_code\tVARCHAR(3)\n)\n\nHere is your generated query:\nSELECT DISTINCT T1.Fname FROM Student AS T1 INNER JOIN Has_Pet AS T2 ON T1.StuID = T2.StuID INNER JOIN Pets AS T3 ON T3.PetID = T2.PetID WHERE T3.pet

  0%|          | 0/2147 [00:00<?, ?it/s]

(1/10) Retry prompt: text='Your query is invalid. Carefully read the table schemas and the user question, and then regenerate a new query.\n\nUser question:\nWhat are the names of clubs, ordered descending by the average earnings of players within each?\n\nTable schemas:\nCREATE TABLE "club" (\n"Club_ID" int,\n"Name" text,\n"Manager" text,\n"Captain" text,\n"Manufacturer" text,\n"Sponsor" text,\nPRIMARY KEY ("Club_ID")\n)\nCREATE TABLE "player" (\n"Player_ID" real,\n"Name" text,\n"Country" text,\n"Earnings" real,\n"Events_number" int,\n"Wins_count" int,\n"Club_ID" int,\nPRIMARY KEY ("Player_ID"),\nFOREIGN KEY ("Club_ID") REFERENCES "club"("Club_ID")\n)\n\nHere is your generated query:\nSELECT T1.Name FROM club AS T1 ORDER BY AVG(T2.Earnings) DESC\n\nAnd here is the error:\nOperationalError(\'(sqlite3.OperationalError) no such column: T2.Earnings\')\n\nONLY OUTPUT THE SQL STATEMENT, NO OTHER TEXT.\n'
(2/10) Retry prompt: text='Your query is invalid. Carefully read the table schemas and 

KeyboardInterrupt: 

In [41]:
# import pandas as pd

# input_df = pd.read_csv("prepare_data/dev_input.csv")
# print(
#     xiyan_en_prompt_template.invoke(
#         {
#             "dialect": "sqlite",
#             "user_question": input_df.iloc[48]["question"],
#             "schema": input_df.iloc[48]["schemas"],
#             "example_rows": input_df.iloc[48]["example_rows"],
#         }
#     ).text
# )