In [37]:
input_file_path = "prepare_data/dev_input.csv"
model_name = "llama3.2"
dialect = "sqlite"
output_file_path = "inference_data/dev_inference.txt"

In [20]:
from langchain_core.prompts import PromptTemplate

common_prompt_template = PromptTemplate.from_template(
    """You are an expert in {dialect}. Your job is to read and understand the following [Database Schema] description, along with any [Reference Information], and then use your knowledge of {dialect} to generate an SQL statement that answers the [User Question].

[User Question]
{user_question}

[Database Schema]
{schema}

[Reference Information]
None

ONLY OUTPUT THE SQL STATEMENT, NO OTHER TEXT.
"""
)

In [21]:
model_config = {
    "llama3.2": {
        "model_name": "llama3.2",
        "model_config_init": {
            "temperature": 0.0,
        },
        "prompt_template": common_prompt_template,
    }
}

In [23]:
from langchain_ollama import ChatOllama


llm = ChatOllama(
    model=model_config[model_name]["model_name"],
    **model_config[model_name]["model_config_init"]
)
llm.invoke("Hello world")

AIMessage(content="Hello! It's nice to meet you. Is there something I can help you with or would you like to chat?", additional_kwargs={}, response_metadata={'model': 'llama3.2', 'created_at': '2025-06-20T08:14:10.16633Z', 'done': True, 'done_reason': 'stop', 'total_duration': 3070754947, 'load_duration': 29486716, 'prompt_eval_count': 27, 'prompt_eval_duration': 547266687, 'eval_count': 25, 'eval_duration': 2492119598, 'model_name': 'llama3.2'}, id='run--0fa2758a-64ab-4fba-b50a-b38edd4d691f-0', usage_metadata={'input_tokens': 27, 'output_tokens': 25, 'total_tokens': 52})

In [39]:
prompt_template = model_config[model_name]["prompt_template"]


def generate_sql(user_question: str, schema_list: list[str], dialect: str):
    message = prompt_template.invoke(
        {
            "user_question": user_question,
            "schema": "\n".join(schema_list),
            "dialect": dialect,
        }
    )
    message = llm.invoke(message).content
    # format the message to be a valid sql statement
    message = (
        message.strip()
        .replace("```sql", "")
        .replace("```", "")
        .replace("\n", "")
        .replace("\t", "")
    )
    print(f"Generated answer for question: {user_question[:20]}...")
    return message

In [40]:
generate_sql(
    "How many singers do we have?",
    [
        'CREATE TABLE "singer" (\n"Singer_ID" int,\n"Name" text,\n"Country" text,\n"Song_Name" text,\n"Song_release_year" text,\n"Age" int,\n"Is_male" bool,\nPRIMARY KEY ("Singer_ID")\n)'
    ],
    dialect,
)

Generated answer for question: How many singers do ...


'SELECT COUNT(*) FROM singer;'

In [30]:
import pandas as pd

df = pd.read_csv(input_file_path)
df

Unnamed: 0,question_number,question,db_id,tables,schemas
0,0,How many singers do we have?,concert_singer,['singer'],"['CREATE TABLE ""singer"" (\n""Singer_ID"" int,\n""..."
1,1,What is the total number of singers?,concert_singer,['singer'],"['CREATE TABLE ""singer"" (\n""Singer_ID"" int,\n""..."
2,2,"Show name, country, age for all singers ordere...",concert_singer,['singer'],"['CREATE TABLE ""singer"" (\n""Singer_ID"" int,\n""..."
3,3,"What are the names, countries, and ages for ev...",concert_singer,['singer'],"['CREATE TABLE ""singer"" (\n""Singer_ID"" int,\n""..."
4,4,"What is the average, minimum, and maximum age ...",concert_singer,['singer'],"['CREATE TABLE ""singer"" (\n""Singer_ID"" int,\n""..."
...,...,...,...,...,...
1029,1029,What are the citizenships that are shared by s...,singer,['singer'],"['CREATE TABLE ""singer"" (\n""Singer_ID"" int,\n""..."
1030,1030,How many available features are there in total?,real_estate_properties,['Other_Available_Features'],['CREATE TABLE `Other_Available_Features` (\n`...
1031,1031,What is the feature type name of feature AirCon?,real_estate_properties,"['Other_Available_Features', 'Ref_Feature_Types']",['CREATE TABLE `Other_Available_Features` (\n`...
1032,1032,Show the property type descriptions of propert...,real_estate_properties,"['Properties', 'Ref_Property_Types']",['CREATE TABLE `Properties` (\n`property_id` I...


In [35]:
inference_df = df.head(2).assign(
    answer=lambda df_: df_.apply(
        lambda row: generate_sql(row.question, row.schemas, dialect), axis=1
    )
)
inference_df

Generated answer for question: How many singers do ...
Generated answer for question: What is the total nu...


Unnamed: 0,question_number,question,db_id,tables,schemas,answer
0,0,How many singers do we have?,concert_singer,['singer'],"['CREATE TABLE ""singer"" (\n""Singer_ID"" int,\n""...",SELECT COUNT(*) FROM singer;
1,1,What is the total number of singers?,concert_singer,['singer'],"['CREATE TABLE ""singer"" (\n""Singer_ID"" int,\n""...",SELECT COUNT(*) FROM singer;


In [38]:
with open(output_file_path, "w+") as f:
    answer_list = inference_df["answer"].tolist()
    f.write("\n".join(answer_list))