Importing Libraries






In [None]:
# !pip install langchain_openai langchain_community langchain pymysql chromadb -q
import os
import sqlite3
import pandas as pd
from langchain_openai import ChatOpenAI
from langchain.chains import create_sql_query_chain
from langchain.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_community.vectorstores import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough




*   Database Setup


1.   Connecting to SQLITE3
2.   Creating a database and laoding CSV file as a table.





In [None]:
DB_PATH = "database.db"
CSV_FILE = "Data/LLM_data.csv"
OPENAI_API_KEY = "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"
table_info = """
Table name: LLM_sql
Column names: PartID, RepairID, CarID, DemandDate, Quantity, ModelYear
Here is the description of each column:
* PartID: Unique identifier for each part (Character)
* RepairID: Unique identifier for each repair instance (Alphanumeric)
* CarID: Unique identifier for cars; a single CarID may have multiple RepairIDs (Alphanumeric)
* DemandDate: Date of repair demand (note the date format for consistency). (Date in mm/dd/yy format from 01/01/2019 to 10/16/2023)
* Quantity: Number of parts used or required for the repair (Integer)
* ModelYear: Year the car model was manufactured (Integer)
"""


"""Connects to the SQLite database."""
def connect_to_db(db_path=DB_PATH):
    try:
        conn = sqlite3.connect(db_path)
        return conn
    except sqlite3.Error as e:
        print(f"Error connecting to database: {e}")



"""Loads data from CSV to the database."""
def load_data_to_db(csv_file, db_conn):
    df = pd.read_csv(csv_file)
    df['DemandDate'] = pd.to_datetime(df['DemandDate'], format='%m/%d/%y')
    df.to_sql('LLM_sql', db_conn, if_exists='replace', index=False)



In [None]:

#  """Defines examples of natural language queries and expected SQL queries."""
examples = [
        {
            "input": "What is the earliest demand date for each model year?",
            "sql_cmd": "SELECT ModelYear, MIN(DemandDate) AS EarliestDemandDate FROM LLM_sql GROUP BY ModelYear ORDER BY ModelYear ASC;"  },

        {
            "input": "What is the most common combination of repair ID and model year?",
            "sql_cmd": """SELECT RepairID, ModelYear, COUNT(*) AS count
                        FROM your_table
                        GROUP BY RepairID, ModelYear
                        ORDER BY count DESC
                        LIMIT 1;
                        """
        },
        {
            "input": "What's the average quantity of parts used per repair?",
            "sql_cmd": """SELECT RepairID, AVG(Quantity) AS AvgQuantity FROM LLM_sql GROUP BY RepairID;"""
        },
        {
            "input": "Give me repair trends for the month of January across years?",
            "sql_cmd": """SELECT strftime('%Y', DemandDate) AS year, COUNT(*) AS total_repairs
            FROM LLM_sql WHERE strftime('%m', DemandDate) = '01' -- Represents January
            GROUP BY strftime('%Y', DemandDate)
            ORDER BY year ASC; )"""
        },
        {
            "input": "What is the year-over-year change trend  in repairs request for each car model ?",
            "sql_cmd": """WITH cte AS (
    SELECT
        ModelYear,
        strftime('%Y', DemandDate) as DemandYear,
        COUNT(*) AS TotalRepairs
    FROM
        LLM_sql
    GROUP BY
        ModelYear, strftime('%Y', DemandDate)
)
SELECT
    ModelYear,
    DemandYear,
    TotalRepairs,
    TotalRepairs - LAG(TotalRepairs, 1) OVER (PARTITION BY ModelYear ORDER BY DemandYear) AS YearOverYearChange
FROM
    cte;"""
        }

]

def create_few_shot_prompt(examples):
  vectorstore = Chroma()
  vectorstore.delete_collection()


    # """Creates the few-shot prompt for the language model."""
  example_selector = SemanticSimilarityExampleSelector.from_examples(
        examples,
        OpenAIEmbeddings(),
        vectorstore,
        k=2,  # Select top 2 most similar examples
        input_keys=["input"]
    )
  example_prompt = ChatPromptTemplate.from_messages(
        [
            ("human", "{input}\nSQL Query:"),
            ("ai", "{sql_cmd}")
        ]
    )
  fs_prompt = FewShotChatMessagePromptTemplate(
        example_prompt=example_prompt,
        example_selector=example_selector,
        input_variables=["input","top_k","table_info"]
    )
  final_prompt=ChatPromptTemplate.from_messages(
     [
         ("system", """You are a SQLLITE expert. Given an input question, create a syntactically correct SQLITE query to run. Give me column names as well. Do not use LIMIT functionality  Unless otherwise specificed.
         Here is the relevant table info: Table name: LLM_sql
Column names: PartID, RepairID, CarID, DemandDate, Quantity, ModelYear
Here is the description of each column:
* PartID: Unique identifier for each part (Character)
* RepairID: Unique identifier for each repair instance (Alphanumeric)
* CarID: Unique identifier for cars; a single CarID may have multiple RepairIDs (Alphanumeric)
* DemandDate: Date of repair demand (note the date format for consistency). (Date in mm/dd/yy format from 01/01/2019 to 10/16/2023)
* Quantity: Number of parts used or required for the repair (Integer)
* ModelYear: Year the car model was manufactured (Integer)
Below are a number of examples of questions and their corresponding SQL queries."""),
         fs_prompt,
         ("human", "{input}"),
     ]
 )
  # print(final_prompt.format(input="How many rows are there?"))
  return final_prompt



  #  """Generates the SQL query and executes it."""

def generate_and_execute_query(user_question, final_prompt, llm, db):
  generate_query = create_sql_query_chain(llm, db,final_prompt)
  chain = (
  RunnablePassthrough.assign(query =generate_query).assign(
      result=itemgetter("query")
      # | execute_query
  )
  )
  return chain.invoke({"question": user_question})

def query_with_no_prompt(user_question, llm, db):
  generate_query = create_sql_query_chain(llm, db)
  query = generate_query.invoke({"question": user_question})
  print(query)

def query_generation(user_question):
    """Coordinates the overall workflow."""
    os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY

    # Load data (only if the database doesn't exist yet)
    if not os.path.exists(DB_PATH):
        with connect_to_db() as conn:
            load_data_to_db(CSV_FILE, conn)

    llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
    db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}")
    # Prepare LangChain components


    print( 'Query without fewshot prompting : \n')
    query_with_no_prompt(user_question, llm, db)

    print( '\n Query with fewshot prompting : \n')
    prompt = create_few_shot_prompt(examples)
    # llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
    # db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}")
    # Main query loop

    print(generate_and_execute_query(user_question, prompt, llm, db))



In [None]:
query_generation("What is the  change trend for every two years  in repairs ")

Query without fewshot prompting : 

SELECT "ModelYear", COUNT("RepairID") as "Number of Repairs"
FROM "LLM_sql"
GROUP BY "ModelYear"
ORDER BY "ModelYear" DESC
LIMIT 5;

 Query with fewshot prompting : 

{'question': 'What is the  change trend for every two years  in repairs ', 'query': "WITH cte AS (\n    SELECT\n        strftime('%Y', DemandDate) as DemandYear,\n        COUNT(*) AS TotalRepairs\n    FROM\n        LLM_sql\n    GROUP BY\n        strftime('%Y', DemandDate)\n)\nSELECT\n    DemandYear,\n    TotalRepairs,\n    TotalRepairs - LAG(TotalRepairs, 2) OVER (ORDER BY DemandYear) AS TwoYearChange\nFROM\n    cte;", 'result': "WITH cte AS (\n    SELECT\n        strftime('%Y', DemandDate) as DemandYear,\n        COUNT(*) AS TotalRepairs\n    FROM\n        LLM_sql\n    GROUP BY\n        strftime('%Y', DemandDate)\n)\nSELECT\n    DemandYear,\n    TotalRepairs,\n    TotalRepairs - LAG(TotalRepairs, 2) OVER (ORDER BY DemandYear) AS TwoYearChange\nFROM\n    cte;"}
