In [1]:
import pandas as pd
from sqlalchemy import create_engine

In [4]:
import getpass
import os
from langchain_groq import ChatGroq

os.environ["GROQ_API_KEY"] = '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'
llm = ChatGroq(model="llama3-8b-8192")

In [5]:
llm.invoke('Hi')

AIMessage(content="Hi! It's nice to meet you. Is there something I can help you with or would you like to chat?", response_metadata={'token_usage': {'completion_tokens': 25, 'prompt_tokens': 11, 'total_tokens': 36, 'completion_time': 0.020833333, 'prompt_time': 0.001737674, 'queue_time': 0.021624417, 'total_time': 0.022571007}, 'model_name': 'llama3-8b-8192', 'system_fingerprint': 'fp_6a6771ae9c', 'finish_reason': 'stop', 'logprobs': None}, id='run-7c95fced-4ca0-4b15-a74f-4976ed3f3066-0')

In [3]:
DATABSE_URL = "sqlite:///chatbot.db"
engine = create_engine(DATABSE_URL)
df = pd.read_csv('ds_salaries.csv')
df.to_sql('chat_data',con=engine,index=False,if_exists='replace')

3755

## Simple Database Question Answering

In [11]:
from langchain.agents import Tool
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain_experimental.sql import SQLDatabaseChain
from langchain_community.utilities import SQLDatabase

In [12]:
db = SQLDatabase.from_uri("sqlite:///chatbot.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM chat_data LIMIT 10;")

sqlite
['chat_data']


"[(2023, 'SE', 'FT', 'Principal Data Scientist', 80000, 'EUR', 85847, 'ES', 100, 'ES', 'L'), (2023, 'MI', 'CT', 'ML Engineer', 30000, 'USD', 30000, 'US', 100, 'US', 'S'), (2023, 'MI', 'CT', 'ML Engineer', 25500, 'USD', 25500, 'US', 100, 'US', 'S'), (2023, 'SE', 'FT', 'Data Scientist', 175000, 'USD', 175000, 'CA', 100, 'CA', 'M'), (2023, 'SE', 'FT', 'Data Scientist', 120000, 'USD', 120000, 'CA', 100, 'CA', 'M'), (2023, 'SE', 'FT', 'Applied Scientist', 222200, 'USD', 222200, 'US', 0, 'US', 'L'), (2023, 'SE', 'FT', 'Applied Scientist', 136000, 'USD', 136000, 'US', 0, 'US', 'L'), (2023, 'SE', 'FT', 'Data Scientist', 219000, 'USD', 219000, 'CA', 0, 'CA', 'M'), (2023, 'SE', 'FT', 'Data Scientist', 141000, 'USD', 141000, 'CA', 0, 'CA', 'M'), (2023, 'SE', 'FT', 'Data Scientist', 147100, 'USD', 147100, 'US', 0, 'US', 'M')]"

## Generate SQL Query Function

In [19]:
from langchain import hub

query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")

Please use the `langsmith sdk` instead:
  pip install langsmith
Use the `pull_prompt` method.
  res_dict = client.pull_repo(owner_repo_commit)


In [48]:
from typing import Optional
from langchain_groq import ChatGroq
from langchain_core.pydantic_v1 import BaseModel, Field

class QueryOutput(BaseModel):
    """Generated SQL query."""
    query: str = Field(..., description="Syntactically valid SQL query.")

def generate_sql_query(user_question: str, table_info: str, dialect: str = 'sqlite') -> str:
    """Generate SQL query from user question."""
    prompt_template = PromptTemplate(
        input_variables=["dialect", "table_info", "input"],
        template=f"""
        Given an input question, create a syntactically correct {dialect} query to run to help find the answer. 
        Unless the user specifies in his question a specific number of examples they wish to obtain, 
        always limit your query to at most 5 results. You can order the results by a relevant column to return the most interesting examples in the database.

        Never query for all the columns from a specific table, only ask for the few relevant columns given the question.

        Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

        Only use the following tables:
        {table_info}

        Example 1:
        Question: "What is the average salary for a data scientist?"
        SQL Query:
        SELECT AVG(salary) 
        FROM chat_data
        WHERE job_title = 'data scientist'
        LIMIT 5;
    
        Example 2:
        Question: "Show me the salary details for employees working remotely."
        SQL Query:
        SELECT salary, salary_currency, remote_ratio
        FROM chat_data
        WHERE remote_ratio > 0
        LIMIT 5;
    
        Example 3:
        Question: "How many employees are there in large companies?"
        SQL Query:
        SELECT COUNT(*) 
        FROM chat_data
        WHERE company_size = 'large'
        LIMIT 5;
    
        Example 4:
        Question: "Give me the top 5 highest-paying job titles."
        SQL Query:
        SELECT job_title, salary 
        FROM chat_data
        ORDER BY salary DESC
        LIMIT 5;
    
        Example 5:
        Question: "What is the average salary in USD by job title?"
        SQL Query:
        SELECT job_title, AVG(salary_in_usd) 
        FROM chat_data
        GROUP BY job_title
        LIMIT 5;
    
        Example 6:
        Question: "What is the salary for software engineers with more than 5 years of experience?"
        SQL Query:
        SELECT salary, experience_level
        FROM chat_data
        WHERE job_title = 'software engineer' AND work_year > 5
        LIMIT 5;

        Question: {user_question}
        """
    )
    prompt = prompt_template.invoke(
        {
            "dialect": dialect,
            "table_info": table_info,
            "input": user_question,
        }
    )
    llm = ChatGroq(model="llama3-8b-8192")
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return result.query

In [49]:
table_info = """
    Table: chat_data
    Columns:
    - work_year: INTEGER
    - experience_level: TEXT
    - employment_type: TEXT
    - job_title: TEXT
    - salary: INTEGER
    - salary_currency: TEXT
    - salary_in_usd: INTEGER
    - employee_residence: TEXT
    - remote_ratio: INTEGER
    - company_location: TEXT
    - company_size: TEXT
    """

In [50]:
user_question = "What are the average salaries for each job title?"
sql_query = generate_sql_query(user_question, table_info)
print(sql_query)

SELECT job_title, AVG(salary_in_usd) AS average_salary FROM chat_data GROUP BY job_title LIMIT 5;


C:\Users\sayantghosh\AppData\Local\anaconda3\envs\genai\Lib\site-packages\pydantic\main.py:1114: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.9/migration/


## Execute the SQL Query

In [51]:
def execute_sql_query(query: str, engine) -> pd.DataFrame:
    """Execute the generated SQL query and return the results."""
    return pd.read_sql_query(query, con=engine)

In [52]:
execute_sql_query(sql_query,engine)

Unnamed: 0,job_title,average_salary
0,3D Computer Vision Researcher,21352.25
1,AI Developer,136666.090909
2,AI Programmer,55000.0
3,AI Scientist,110120.875
4,Analytics Engineer,152368.631068


## 5. Implement Chatbot Functionality with Memory

Create a function to handle user inputs, generate SQL queries, execute them, and maintain chat history.

In [53]:
import json
from langchain.memory import ConversationBufferMemory
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


# Initialize memory to store chat history
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
chat_history = memory.load_memory_variables({})["chat_history"]

def generate_plot(df: pd.DataFrame, plot_flag: bool) -> Optional[plt.Figure]:
    """Generate a plot based on the DataFrame columns if plot_flag is True."""
    if plot_flag and not df.empty:
        # Plot dynamically based on the columns of the DataFrame
        plot = None
        num_columns = len(df.columns)

        if num_columns == 1:
            plot = df.plot(kind="bar", title="Single Column Plot")
        elif num_columns == 2:
            plot = df.plot(kind="scatter", x=df.columns[0], y=df.columns[1], title="Scatter Plot")
        elif num_columns > 2:
            plot = sns.pairplot(df)  # Create a pairplot for multiple columns

        plt.show()
        return plot
    return None

def chatbot(user_question: str, engine, table_info: str, plot_flag: bool = False, dialect: str = 'sqlite') -> str:
    """Handle user input, generate SQL query, execute it, and maintain chat history."""
    
    
    # Generate SQL query from user question
    sql_query = generate_sql_query(user_question, table_info)
    
    # Execute the SQL query
    query_result = execute_sql_query(sql_query, engine)

    #Generate Plot
    generate_plot(query_result,plot_flag)
    
    # Convert query result to a string (e.g., CSV format)
    result_str = query_result.to_csv(index=False)

    chat_history.append({"role": "user", "content": user_question})
    chat_history.append({"role": "assistant", "content": result_str})
    
    return result_str


In [54]:
if __name__ == "__main__":
    # Define table information (e.g., column names and types)  
    while True:
        user_input = input("You: ")
        if user_input.lower() == 'exit':
            break
        response = chatbot(user_question=user_input,
                           engine=engine, 
                           table_info=table_info, 
                           plot_flag=False)
        print(f"Bot: {response}")

You:  What are the average salaries for each job title?


C:\Users\sayantghosh\AppData\Local\anaconda3\envs\genai\Lib\site-packages\pydantic\main.py:1114: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.9/migration/


Bot: job_title,AVG(salary_in_usd)
3D Computer Vision Researcher,21352.25
AI Developer,136666.0909090909
AI Programmer,55000.0
AI Scientist,110120.875
Analytics Engineer,152368.63106796116



You:  For the Analytics Enginner whats the min,max salary?


C:\Users\sayantghosh\AppData\Local\anaconda3\envs\genai\Lib\site-packages\pydantic\main.py:1114: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.9/migration/


Bot: min_salary,max_salary
7500,289800



You:  exit


In [55]:
# Save chat history to a JSON file
with open('chat_history.json', 'w') as f:
    json.dump(chat_history, f, indent=4)