In [1]:
import os
from dotenv import load_dotenv

load_dotenv()
os.environ["GEMINI_API_KEY"]

'AIzaSyBWDfwt7TPz1i6L9EDMaP-XnSgYDhEgm-Y'

In [2]:
from langchain_google_genai import GoogleGenerativeAI

llm=GoogleGenerativeAI(model='gemini-pro',google_api_key=os.environ["GEMINI_API_KEY"])

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain

In [4]:
db_user="gaurav"
db_password="1234"
db_host="localhost"
db_name="retail_sales_db"

db=SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}",sample_rows_in_table_info=3)

print(db.table_info)


CREATE TABLE sales_tb (
	`TransactionID` INTEGER, 
	`Date` DATE, 
	`CustomerID` VARCHAR(10), 
	`Gender` VARCHAR(10), 
	`Age` INTEGER, 
	`ProductCategory` VARCHAR(50), 
	`Quantity` INTEGER, 
	`PriceperUnit` DECIMAL(10, 2), 
	`TotalAmount` DECIMAL(10, 2)
)DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci

/*
3 rows from sales_tb table:
TransactionID	Date	CustomerID	Gender	Age	ProductCategory	Quantity	PriceperUnit	TotalAmount
1	2023-11-24	CUST001	Male	34	Beauty	3	50.00	150.00
2	2023-02-27	CUST002	Female	26	Clothing	2	500.00	1000.00
3	2023-01-13	CUST003	Male	50	Electronics	1	30.00	30.00
*/


In [5]:
from langchain.chains import create_sql_query_chain

chain=create_sql_query_chain(llm,db)
response=chain.invoke({"question":"How many male customers are there?"})
response

"```sql\nSELECT COUNT(DISTINCT `CustomerID`) AS `Number of Male Customers`\nFROM sales_tb\nWHERE `Gender` = 'Male';\n```"

In [6]:
clean_query=response.strip('```sql\n').strip('\n```')
print(clean_query)

SELECT COUNT(DISTINCT `CustomerID`) AS `Number of Male Customers`
FROM sales_tb
WHERE `Gender` = 'Male';


In [7]:
result=db.run(clean_query)
result

'[(14,)]'

In [10]:

from sqlalchemy.exc import ProgrammingError
def sql_chain(question):
    try:
        response=chain.invoke({"question":question})
        clean_query=response.strip('```sql\n').strip('\n```')
        print("SQL Query:",clean_query)
        result=db.run(clean_query)
        print("Result:",result)
    except ProgrammingError as e:
        print(f"An error occured: {e}")

In [11]:
sql_chain("How many unique customers are there for each product query?")

SQL Query: SELECT 
  `ProductCategory`, 
  COUNT(DISTINCT `CustomerID`) AS `UniqueCustomers`
FROM 
  `sales_tb`
GROUP BY 
  `ProductCategory`
ORDER BY 
  `UniqueCustomers` DESC
LIMIT 
  5;
Result: [('Clothing', 13), ('Beauty', 8), ('Electronics', 8)]


In [12]:
sql_chain("Calculate the total sales amount per product category")

SQL Query: SELECT `ProductCategory`, SUM(`TotalAmount`) AS `TotalSalesAmount`
FROM sales_tb
GROUP BY `ProductCategory`
ORDER BY `TotalSalesAmount` DESC
LIMIT 5;
Result: [('Electronics', Decimal('5310.00')), ('Clothing', Decimal('5040.00')), ('Beauty', Decimal('1455.00'))]
