In [22]:
from src.chains.lang_chain import sql_chain
from src.chains.llama_chain import retriever, execute_sql
from src.data.data_loader import create_table_from_file
from sqlalchemy import create_engine
from dotenv import load_dotenv
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from llama_index.llms.huggingface import HuggingFaceInferenceAPI
from llama_index.core import SQLDatabase


In [23]:
# Load environment variables, e.g. HuggingFace API token

load_dotenv()

True

In [24]:
file = "/Users/leebarrett/Documents/data_projects/duckdb/lang_vs_llama/files/avg_house_price_scotland.csv"

### Database setup

In [25]:
# create a database connection to an in-memory DuckDB instance and create a table from the file.

engine = create_engine("duckdb:///:memory:")
create_table_from_file(engine, file)

### LangChain example

In [26]:
# Model from HuggingFace repo. This could be any model - feel free to play around with different models.

model = "mistralai/Mistral-Nemo-Instruct-2407"

llm = HuggingFaceEndpoint(endpoint_url=model, temperature=0.1)

In [27]:
#  question = "What was the average detached house price in Scotland in 2022?"
#  question = "What region and date had the highest average detached house price?"
#  question = "What region and date had the highest average detached house price in 2022?"
#  question = "What was the average semi-detached house price in Scotland in 2023?"
#  question = "What was the highest average price for a detached house in East Renfrewshire in 2023?"
#  question = "What was the highest average semi-detached house price in East Renfrewshire in 2023?"
question = "What was the lowest average semi-detached house price in South Lanarkshire in all of 2023?"

In [28]:
chain = sql_chain(model=llm, db=engine)
chain.invoke({"question": question})



'[(178082.0,)]'

### LLamaIndex Example

In [29]:
model = HuggingFaceInferenceAPI(model_name=model)

db = SQLDatabase(engine)

  model = HuggingFaceInferenceAPI(model_name=model)


In [30]:
retriever = retriever(model=model, db=db)

query = retriever.retrieve(question)

execute_sql(db=db, sql_query=query[0].text)

('[(178082.0,)]',
 {'result': [(178082.0,)], 'col_keys': ['min(Semi_Detached_Average_Price)']})