In [None]:
!pip install langchain langchain-groq langchain-chroma sentence-transformers langchain-community chromadb langchain-huggingface

In [None]:
import os
from google.colab import userdata
GROQ_API_KEY = userdata.get('GROQ_API_KEY')
os.environ['GROQ_API_KEY'] = GROQ_API_KEY

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import os
import sqlite3

conn = sqlite3.connect("company.db")
cursor = conn.cursor()

In [None]:
cursor.execute("""
CREATE TABLE IF NOT EXISTS employees (
    id INTEGER PRIMARY KEY,
    name TEXT NOT NULL,
    department TEXT,
    salary INTEGER,
    hire_date DATE,
    age INTEGER
)
""")

<sqlite3.Cursor at 0x7fdfc84a2640>

In [None]:
cursor.execute("""
CREATE TABLE IF NOT EXISTS departments (
    dept_id INTEGER PRIMARY KEY,
    dept_name TEXT NOT NULL,
    budget INTEGER,
    manager TEXT
)
""")

<sqlite3.Cursor at 0x7fdfc84a2640>

In [None]:
# Insert sample data
employees_data = [
    (1, "John Smith", "Engineering", 85000, "2022-01-15", 28),
    (2, "Sarah Johnson", "Marketing", 65000, "2021-06-20", 32),
    (3, "Mike Davis", "Engineering", 92000, "2020-03-10", 35),
    (4, "Lisa Chen", "Sales", 58000, "2023-02-28", 26),
    (5, "David Wilson", "HR", 72000, "2019-11-05", 41)
]

departments_data = [
    (1, "Engineering", 500000, "Alice Brown"),
    (2, "Marketing", 200000, "Tom White"),
    (3, "Sales", 300000, "Emma Green"),
    (4, "HR", 150000, "Chris Black")
]

In [None]:
cursor.executemany("INSERT OR REPLACE INTO employees VALUES (?,?,?,?,?,?)", employees_data)
cursor.executemany("INSERT OR REPLACE INTO departments VALUES (?,?,?,?)", departments_data)
conn.commit()
print("Database created with sample data!")

Database created with sample data!


In [None]:
from langchain_groq import ChatGroq
llm = ChatGroq(
    model="llama-3.3-70b-versatile",
    temperature=0.1
)

In [None]:
from langchain_huggingface import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [None]:
from langchain_chroma import Chroma
schema_vectorstore = Chroma(
    collection_name="database_schemas",
    embedding_function=embeddings
)

In [None]:
from langchain_core.documents import Document
from langchain_community.vectorstores.utils import filter_complex_metadata

def get_schema_info():
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()

    schema_docs = []

    for table in tables:
        table_name = table[0]

        cursor.execute(f"PRAGMA table_info({table_name});")
        columns = cursor.fetchall()

        column_info = []
        for col in columns:
            col_name = col[1]
            col_type = col[2]
            column_info.append(f"{col_name} ({col_type})")

        cursor.execute(f"SELECT * FROM {table_name} LIMIT 3;")
        sample_data = cursor.fetchall()

        schema_text = f"""
        Table: {table_name}
        Columns: {', '.join(column_info)}
        Sample data: {sample_data}
        Description: This table contains information about {table_name}
        """

        doc = Document(
            page_content=schema_text,
            metadata={
                "table_name": table_name,
                "columns": [col[1] for col in columns],
                "source": "database_schema"
            }
        )
        schema_docs.append(doc)

    return schema_docs

schema_documents = get_schema_info()
filtered_schema_documents = filter_complex_metadata(schema_documents)
schema_vectorstore.add_documents(filtered_schema_documents)
print("Schema information stored in vector store!")

Schema information stored in vector store!


In [None]:
def natural_language_to_sql(question):
    relevant_schemas = schema_vectorstore.similarity_search(question, k=2)

    schema_context = ""
    for doc in relevant_schemas:
        schema_context += doc.page_content + "\n\n"

    sql_prompt = f"""
    Based on the following database schema information, generate a SQL query to answer the user's question.

    Database Schema:
    {schema_context}

    User Question: {question}

    Instructions:
    - Generate only a valid SQLite SQL query
    - Use proper table and column names from the schema
    - Return only the SQL query without explanation
    - Make sure the query is syntactically correct

    SQL Query:
    """

    response = llm.invoke(sql_prompt)
    sql_query = response.content.strip()

    if "```sql" in sql_query:
        sql_query = sql_query.split("```sql")[1].split("```")[0]
    elif "```" in sql_query:
        sql_query = sql_query.split("```")[1].split("```")[0]


    return sql_query

In [None]:
def query_database_with_rag(question):
    try:
        sql_query = natural_language_to_sql(question)
        print(f"Generated SQL: {sql_query}")

        cursor.execute(sql_query)
        results = cursor.fetchall()

        column_names = [description for description in cursor.description]

        formatted_results = []
        for row in results:
            formatted_results.append(dict(zip(column_names, row)))

        answer_prompt = f"""
        Based on the following SQL query results, provide a clear and natural language answer to the user's question.

        Original Question: {question}
        SQL Query: {sql_query}
        Query Results: {formatted_results}

        Provide a clear, conversational answer based on the data:
        """

        final_response = llm.invoke(answer_prompt)
        return {
            "answer": final_response.content,
            "sql_query": sql_query,
            "results": formatted_results
        }

    except Exception as e:
        return {
            "error": f"An error occurred: {str(e)}",
            "sql_query": sql_query if 'sql_query' in locals() else "Not generated"
        }

In [None]:
questions = [
    "How many employees are there in total?",
    "What is the average salary by department?",
    "Who are the employees in the Engineering department?",
    "Which department has the highest budget?",
    "Show me employees older than 30 years"
]

for question in questions:
    print(f"\n{'='*50}")
    print(f"Question: {question}")
    print('='*50)

    result = query_database_with_rag(question)

    if "error" in result:
        print(f"Error: {result['error']}")
    else:
        print(f"Answer: {result['answer']}")
        print(f"SQL Query: {result['sql_query']}")
        print(f"Raw Results: {result['results']}")


Question: How many employees are there in total?
Generated SQL: 
SELECT COUNT(id) FROM employees;

Answer: There are 5 employees in total.
SQL Query: 
SELECT COUNT(id) FROM employees;

Raw Results: [{('COUNT(id)', None, None, None, None, None, None): 5}]

Question: What is the average salary by department?
Generated SQL: 
SELECT department, AVG(salary) AS average_salary
FROM employees
GROUP BY department;

Answer: The average salary varies by department. In the Engineering department, the average salary is $88,500. The HR department has an average salary of $72,000. For Marketing, the average salary is $65,000, and in Sales, it's $58,000. This gives you a general idea of how salaries compare across different departments within the company.
SQL Query: 
SELECT department, AVG(salary) AS average_salary
FROM employees
GROUP BY department;

Raw Results: [{('department', None, None, None, None, None, None): 'Engineering', ('average_salary', None, None, None, None, None, None): 88500.0}, {('

In [None]:
def ask_database(question):
    result = query_database_with_rag(question)

    if "error" in result:
        print(f"❌ {result['error']}")
    else:
        print(f"🤖 {result['answer']}")
        print(f"📊 Data: {result['results']}")

ask_database("What's the total payroll for all employees?")
ask_database("Show me the oldest employee in each department")
ask_database("Which departments have budgets over 250000?")

Generated SQL: 
SELECT SUM(salary) AS total_payroll FROM employees;

🤖 The total payroll for all employees is $372,000.
📊 Data: [{('total_payroll', None, None, None, None, None, None): 372000}]
Generated SQL: 
SELECT department, name, age
FROM employees
WHERE (department, age) IN (
    SELECT department, MAX(age)
    FROM employees
    GROUP BY department
)

🤖 The oldest employee in each department is as follows: 

In the Marketing department, the oldest employee is Sarah Johnson, who is 32 years old. 
In the Engineering department, the oldest employee is Mike Davis, who is 35 years old. 
In the Sales department, the oldest employee is Lisa Chen, who is 26 years old. 
And in the HR department, the oldest employee is David Wilson, who is 41 years old.

So, the overall oldest employee across all departments is David Wilson from the HR department at 41 years old.
📊 Data: [{('department', None, None, None, None, None, None): 'Marketing', ('name', None, None, None, None, None, None): 'Sarah