### Store the schema information in ChromaDB

Retrieve only the relevant columns based on the asked query.

In [1]:
pip install chromaDB

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip available: 22.3.1 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [15]:
import chromadb

client = chromadb.PersistentClient(path="D:\\Internship\\text-to-SQL\\application\\chroma")

In [4]:
client.heartbeat()

1739787377597971700

In [16]:
collection = client.get_or_create_collection("sql_data")

In [6]:
from sentence_transformers import SentenceTransformer

embedder = SentenceTransformer("all-MiniLM-L6-v2")

  from .autonotebook import tqdm as notebook_tqdm


In [18]:
schemas = [
    {
        "id": "table_regions",
        "text": """Table: regions
Purpose: Stores geographical regions that categorize different countries.
Columns:
- region_id (INT, Primary Key, Auto-increment) → Unique ID for each region.
- region_name (VARCHAR, 25) → Name of the region (e.g., Asia, Europe).""",
        "metadata": {"type": "table", "name": "regions"}
    },
    {
        "id": "table_countries",
        "text": """Table: countries
Purpose: Stores country details and links them to regions.
Columns:
- country_id (CHAR, 2, Primary Key) → Unique 2-character country code (e.g., 'US' for United States).
- country_name (VARCHAR, 40) → Name of the country.
- region_id (INT, Foreign Key) → References regions.region_id, linking the country to its region.""",
        "metadata": {"type": "table", "name": "countries"}
    },
    {
        "id": "table_locations",
        "text": """Table: locations
Purpose: Stores address details for different locations.
Columns:
- location_id (INT, Primary Key, Auto-increment) → Unique location identifier.
- street_address (VARCHAR, 40) → Street address of the location.
- postal_code (VARCHAR, 12) → Postal code of the location.
- city (VARCHAR, 30, NOT NULL) → City where the location is situated.
- state_province (VARCHAR, 25) → State or province.
- country_id (CHAR, 2, Foreign Key) → References countries.country_id, linking the location to a country.""",
        "metadata": {"type": "table", "name": "locations"}
    },
    {
        "id": "table_jobs",
        "text": """Table: jobs
Purpose: Stores job positions and salary ranges.
Columns:
- job_id (INT, Primary Key, Auto-increment) → Unique job identifier.
- job_title (VARCHAR, 35, NOT NULL) → Name of the job (e.g., Software Engineer).
- min_salary (DECIMAL, 8,2) → Minimum salary for the job.
- max_salary (DECIMAL, 8,2) → Maximum salary for the job.""",
        "metadata": {"type": "table", "name": "jobs"}
    },
    {
        "id": "table_departments",
        "text": """Table: departments
Purpose: Stores company department details.
Columns:
- department_id (INT, Primary Key, Auto-increment) → Unique department identifier.
- department_name (VARCHAR, 30, NOT NULL) → Name of the department (e.g., IT, HR).
- location_id (INT, Foreign Key) → References locations.location_id, linking the department to a location.""",
        "metadata": {"type": "table", "name": "departments"}
    },
    {
        "id": "table_employees",
        "text": """Table: employees
Purpose: Stores employee records, including job and department details.
Columns:
- employee_id (INT, Primary Key, Auto-increment) → Unique employee identifier.
- first_name (VARCHAR, 20) → First name of the employee.
- last_name (VARCHAR, 25, NOT NULL) → Last name of the employee.
- email (VARCHAR, 100, NOT NULL) → Employee's email address.
- phone_number (VARCHAR, 20) → Contact number.
- hire_date (DATE, NOT NULL) → Date the employee was hired.
- job_id (INT, Foreign Key) → References jobs.job_id, linking to the employee's job role.
- salary (DECIMAL, 8,2, NOT NULL) → Employee's salary.
- manager_id (INT, Foreign Key) → References employees.employee_id (self-referencing for hierarchy).
- department_id (INT, Foreign Key) → References departments.department_id.""",
        "metadata": {"type": "table", "name": "employees"}
    },
    {
        "id": "table_dependents",
        "text": """Table: dependents
Purpose: Stores dependents of employees.
Columns:
- dependent_id (INT, Primary Key, Auto-increment) → Unique dependent identifier.
- first_name (VARCHAR, 50, NOT NULL) → First name of the dependent.
- last_name (VARCHAR, 50, NOT NULL) → Last name of the dependent.
- relationship (VARCHAR, 25, NOT NULL) → Relationship to the employee (e.g., Spouse, Child).
- employee_id (INT, Foreign Key) → References employees.employee_id, linking to the employee.""",
        "metadata": {"type": "table", "name": "dependents"}
    }
]

for schema in schemas:
    embedding = embedder.encode(schema["text"]).tolist()  # Convert text to embedding
    collection.add(
        ids=[schema["id"]],
        embeddings=[embedding],
        metadatas=[schema["metadata"]],
        documents=[schema["text"]]
    )

print("Schema successfully stored in ChromaDB!")

Schema successfully stored in ChromaDB!


### Retrieve information fabout schema from the database based on query

In [19]:
def retrieve_schema(query, top_n=5):
    query_embedding = embedder.encode(query).tolist()

    results = collection.query(
        query_embeddings=[query_embedding],
        n_results=top_n
    )

    retrieved_docs = []
    for doc, meta in zip(results["documents"][0], results["metadatas"][0]):
        retrieved_docs.append(f"Table: {meta['name']}\nInfo: {doc}\n")

    return "\n".join(retrieved_docs)

In [20]:
# Example query
query = "In which store was customer with email 'MARY.SMITH@sakilacustomer.org' registered in? Provide the address of the store.?"
retrieved_context = retrieve_schema(query)

print("Retrieved Schema Context:\n", retrieved_context)

Retrieved Schema Context:
 Table: customer_list
Info: View: customer_list
Purpose: Provides a list of customers along with their address, city, country, and store details.
Columns:
- ID (customer_id) → Unique customer ID.
- name → Concatenated first and last name of the customer.
- address → Address of the customer.
- zip code → Postal code of the customer.
- phone → Customer's contact number.
- city → City of the customer.
- country → Country of the customer.
- notes → Status of the customer (active or inactive).
- SID → Store ID to which the customer is linked.

Table: store
Info: Table: store
Purpose: Stores information about film rental stores.
Columns:
- store_id (TINYINT UNSIGNED, Primary Key, Auto-increment) → Unique ID for each store.
- manager_staff_id (TINYINT UNSIGNED, Foreign Key) → References staff.staff_id, linking store to its manager.
- address_id (SMALLINT UNSIGNED, Foreign Key) → References address.address_id, linking store to an address.
- last_update (TIMESTAMP, NOT

### Pass the retrieved schema info as context to LLMs for generating the query.

In [11]:
!pip install langchain_mistralai

Collecting langchain_mistralai
  Downloading langchain_mistralai-0.2.6-py3-none-any.whl (15 kB)
Collecting httpx-sse<1,>=0.3.1
  Downloading httpx_sse-0.4.0-py3-none-any.whl (7.8 kB)
Collecting langchain-core<0.4.0,>=0.3.33
  Using cached langchain_core-0.3.34-py3-none-any.whl (412 kB)
Collecting langsmith<0.4,>=0.1.125
  Downloading langsmith-0.3.8-py3-none-any.whl (332 kB)
     -------------------------------------- 332.8/332.8 kB 2.3 MB/s eta 0:00:00
Collecting jsonpatch<2.0,>=1.33
  Using cached jsonpatch-1.33-py2.py3-none-any.whl (12 kB)
Collecting jsonpointer>=1.9
  Using cached jsonpointer-3.0.0-py2.py3-none-any.whl (7.6 kB)
Collecting requests-toolbelt<2.0.0,>=1.0.0
  Using cached requests_toolbelt-1.0.0-py2.py3-none-any.whl (54 kB)
Collecting zstandard<0.24.0,>=0.23.0
  Using cached zstandard-0.23.0-cp311-cp311-win_amd64.whl (495 kB)
Installing collected packages: zstandard, jsonpointer, httpx-sse, requests-toolbelt, jsonpatch, langsmith, langchain-core, langchain_mistralai
Su


[notice] A new release of pip available: 22.3.1 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [21]:
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.documents import Document

In [None]:
from langchain.chat_models import init_chat_model

model = init_chat_model("mistral-medium", model_provider="mistralai", temperature=0.8, mistral_api_key="YOUR API KEY")

#### define the prompt template

In [23]:
prompt_template = PromptTemplate(
    template = """
    You are an expert SQL Generator. Based on the provided database schema information. Generate only the SQL query to answer the following question.

    Schema Information:
    {schema_info}

    Question:
    {question}

    SQL Query:
    """,
    input_variables=["schema_info1", "question"]
)

In [24]:
insights_prompt_template = PromptTemplate(
    template= """
    You are an excellent data analyst. Rephrase the following information obtained and derive meaningful insights and tabluate the information as well.
    Information : {result}

    Insights:
    """,
    input_variables=["result"]
)

In [25]:
error_handling_prompt_template = PromptTemplate(
    template= """
    You are an expert in SQL query generation. The following SQL query has encountered an error.
    
    **Failed SQL Query:**  
    {failed_query}  

    **Error Message:**  
    {error}  

    Carefully analyze the error and modify the query to fix the issue while maintaining the intent of the original question.
    Strictly specify  **Corrected SQL Query:**  before giving the corected query.
    **Corrected SQL Query:**  
    """,
    input_variables=["failed_query", "error"]
)

In [27]:
import mysql.connector

def execute_sql_query(sql_query):
    try:
        conn = mysql.connector.connect(
            host="localhost",     
            user="root",  
            password="12345",  
            database="sakila" 
        )
        cursor = conn.cursor()
        cursor.execute(sql_query)
        result = cursor.fetchall()
        cursor.close()
        conn.close()
        return {"success": True, "result": result}
    except Exception as e:
        return {"success": False, "failed_query": sql_query, "error": str(e)}

In [28]:
def extract_corrected_sql(ai_message):
    """
    Extracts the corrected SQL query from the AI model's response.
    If no corrected query is found, return the error message.
    """
    if hasattr(ai_message, 'content'):
        content = ai_message.content  
        if "**Corrected SQL Query:**" in content:
            corrected_query = content.split("**Corrected SQL Query:**")[-1].strip()
            print("\nExtracted Corrected SQL Query:\n", corrected_query)
            return corrected_query
    return ai_message.content  # If no corrected query, return the full response as error

#### create a chain

In [29]:
from langchain_core.runnables import RunnableBranch, RunnableLambda

retriever = RunnableLambda(lambda x: {"schema_info": retrieve_schema(x["question"]), "question": x["question"]})

passfunc = RunnableLambda(lambda res: {"result": res["result"]})

failfunc = (
    RunnableLambda(lambda x: {
        "failed_query": x.get("failed_query", ""),  # Get the actual failed SQL query
        "error": x.get("error", "SQL execution failed")  # Use the actual error message
    })
    | RunnableLambda(lambda x: (
        print("\nReceived Failed Query:", x['failed_query']),  # Debug print
        print("\nReceived Error Message:", x['error']),  # Debug print
        x  # Ensure the function returns x
    )[-1])  # Use [-1] trick to return x
    | error_handling_prompt_template  
    | RunnableLambda(lambda x: (print("\nError template:", x), x)[-1])  # Ensure return x
    | model  
    | RunnableLambda(lambda x: (print("Model output: ", x), x)[-1])  # Ensure return x
    | RunnableLambda(lambda x: {
        "success": False,
        "result": extract_corrected_sql(x)  # Extract corrected SQL
    })  
    | RunnableLambda(lambda x: {
        "success": False, 
        "result": execute_sql_query(x["result"].replace("\\", ""))  # Replace backslashes in corrected SQL
    })
)

chain = (
    retriever
    | prompt_template
    | RunnableLambda(lambda x: print("Prompt Input:", x) or x)  # Print input to LLM
    | model
    | RunnableLambda(lambda x: print("Generated SQL Query:", x) or x)  
    | StrOutputParser()
    | RunnableLambda(lambda sql_query: print("Cleaned SQL Query:", sql_query.replace("\\", "")) or sql_query)  
    | RunnableLambda(lambda sql_query: execute_sql_query(sql_query.replace("\\", "")))  
    | RunnableBranch(
    (lambda res: "success" in res and res["success"], passfunc),  # Check explicitly for "success"
    failfunc
)
    | RunnableLambda(lambda result: print("Raw Result obtained from DB:", result) or result)  
    | insights_prompt_template  
    | model 
    | RunnableLambda(lambda insights: print("Final Insights:", insights) or insights) 
)

#### generate the final answer

In [31]:
query = "In which store was customer with email 'MARY.SMITH@sakilacustomer.org' registered in?"

retrived_context = retrieve_schema(query)
print("Retrived Context: ", retrieved_context)

sql_query = chain.invoke({"question": query})
sql_query

Retrived Context:  Table: customer_list
Info: View: customer_list
Purpose: Provides a list of customers along with their address, city, country, and store details.
Columns:
- ID (customer_id) → Unique customer ID.
- name → Concatenated first and last name of the customer.
- address → Address of the customer.
- zip code → Postal code of the customer.
- phone → Customer's contact number.
- city → City of the customer.
- country → Country of the customer.
- notes → Status of the customer (active or inactive).
- SID → Store ID to which the customer is linked.

Table: store
Info: Table: store
Purpose: Stores information about film rental stores.
Columns:
- store_id (TINYINT UNSIGNED, Primary Key, Auto-increment) → Unique ID for each store.
- manager_staff_id (TINYINT UNSIGNED, Foreign Key) → References staff.staff_id, linking store to its manager.
- address_id (SMALLINT UNSIGNED, Foreign Key) → References address.address_id, linking store to an address.
- last_update (TIMESTAMP, NOT NULL, D

AIMessage(content="It appears that the query failed due to an error in the SQL syntax. Specifically, there is a reference to a non-existent column named 'SID' in the 'staff' table, which is represented by the alias 's'. This error can be resolved by correcting the query to join the 'staff' and 'staff\\_list' tables using their common 'store\\_id' column.\n\nThe corrected query should allow for the retrieval of store information for a specific customer. This information includes the store ID, name, address, city, and country.\n\n| Table Alias | Column Name | Corrected Column Name |\n| --- | --- | --- |\n| s | SID | store\\_id |\n| sl | - | store\\_id |\n\nAdditional context:\n\n* The error message suggests that the query was intended to join the 'staff' and 'staff\\_list' tables using their common 'store\\_id' column, but instead attempted to use the non-existent 'SID' column.\n* The corrected query includes the appropriate join condition, which is 'ON c.store\\_id = s.store\\_id AND s.

### Checking the failure handling case

In [28]:
query = """
SELECT d.department_id, d.epartment_name, COUNT(e.employee_id) AS number_of_employees
FROM departments d
LEFT JOIN employees e ON d.department_id = e.department_id
GROUP BY d.department_id, d.epartment_name;
"""

test = execute_sql_query(query)
test

{'success': False,
 'failed_query': '\nSELECT d.department_id, d.epartment_name, COUNT(e.employee_id) AS number_of_employees\nFROM departments d\nLEFT JOIN employees e ON d.department_id = e.department_id\nGROUP BY d.department_id, d.epartment_name;\n',
 'error': "1054 (42S22): Unknown column 'd.epartment_name' in 'field list'"}

In [29]:
passfunc = RunnableLambda(lambda res: {"result": res["result"]})

failfunc = (
    RunnableLambda(lambda x: {
        "failed_query": x.get("failed_query", ""),  
        "error": x.get("error", "SQL execution failed")  
    })
    | RunnableLambda(lambda x: (
        print("\nReceived Failed Query:", x['failed_query']),  
        print("\nReceived Error Message:", x['error']),  
        x  # Ensure the function returns x
    )[-1])  # Use [-1] trick to return x
    | error_handling_prompt_template  
    | RunnableLambda(lambda x: (print("\nError template:", x), x)[-1]) 
    | model  
    | RunnableLambda(lambda x: (print("Model output: ", x), x)[-1])  
    | RunnableLambda(lambda x: {
        "success": False,
        "result": extract_corrected_sql(x)  # Extract corrected SQL
    })  
    | RunnableLambda(lambda x: {
        "success": False, 
        "result": execute_sql_query(x["result"].replace("\\", ""))  # Replace backslashes in corrected SQL
    })
)

test_chain = (
    RunnableBranch(
    (lambda res: "success" in res and res["success"], passfunc),  # Check explicitly for "success"
    failfunc
)
)

test_chain.invoke(test)



Received Failed Query: 
SELECT d.department_id, d.epartment_name, COUNT(e.employee_id) AS number_of_employees
FROM departments d
LEFT JOIN employees e ON d.department_id = e.department_id
GROUP BY d.department_id, d.epartment_name;


Received Error Message: 1054 (42S22): Unknown column 'd.epartment_name' in 'field list'

Error template: text="\n    You are an expert in SQL query generation. The following SQL query has encountered an error.\n    \n    **Failed SQL Query:**  \n    \nSELECT d.department_id, d.epartment_name, COUNT(e.employee_id) AS number_of_employees\nFROM departments d\nLEFT JOIN employees e ON d.department_id = e.department_id\nGROUP BY d.department_id, d.epartment_name;\n  \n\n    **Error Message:**  \n    1054 (42S22): Unknown column 'd.epartment_name' in 'field list'  \n\n    Carefully analyze the error and modify the query to fix the issue while maintaining the intent of the original question.\n    Strictly specify  **Corrected SQL Query:**  before giving the core

{'success': False,
 'result': {'success': True,
  'result': [(1, 'Administration', 1),
   (2, 'Marketing', 2),
   (3, 'Purchasing', 6),
   (4, 'Human Resources', 1),
   (5, 'Shipping', 7),
   (6, 'IT', 5),
   (7, 'Public Relations', 1),
   (8, 'Sales', 6),
   (9, 'Executive', 3),
   (10, 'Finance', 6),
   (11, 'Accounting', 2)]}}