In [17]:
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field

input_file = "prepare_data/{dataset}_input.csv"
table_list_wo_des_file = "prepare_data/{dataset}_table_list_wo_des.txt"
common_prompt = ChatPromptTemplate(
    [
        (
            "human",
            """This is the database and table list

{table_list_text}

Based on the database and table list, please select the tables that are most related to the user question.
Only select the tables that are listed in the database and table list.

{user_question}
""",
        )
    ]
)


class SelectedTable(BaseModel):
    table_name: str = Field(description="The name of the table")
    db_name: str = Field(description="The name of the database")


class SelectedTableList(BaseModel):
    selected_table_list: list[SelectedTable] = Field(
        description="The list of selected tables"
    )


configs = [
    {
        "dataset": "dev",
        "method": "llm",
        "model": "llama3.2:3b",
        "model_config": {
            "temperature": 0.0,
        },
        "input_file": input_file,
        "table_list_file": table_list_wo_des_file,
        "output_file": "inference_data/llama32_3b_retrieval_wo_des_{dataset}.csv",
        "prompt": common_prompt,
    }
]

In [18]:
import pandas as pd
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_ollama import ChatOllama


def retrieve(config: dict):
    if config["method"] == "llm":
        return retrieve_llm(config)
    else:
        raise ValueError(f"Invalid method: {config['method']}")


def gen_selected_table_list(
    llm: BaseChatModel,
    prompt: ChatPromptTemplate,
    table_list_text: str,
    user_question: str,
):
    llm_with_structure = llm.with_structured_output(SelectedTableList)
    msg = prompt.invoke(
        {"table_list_text": table_list_text, "user_question": user_question}
    )
    # print(msg.to_messages()[0].content)
    res: SelectedTableList = llm_with_structure.invoke(msg)
    return res.selected_table_list


def retrieve_llm(config: dict):
    table_list_file = config["table_list_file"].format(dataset=config["dataset"])
    table_list_text = open(table_list_file, "r").read()
    input_file = config["input_file"].format(dataset=config["dataset"])
    df = pd.read_csv(input_file)

    llm = ChatOllama(model=config["model"], **config["model_config"])
    return gen_selected_table_list(
        llm, config["prompt"], table_list_text, df["question"].iloc[0]
    )

In [19]:
retrieve_llm(configs[0])

KeyboardInterrupt: 