In [1]:
import pandas as pd
import mysql.connector
from dotenv import load_dotenv
import os

# Load environment variables
load_dotenv()

mydb = mysql.connector.connect(
    host=os.getenv('MYSQL_HOST'),
    user=os.getenv('MYSQL_USER'),
    password=os.getenv('MYSQL_PASSWORD'),
    database=os.getenv('MYSQL_DATABASE')
)
cursor = mydb.cursor()

cursor.execute("select * from accounts")

results = cursor.fetchall();
for r in results:
    print(r)


(1, 1, 'Savings', Decimal('5000.00'), Decimal('1000.00'), datetime.date(2020, 2, 1))
(2, 2, 'Checking', Decimal('2500.00'), Decimal('500.00'), datetime.date(2019, 11, 15))
(3, 3, 'Credit Card', Decimal('300.00'), Decimal('100.00'), datetime.date(2021, 1, 20))
(4, 4, 'Savings', Decimal('800.00'), Decimal('200.00'), datetime.date(2023, 6, 10))
(5, 5, 'Checking', Decimal('600.00'), Decimal('300.00'), datetime.date(2020, 12, 30))


In [2]:
df = pd.read_csv('churn_data.csv')

create_table = """CREATE TABLE IF NOT EXISTS churn (
    customer_id VARCHAR(50) PRIMARY KEY,
    gender VARCHAR(10),
    customer_age INT,
    marital_status VARCHAR(20),
    occupation VARCHAR(50),
    location VARCHAR(50),
    customer_credit_score INT,
    customer_total_transactions INT,
    customer_total_cash_balance DECIMAL(15, 2),
    customer_total_cards INT,                    -- Total number of cards
    customer_total_accounts INT,                 -- Total number of accounts
    customer_total_debt DECIMAL(15, 2),          -- Total debt or outstanding balance
    customer_total_loans INT,                    -- Total number of loans
    customer_total_savings_accounts INT,         -- Total number of savings accounts
    customer_total_checking_accounts INT,        -- Total number of checking accounts
    customer_total_credit_accounts INT,          -- Total number of credit accounts
    customer_total_mortgage_accounts INT,        -- Total number of mortgage accounts
    customer_account_utilization_rate DECIMAL(15, 2), -- Utilization rate (credit utilization)
    customer_long_term_savings DECIMAL(15, 2),   -- Long-term savings balance
    customer_short_term_savings DECIMAL(15, 2),  -- Short-term savings balance
    customer_active_credit_cards INT,            -- Number of active credit cards
    customer_overdue_payments INT                -- Number of overdue payments
)"""

cursor.execute(create_table)

for index, row in df.iterrows():
    insert_query = """INSERT IGNORE INTO churn (customer_id, gender, customer_age, marital_status, occupation, location, customer_credit_score, 
    customer_total_transactions, customer_total_cash_balance, customer_total_cards, customer_total_accounts, 
    customer_total_debt, customer_total_loans, customer_total_savings_accounts, customer_total_checking_accounts, 
    customer_total_credit_accounts, customer_total_mortgage_accounts, customer_account_utilization_rate, 
    customer_long_term_savings, customer_short_term_savings, customer_active_credit_cards, customer_overdue_payments)
    VALUES(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
    """
    values = tuple(row)
    cursor.execute(insert_query, values)

cursor.execute("select * from churn")
results = cursor.fetchall()
for r in results:
    print(r)
mydb.commit()
cursor.close()
mydb.close()

('CUST_1', 'Male', 30, 'Married', 'Physical Therapist', 'Columbus', 760, 34, Decimal('72230.44'), 3, 5, Decimal('6386.85'), 3, 2, 2, 2, 7, Decimal('0.41'), Decimal('62015.37'), Decimal('25426.30'), 1, 0)
('CUST_10', 'Female', 63, 'Widowed', 'Producer', 'New York', 755, 63, Decimal('15302.87'), 3, 4, Decimal('24318.83'), 3, 0, 15, 9, 7, Decimal('0.50'), Decimal('128820.78'), Decimal('64410.39'), 16, 0)
('CUST_100', 'Male', 64, 'Widowed', 'Waiter/Waitress', 'Phoenix', 581, 34, Decimal('10028.84'), 2, 5, Decimal('20476.51'), 1, 0, 7, 0, 4, Decimal('0.97'), Decimal('137007.54'), Decimal('132897.31'), 12, 0)
('CUST_11', 'Female', 54, 'Married', 'Digital Marketer', 'Los Angeles', 512, 79, Decimal('98289.80'), 4, 6, Decimal('46921.25'), 2, 2, 8, 3, 5, Decimal('0.33'), Decimal('139924.75'), Decimal('46175.17'), 13, 1)
('CUST_12', 'Male', 22, 'Married', 'Producer', 'Dallas', 385, 26, Decimal('9340.38'), 1, 5, Decimal('22747.30'), 4, 1, 18, 0, 8, Decimal('0.76'), Decimal('50762.89'), Decimal('38

In [3]:

from groq import Groq

schema_info = """
You have access to the following database schema:
Table 'churn': 
    customer_id (VARCHAR),
    gender (VARCHAR),
    customer_age (INT),
    marital_status (VARCHAR),
    occupation (VARCHAR),
    location (VARCHAR),
    customer_credit_score (INT),
    customer_total_transactions (INT),
    customer_total_cash_balance (DECIMAL),
    customer_total_cards (INT),
    customer_total_accounts (INT),
    customer_total_debt (DECIMAL),
    customer_total_loans (INT),
    customer_total_savings_accounts (INT),
    customer_total_checking_accounts (INT),
    customer_total_credit_accounts (INT),
    customer_total_mortgage_accounts (INT),
    customer_account_utilization_rate (DECIMAL),
    customer_long_term_savings (DECIMAL),
    customer_short_term_savings (DECIMAL),
    customer_active_credit_cards (INT),
    customer_overdue_payments (INT)

Only return a **single SQL query** that answers the question. Do not include any explanations, additional text, or comments. The query should be formatted properly and valid for execution in a MySQL database.
"""

client = Groq(
    api_key=os.environ.get("GROQ_API_KEY"),
)

def answer(question):
    chat_completion = client.chat.completions.create(
        messages=[
            {
                "role": "user",
                "content": schema_info+ "Question: " + question,
            },
            {
                "role": "assistant",
                "content": "```SQL"
            }
        ],
        model="llama3-8b-8192",
    )
    ret = chat_completion.choices[0].message.content
    print(ret)
    return ret


In [4]:
def execute_sql(sql_query):
    connection = mysql.connector.connect(
    host=os.getenv('MYSQL_HOST'),
    user=os.getenv('MYSQL_USER'),
    password=os.getenv('MYSQL_PASSWORD'),
    database=os.getenv('MYSQL_DATABASE')
)

    cursor = connection.cursor();
    try:
        cursor.execute(sql_query)
        result = cursor.fetchall()
    except mysql.connector.Eroor as e:
        print(f"Error: {e}")
        result = None
    finally:
        cursor.close()
        connection.close()
    return result
    

In [5]:
def get_answer_from_llm(question):
    sql_query = answer(question) 
    result = execute_sql(sql_query)

    if result:
        return f"The answer is: {result}"
    else:
        return "No result found or an error occurred."

In [6]:
print(get_answer_from_llm("How many rows of data are there in the table?"))
print(get_answer_from_llm("How many customers have a credit score above 700?"))
print(get_answer_from_llm("What is the average cash balance of customers who have more than 2 credit accounts and fewer than 3 overdue payments?"))
print(get_answer_from_llm("List the top 5 locations with the highest average customer debt, showing the location and average debt."))
print(get_answer_from_llm("For customers with more than 2 active credit cards, what is the total number of transactions made by customers with a credit score below 600?"))
print(get_answer_from_llm("how many people that are married have more than 2 active credit/debit cards"))


SELECT COUNT(*) AS total_rows FROM churn;
The answer is: [(100,)]

SELECT COUNT(*)
FROM churn
WHERE customer_credit_score > 700;
The answer is: [(23,)]

SELECT AVG(c.customer_total_cash_balance) 
FROM churn c 
WHERE c.customer_total_credit_accounts > 2 AND c.customer_overdue_payments < 3;
The answer is: [(Decimal('50834.962899'),)]

SELECT location, AVG(customer_total_debt) AS average_debt
FROM churn
GROUP BY location
ORDER BY average_debt DESC
LIMIT 5;
The answer is: [('Los Angeles', Decimal('40452.860000')), ('Denver', Decimal('37942.110000')), ('Philadelphia', Decimal('37613.285000')), ('Jacksonville', Decimal('36831.083333')), ('Dallas', Decimal('33218.710000'))]

SELECT SUM(churn.customer_total_transactions) 
FROM churn 
WHERE churn.customer_id IN (SELECT churn.customer_id 
                            FROM churn 
                            WHERE churn.customer_active_credit_cards > 2) 
  AND churn.customer_credit_score < 600;
The answer is: [(Decimal('2499'),)]

SELECT COUNT(*)
