In [73]:
import getpass
import os
from langchain_groq import ChatGroq
import pandas as pd
from sqlalchemy import create_engine
from dotenv import load_dotenv
load_dotenv()


os.environ["GROQ_API_KEY"] = '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'
llm = ChatGroq(model="llama3-8b-8192")

In [74]:
llm.invoke('Hi')

C:\Users\sayantghosh\AppData\Local\anaconda3\envs\genai\Lib\site-packages\pydantic\main.py:1114: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.9/migration/


AIMessage(content="Hi! It's nice to meet you. Is there something I can help you with, or would you like to chat?", response_metadata={'token_usage': {'completion_tokens': 26, 'prompt_tokens': 11, 'total_tokens': 37, 'completion_time': 0.021666667, 'prompt_time': 0.001131683, 'queue_time': 0.025714814, 'total_time': 0.02279835}, 'model_name': 'llama3-8b-8192', 'system_fingerprint': 'fp_179b0f92c9', 'finish_reason': 'stop', 'logprobs': None}, id='run-5c4f6884-3907-4fec-af1b-946eb4fdf821-0')

In [58]:
df = pd.read_csv('ds_salaries.csv')
df.head()

Unnamed: 0,work_year,experience_level,employment_type,job_title,salary,salary_currency,salary_in_usd,employee_residence,remote_ratio,company_location,company_size
0,2023,SE,FT,Principal Data Scientist,80000,EUR,85847,ES,100,ES,L
1,2023,MI,CT,ML Engineer,30000,USD,30000,US,100,US,S
2,2023,MI,CT,ML Engineer,25500,USD,25500,US,100,US,S
3,2023,SE,FT,Data Scientist,175000,USD,175000,CA,100,CA,M
4,2023,SE,FT,Data Scientist,120000,USD,120000,CA,100,CA,M


In [3]:
DATABSE_URL = "sqlite:///chatbot.db"
engine = create_engine(DATABSE_URL)
df = pd.read_csv('ds_salaries.csv')
df.to_sql('chat_data',con=engine,index=False,if_exists='replace')

3755

## Simple Database Question Answering

In [75]:
from langchain.agents import Tool
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain_experimental.sql import SQLDatabaseChain
from langchain_community.utilities import SQLDatabase

In [104]:
db = SQLDatabase.from_uri("sqlite:///chatbot.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT job_title, AVG(CAST(salary_in_usd AS REAL)) AS average_salary FROM chat_data GROUP BY job_title ORDER BY average_salary DESC LIMIT 10;")

sqlite
['chat_data']


"[('Data Science Tech Lead', 375000.0), ('Cloud Data Architect', 250000.0), ('Data Lead', 212500.0), ('Data Analytics Lead', 211254.5), ('Principal Data Scientist', 198171.125), ('Director of Data Science', 195140.72727272726), ('Principal Data Engineer', 192500.0), ('Machine Learning Software Engineer', 192420.0), ('Data Science Manager', 191278.77586206896), ('Applied Scientist', 190264.4827586207)]"

## Generate SQL Query Function

In [77]:
from langchain import hub

#query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")

In [102]:
from typing import Optional
from langchain_groq import ChatGroq
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.output_parsers import JsonOutputParser

class QueryOutput(BaseModel):
    """Generated SQL query."""
    query: str = Field(..., description="Syntactically valid SQL query.")

def query_validator_agent(sql_query, table_info):
    base_prompt_template = """
    You are a SQL query assistant. You will be provided with the table schema and a SQL Query. 
    Your task is to check if the SQL query is valid based on the schema and fix any issues that arise, 
    ensuring the query is compatible with SQLite.

    <CRUCIAL RULES>
    - Always provide a valid SQL query based on the schema provided. No additional text is allowed.
    - The SQL Query MUST BE COMPATIBLE WITH SQLITE. SQLite does not support advanced functions such as PERCENTILE_CONT.
    - If the query uses unsupported functions or features, replace them with SQLite-compatible alternatives.
    - Verify the columns and tables in the query.
    - If the query includes percentile calculations, use a workaround or remove them, as SQLite does not support `PERCENTILE_CONT`.

    Table Schema: 
    {SCHEMA}
    
    SQL QUERY:
    {SQL_QUERY}
    
    Correct SQL Query: ONLY provide the SQL query. No additional text is allowed.
    \n{format_instructions}\n
    """

    parser = JsonOutputParser(pydantic_object=QueryOutput)
    
    prompt = PromptTemplate(
        template=base_prompt_template,
        input_variables=["SCHEMA", "SQL_QUERY"],
        partial_variables={"format_instructions": parser.get_format_instructions()},
    )
    chain = prompt | llm | parser 
    response = chain.invoke({
        "SCHEMA": table_info,
        "SQL_QUERY": sql_query
    })
    
    # Return the fixed query if any changes are made
    return response['query']

# SQL generation function with enhanced prompt template for SQLite
def generate_sql_query(user_question: str, table_info: str, dialect: str = 'sqlite') -> str:
    """Generate SQL query from user question for SQLite."""
    
    # Example scenarios to guide the model
    example_scenarios = """
    Example Scenarios:
    1. Find the total salary for employees in a specific job title (e.g., 'Software Engineer').
       Query: SELECT SUM(salary) FROM chat_data WHERE job_title = 'Software Engineer';

    2. Retrieve the top 5 highest salaries and their job titles.
       Query: SELECT job_title, salary FROM chat_data ORDER BY salary DESC LIMIT 5;

    3. Find the average salary by experience level, sorted by average salary.
       Query: SELECT experience_level, AVG(salary) FROM chat_data GROUP BY experience_level ORDER BY AVG(salary) DESC;

    4. Retrieve the employees with the lowest salary in a specific department (e.g., 'Engineering').
       Query: SELECT job_title, salary FROM chat_data WHERE job_title = 'Engineering' ORDER BY salary ASC LIMIT 1;

    5. Count the number of employees with a remote ratio greater than 50%.
       Query: SELECT COUNT(*) FROM chat_data WHERE remote_ratio > 50;

    6. Find the 10th percentile salary in the dataset (workaround for SQLite as it doesn't support `PERCENTILE_CONT`).
       Query: SELECT salary FROM chat_data ORDER BY salary LIMIT 1 OFFSET (SELECT ROUND(0.1 * COUNT(*) - 1) FROM chat_data);

    7. List employees who work remotely and have a salary greater than 100000.
       Query: SELECT job_title, salary FROM chat_data WHERE remote_ratio = 100 AND salary > 100000;

    8. Retrieve the employee with the highest salary in each job title.
       Query: SELECT job_title, MAX(salary) FROM chat_data GROUP BY job_title;

    9. Get the average salary for each company size (e.g., 'Small', 'Medium', 'Large').
       Query: SELECT company_size, AVG(salary) FROM chat_data GROUP BY company_size;

    10. Find the employees with more than 5 years of experience and their salary in USD.
        Query: SELECT job_title, salary_in_usd FROM chat_data WHERE experience_level = 'Senior' AND work_year > 5;
    """

    prompt_template = PromptTemplate(
        input_variables=["dialect", "table_info", "input", "example_scenarios"],
        template=f"""
        Given an input question, create a syntactically correct {dialect} query to help find the answer. 
        Unless the user specifies in their question a specific number of examples they wish to obtain, 
        always limit your query to at most 5 results. You can order the results by a relevant column to return the most interesting examples in the database.

        **Important Notes for SQLite**:
        - SQLite does not support advanced functions like `PERCENTILE_CONT` or window functions.
        - Ensure that any complex calculations or percentile calculations are replaced with compatible SQLite logic (e.g., subqueries or custom aggregations).
        - You can use `LIMIT` and `OFFSET` for approximations, especially for cases like percentiles.
        - Avoid using non-SQLite syntax or unsupported functions.

        **Use the following table schema**:
        {table_info}

        **Example Scenarios**:
        {example_scenarios}

        **Question**: {user_question}
        """
    )

    prompt = prompt_template.invoke(
        {
            "dialect": dialect,
            "table_info": table_info,
            "input": user_question,
            "example_scenarios": example_scenarios,
        }
    )
    
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    validated_query = query_validator_agent(result.query, table_info)
    return validated_query

In [103]:
print(query_validator_agent(sql_query,table_info))

SELECT job_title, AVG(CAST(salary_in_usd AS REAL)) AS average_salary FROM chat_data GROUP BY job_title ORDER BY average_salary DESC LIMIT 10;


C:\Users\sayantghosh\AppData\Local\anaconda3\envs\genai\Lib\site-packages\pydantic\main.py:1114: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.9/migration/


In [79]:
table_info = """
    Table: chat_data
    Columns:
    - work_year: INTEGER
    - experience_level: TEXT
    - employment_type: TEXT
    - job_title: TEXT
    - salary: INTEGER
    - salary_currency: TEXT
    - salary_in_usd: INTEGER
    - employee_residence: TEXT
    - remote_ratio: INTEGER
    - company_location: TEXT
    - company_size: TEXT
    """

In [90]:
user_question = "What are the top 10 average salaries for each job title?"
sql_query = generate_sql_query(user_question, table_info)
print(sql_query)

SELECT job_title, AVG(salary_in_usd) AS average_salary
FROM chat_data
GROUP BY job_title
ORDER BY average_salary DESC
LIMIT 10;


C:\Users\sayantghosh\AppData\Local\anaconda3\envs\genai\Lib\site-packages\pydantic\main.py:1114: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.9/migration/


## Execute the SQL Query

In [91]:
def execute_sql_query(query: str, engine) -> pd.DataFrame:
    """Execute the generated SQL query and return the results."""
    return pd.read_sql_query(query, con=engine)

In [92]:
execute_sql_query(sql_query,engine)

Unnamed: 0,job_title,average_salary
0,Data Science Tech Lead,375000.0
1,Cloud Data Architect,250000.0
2,Data Lead,212500.0
3,Data Analytics Lead,211254.5
4,Principal Data Scientist,198171.125
5,Director of Data Science,195140.727273
6,Principal Data Engineer,192500.0
7,Machine Learning Software Engineer,192420.0
8,Data Science Manager,191278.775862
9,Applied Scientist,190264.482759


## 5. Implement Chatbot Functionality with Memory

Create a function to handle user inputs, generate SQL queries, execute them, and maintain chat history.

In [95]:
import json
from langchain.memory import ConversationBufferMemory
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


# Initialize memory to store chat history
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
chat_history = memory.load_memory_variables({})["chat_history"]

def generate_plot(df: pd.DataFrame, plot_flag: bool) -> Optional[plt.Figure]:
    """Generate a plot based on the DataFrame columns if plot_flag is True."""
    if plot_flag and not df.empty:
        # Plot dynamically based on the columns of the DataFrame
        plot = None
        num_columns = len(df.columns)

        if num_columns == 1:
            plot = df.plot(kind="bar", title="Single Column Plot")
        elif num_columns == 2:
            plot = df.plot(kind="scatter", x=df.columns[0], y=df.columns[1], title="Scatter Plot")
        elif num_columns > 2:
            plot = sns.pairplot(df)  # Create a pairplot for multiple columns

        plt.show()
        return plot
    return None

def chatbot(user_question: str, engine, table_info: str, plot_flag: bool = False, dialect: str = 'sqlite') -> str:
    """Handle user input, generate SQL query, execute it, and maintain chat history."""
    
    
    # Generate SQL query from user question
    sql_query = generate_sql_query(user_question, table_info)
    #validated_sql_query = validate_sql_query(sql_query, user_question,table_info)
    
    # Execute the SQL query
    query_result = execute_sql_query(sql_query, engine)

    #Generate Plot
    generate_plot(query_result,plot_flag)
    
    # Convert query result to a string (e.g., CSV format)
    result_str = query_result.to_csv(index=False)

    chat_history.append({"role": "user", "content": user_question})
    chat_history.append({"role": "assistant", "content": [sql_query,result_str]})
    
    return result_str


In [98]:
if __name__ == "__main__":
    # Define table information (e.g., column names and types)  
    while True:
        user_input = input("You: ")
        if user_input.lower() == 'exit':
            break
        response = chatbot(user_question=user_input,
                           engine=engine, 
                           table_info=table_info, 
                           plot_flag=False)
        print(f"Bot: {response}")

You:  What are the average salaries for each job title?


C:\Users\sayantghosh\AppData\Local\anaconda3\envs\genai\Lib\site-packages\pydantic\main.py:1114: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.9/migration/


Bot: average_salary,job_title
21352.25,3D Computer Vision Researcher
136666.0909090909,AI Developer
55000.0,AI Programmer
110120.875,AI Scientist
152368.63106796116,Analytics Engineer



You:  What are the top 20 average salaries for each job title?


C:\Users\sayantghosh\AppData\Local\anaconda3\envs\genai\Lib\site-packages\pydantic\main.py:1114: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.9/migration/


Bot: average_salary,job_title
375000.0,Data Science Tech Lead
250000.0,Cloud Data Architect
212500.0,Data Lead
211254.5,Data Analytics Lead
198171.125,Principal Data Scientist
195140.72727272726,Director of Data Science
192500.0,Principal Data Engineer
192420.0,Machine Learning Software Engineer
191278.77586206896,Data Science Manager
190264.4827586207,Applied Scientist
190000.0,Principal Machine Learning Engineer
183857.5,Head of Data
175051.66666666666,Data Infrastructure Engineer
174150.0,Business Intelligence Engineer
163220.07692307694,Machine Learning Scientist
163108.37837837837,Research Engineer
161713.77227722772,Data Architect
161214.19512195123,Research Scientist
160591.66666666666,Head of Data Science
158352.4411764706,ML Engineer



You:  For AI Programmer what is the min and max of the salary?


C:\Users\sayantghosh\AppData\Local\anaconda3\envs\genai\Lib\site-packages\pydantic\main.py:1114: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.9/migration/


Bot: MIN(salary_in_usd),MAX(salary_in_usd)
40000,70000



You:  exit


In [97]:
# Save chat history to a JSON file
with open('chat_history.json', 'w') as f:
    json.dump(chat_history, f, indent=4)