In [None]:
!pip install chromadb
!pip install fuzzywuzzy[speedup]
!pip install openai
!pip install SQLAlchemy==1.4.46
!pip install sqlglot

In [4]:
# Import Chroma and instantiate a client. The default Chroma client is ephemeral, meaning it will not save to disk.
import chromadb
import spacy
from fuzzywuzzy import process ,fuzz
from openai import OpenAI
import re
import sqlalchemy
from sqlalchemy import create_engine, text, sql
import pandas as pd
from chromadb.config import Settings
import openai
import sqlglot
from sqlglot.optimizer import optimize
from chromadb.utils import embedding_functions
from typing import List
import warnings
import os
import plotly
import plotly.express as px
import plotly.graph_objects as go
from matplotlib import pyplot as plt
import logging
import plotly.express as px
import plotly.subplots as sp


logging.getLogger().setLevel(logging.ERROR)

warnings.filterwarnings('ignore')
pd.set_option('display.max_columns', None)  # Unlimited columns
pd.set_option('display.width', 500)     # Width of the display in the notebook

In [89]:
dataset = [{
    "question":"total number of orders for the last 30 days",
    "sql":'''
            select date(created_at) as date,
            count(distinct id) as orders

            from orders
            where date(created_at) between current_date-30 and current_date
            group by 1
   '''},

    {"question" : "what is my retention monthly and store wise?" ,
     "sql" : '''
        with base as (
            select store ,date_trunc('MONTH',created_at)::DATE as month , email from orders group by 1,2,3
        ) ,

        ret_base as (
            select a.store , a.month as base_month, b.month as retained_month , count(distinct b.email) as retained_users
            from base a
            join base b on b.month >= a.month and b.email = a.email
            group by 1,2,3
        ),

        totals_base as (
        select a.store, a.month as base_month, count(distinct a.email) as total_users
        from base a
        group by 1,2)

        select a.store, a.base_month, a.retained_month, a.retained_users*100.00/b.total_users as percentage_retained_users
        from ret_base a
        left join totals_base b on a.store=b.store and a.base_month=b.base_month
        order by a.store, a.base_month, a.retained_month
     '''
     } ,

   {
       "question" : "percentage of new users this month" ,
       "sql" : '''
                with new_user_base as(
                    select email , min(date_trunc('MONTH',created_at)::DATE) as acq_month  from orders group by 1
                ) ,

                orders_base as(
                    select date_trunc('MONTH',created_at)::DATE as month , email from orders where date(created_at) >= date_trunc('MONTH',created_at)::DATE group by 1,2
                )

                select 100.0*count(distinct case when orders_base.month = new_user_base.acq_month then a.email end)/count(distinct orders_base.email) as new_users_percentage from orders_base a join new_user_base on orders_base.email = new_user_base.email

       '''
   } ,

 {"question" : "how does the month on month retention percentage look like?",
    "sql" : '''

    with base as (
        select store ,(case when source_name = 'pos' then 'pos' else 'online' end) as source , date_trunc('MONTH',created_at)::DATE as month , email from orders group by 1,2,3,4
    ) ,

    ret_base as (
        select a.store , a.source , a.month as base_month, b.month as retained_month , count(distinct b.email) as retained_users
        from base a
        join base b on b.month >= a.month and b.email = a.email
        group by 1,2,3,4
    ),

    totals_base as (
    select a.store, a.source, a.month as base_month, count(distinct a.email) as total_users
    from base a
    group by 1,2,3)

    select a.store, a.source, a.base_month, a.retained_month, a.retained_users*100.00/b.total_users as percentage_retained_users
    from ret_base a
    left join totals_base b on a.store=b.store and a.source=b.source and a.base_month=b.base_month

    '''} ,

{
    "question" : "what is month , store and source wise retention and churn for new and existing users" ,
	"sql" : '''
	with user_base as (
        select email , min(date_trunc('MONTH',created_at)::DATE) as acq_month  from orders group by 1
    ) ,

    order_base as (
        select store ,(case when source_name = 'pos' then 'pos' else 'online' end) as source, email , date_trunc('MONTH',created_at)::DATE as month from orders group by 1,2,3,4
    ) ,

    base as (
        select a.store ,a.source , a.month , a.email , case when a.month = b.acq_month then 'new' else 'existing' end as type
        from order_base a join user_base b on a.email = b.email
    ),

    base1 as (
        select store ,source , month , count(distinct email) as total_users ,
        count(distinct case when type = 'new' then email end) as new_users ,
        count(distinct case when type = 'existing' then email end) as existing_users
        from base group by 1,2,3
    ) ,

     ret_base as (
        select a.store ,a.source , a.month as base_month ,b.month as ret_month ,
        count(distinct case when a.type = 'new' then a.email end) as new_retained ,
        count(distinct case when a.type = 'existing' then a.email end) as existing_retained ,
        count(distinct a.email) as total_retained
        from base a
        join base b on b.month-a.month between 25 and 32 and b.email = a.email
        group by 1,2,3,4
     )



    select a.store ,a.source , a.month as base_month , a.total_users , 100.0*a.new_users/a.total_users as new_users_share ,
    100.0*a.existing_users/a.total_users as existing_users_share  ,
    100.0*b.total_retained/a.total_users as retention , 100.0-(100.0*b.total_retained/a.total_users) as churn ,
    100.0*b.new_retained/a.new_users as new_user_retention , 100.0-(100.0*b.new_retained/a.new_users) as new_user_churn ,
    100.0*b.existing_retained/a.existing_users as existing_user_retention , 100.0-(100.0*b.existing_retained/a.existing_users) as existing_user_churn
    from base1 a left join ret_base b on a.month = b.base_month and a.store = b.store and a.source = b.source
  '''
} ,



{
    "question":"find the store wise number of orders for the last 30 days",
    "sql":'''
            select date(created_at) as date,
            store,
            count(distinct id) as orders

            from orders
            where date(created_at) between current_date-30 and current_date
            group by 1,2
   '''},

{
    "question":"find total sales of different product types in last 30 days",
    "sql":'''
             select product_type,

             sum(item_selling_price::float) as value_sold

             from order_item
             where product_type IN ('Personal & Home','Grains & Flour','Spices & Condiments','Instant Food & Beverages')
             and (cast(created_at as date) between current_date-30 and current_date)
             group by 1
   '''},

{
    "question":"find interacted users, logins, searched, collection_viewed, product_viewed, atc, ordered users information for last 30 days",
    "sql":'''with event_base as
          (
                select timestamp::date as date  , store , count(distinct email) as interacted_users ,
                count(distinct case when event = 'Login' and email is not null then email end) as login ,
                count(distinct case when event = 'Search Term' and email is not null then email end) as searched ,
                count(distinct case when event = 'Collection Viewed' and email is not null then email end) as collection_viewed ,
                count(distinct case when event in ('Viewed Product','Product Viewed') and email is not null then email end) as product_viewed ,
                count(distinct case when event in ('Added To Cart') and email is not null then email end) as atc ,
                count(distinct case when event in ('Placed Order') and email is not null then email end) as ordered
                from events
                where ((source_name is null) or  (source_name in ('web','app')))
                and timestamp::date between current_date-30 and current_date
                group by 1,2
          ) ,

          order_base as
          (
                select date ,store ,  count(distinct id) as orders, count(distinct email) as customers
                from order_item
                where ((source_name is null) or  (source_name in ('web','app')))
                and created_at::date between current_date-30 and current_date
                group by 1,2
          )

           select e.date , e.store , e.interacted_users , e.login ,
           e.searched ,e.collection_viewed , e.product_viewed ,
           e.atc , e.ordered , o.orders ,o.customers
           from event_base e
           left join order_base o on e.date = o.date and e.store = o.store
           order by 2,1'''},

{
    "question":"find storewise total user logins in the last 30 days",
    "sql":'''     select cast(timestamp as date) as date,
                store,
                count(distinct email) as login

                from events
                where (source_name is null) or  (source_name in ('web','app'))
                and cast(timestamp as date) between current_date-30 and current_date
                and event = 'Login'
                group by 1,2'''},

{
    "question":"find storewise total user searches in the last 30 days",
    "sql":'''     select cast(timestamp as date) as date,
                store,
                count(distinct email) as Searches

                from events
                where (source_name is null) or  (source_name in ('web','app'))
                and cast(timestamp as date) between current_date-30 and current_date
                and event = 'Search Term'
                group by 1,2'''},

{
    "question":"find storewise total users who viewed collections in the last 30 days",
    "sql":'''     select cast(timestamp as date) as date,
                store,
                count(distinct email) as count_collection_views

                from events
                where (source_name is null) or  (source_name in ('web','app'))
                and cast(timestamp as date) between current_date-30 and current_date
                and event = 'Collection Viewed'
                group by 1,2'''},

{
    "question":"find storewise product views in the last 30 days",
    "sql":'''     select cast(timestamp as date) as date,
                store,
                count(distinct email) as count_product_views

                from events
                where (source_name is null) or  (source_name in ('web','app'))
                and cast(timestamp as date) between current_date-30 and current_date
                and event IN ('Viewed Product','Product Viewed')
                group by 1,2'''},


{
    "question":"find storewise product views for the last 30 days",
    "sql":'''     select cast(timestamp as date) as date,
                store,
                count(distinct email) as orders_placed

                from events

                where (source_name is null) or  (source_name in ('web','app'))
                and cast(timestamp as date) between current_date-30 and current_date
                and event IN ('Placed Order')
                group by 1,2'''},

{
    "question":"find storwise total user logins in last 30 days",
    "sql":'''     with base as
                (select cast(timestamp as date) as date,
                store,
                count(distinct email) as login

                from events

                where (source_name is null) or  (source_name in ('web','app'))
                and cast(timestamp as date) between current_date-30 and current_date
                and event = 'Login'
                group by 1,2)

                select store, sum(login) as total_logins_30_days
                from base
                group by 1'''},

{
    "question":"find storwise total user logins in last 30 days",
    "sql":'''     with base as
                (select cast(timestamp as date) as date,
                store,
                count(distinct email) as login

                from events

                where (source_name is null) or  (source_name in ('web','app'))
                and cast(timestamp as date) between current_date-30 and current_date
                and event = 'Login'
                group by 1,2)

                select store, sum(login) as total_logins_30_days
                from base
                group by 1'''},

{
    "question":"can you find the store and source wise cart penetration for different product types for the last 30 days?",
    "sql": '''

    with ptype_occurences as
    (select store, source_name, product_type, id
    from order_item
    where cast(created_at as date) between current_date-30 and current_date
    group by 1,2,3,4),

    totals as
    (select store, source_name, count(distinct id) as total_orders
    from order_item
    where cast(created_at as date) between current_date-30 and current_date
    group by 1,2)

    select a.store, a.source_name, a.product_type, count(distinct a.id)*100.00/b.total_orders as occur_cart_pen
    from ptype_occurences a
    left join totals b on a.store = b.store and a.source_name = b.source_name
    group by 1,2,3,b.total_orders

    '''},

{
    "question":"What is the average store and source wise month on month retention%?",
    "sql": '''

    with base as (
        select store ,(case when source_name = 'pos' then 'pos' else 'online' end) as source , date_trunc('MONTH',created_at)::DATE as month , email from orders group by 1,2,3,4
    ) ,

     ret_base as (
        select a.store , a.source , a.month as base_month, b.month as retained_month , count(distinct b.email) as retained_users
        from base a
        join base b on b.month >= a.month and b.email = a.email
        group by 1,2,3,4
     ),

     totals_base as (
      select a.store, a.source, a.month as base_month, count(distinct a.email) as total_users
      from base a
      group by 1,2,3)

    select store, source, avg(1.00*percentage_retained_users) as avg_percentage_retained_users

    from

    (select a.store, a.source, a.base_month, a.retained_month, a.retained_users*100.00/b.total_users as percentage_retained_users
    from ret_base a
    left join totals_base b on a.store=b.store and a.source=b.source and a.base_month=b.base_month) a

    group by 1,2

    '''},

{
        "question":"Can you provide the breakdown of interacted users by store?",
        "sql":'''

        SELECT store, COUNT(DISTINCT email) AS interacted_users
        FROM events
        GROUP BY store'''},

{
        "question":"Can you provide the breakdown of interacted users by store for each event type?",
        "sql":'''

        select store, event, count(distinct email) as interacted_users
        from events
        group by 1,2'''},

{
        "question":"What are the top three stores with the highest number of interacted users?",
        "sql":'''
        WITH event_counts AS (
    SELECT store,
           COUNT(DISTINCT email) AS interacted_users
    FROM events
    WHERE source_name IN ('web', 'app')
      AND timestamp::DATE BETWEEN CURRENT_DATE - 30 AND CURRENT_DATE
    GROUP BY store
    )

    SELECT store, interacted_users
    FROM event_counts
    ORDER BY interacted_users DESC
    LIMIT 3'''},

{
        "question":"What are the top 10 selling items in the last 30 days?",
        "sql":'''with base as
        (select title_x as product_name, sum(item_selling_price::decimal) as total_sales
        from order_item
        where created_at::date between current_date-30 and current_date
        group by 1
        order by total_sales desc)

        select * from base
        limit 10'''},

{
        "question":"What are the store-wise top 10 selling items in the last 30 days?",
        "sql":'''with base as
        (select store, title_x as product_name, sum(item_selling_price::decimal) as total_sales,
        rank() OVER(PARTITION BY store order by sum(item_selling_price::decimal) desc) as sales_rank
        from order_item
        where created_at::date between current_date-30 and current_date
        group by 1,2)

        select store, product_name, total_sales, sales_rank
        from base
        where sales_rank<=10
        order by sales_rank, store'''},

{
        "question":"Which store has the highest sales in the last 30 days?",
        "sql":'''SELECT store, SUM(total_price) AS total_sales
                FROM orders
                WHERE created_at >= current_date - interval '30 days'
                GROUP BY store
                ORDER BY total_sales DESC
                LIMIT 1'''},

{
        "question":"What is the distribution of sales by day of the week for the highest-selling store in the last 30 days?",
        "sql":'''with highest_selling_store as (
        select store
        from orders
        where date(created_at) between current_date - 30 and current_date
        group by store
        order by sum(total_price) desc
        limit 1
    )

    select extract(dow from o.created_at::timestamp) as day_of_week,
           sum(o.total_price) as total_sales
    from orders o
    where o.store = (select * from highest_selling_store)
      and date(o.created_at) between current_date - 30 and current_date
    group by day_of_week
    order by day_of_week;'''},

{
        "question":"What percentage of total sales in the last 30 days for the highest-selling store came from repeat customers?",
        "sql":'''WITH total_sales AS (
    SELECT store, SUM(total_price) AS total_sales
    FROM orders
    WHERE created_at >= current_date - INTERVAL '30 days'
    GROUP BY store
),
repeat_customer_sales AS (
    SELECT store, SUM(total_price) AS repeat_sales
    FROM orders
    WHERE created_at >= current_date - INTERVAL '30 days'
    AND email IN (
        SELECT email
        FROM orders
        WHERE created_at < current_date - INTERVAL '30 days'
    )
    GROUP BY store
),
highest_selling_store AS (
    SELECT store
    FROM total_sales
    ORDER BY total_sales DESC
    LIMIT 1
)

SELECT t.store,
    (r.repeat_sales / t.total_sales) * 100 AS repeat_customer_percentage
FROM total_sales t
JOIN repeat_customer_sales r ON t.store = r.store
CROSS JOIN highest_selling_store; '''},

{
        "question":"Can you identify the top 10 products that are frequently purchased by repeat customers in the last 30 days for each store?",
        "sql":'''
with repeat_customer_sales AS (
    SELECT store, title_x as product_name, SUM(item_selling_price::float) AS repeat_sales,
    rank() OVER(PARTITION BY store order by SUM(item_selling_price::float) desc) as product_rank
    FROM order_item
    WHERE created_at >= current_date - INTERVAL '30 days'
    AND email IN (
        SELECT email
        FROM orders
        WHERE created_at < current_date - INTERVAL '30 days'
    )
    GROUP BY store, title_x)

    select store, product_name, repeat_sales, product_rank
    from repeat_customer_sales
    where product_rank<=10
    order by product_rank, store

    '''},

{
        "question":"How do user engagement metrics differ between weekdays and weekends?",
        "sql":'''with engagement_metrics as (
        select
            store,
            case
                when extract(dow from timestamp::date) in (0, 6) then 'Weekend'
                else 'Weekday'
            end as day_type,
            count(distinct case when event = 'Login' then email end) as logins,
            count(distinct case when event = 'Search Term' then email end) as searches,
            count(distinct case when event = 'Collection Viewed' then email end) as collections_viewed,
            count(distinct case when event in ('Viewed Product', 'Product Viewed') then email end) as products_viewed,
            count(distinct case when event = 'Added To Cart' then email end) as added_to_cart,
            count(distinct case when event = 'Placed Order' then email end) as orders_placed
        from events
        group by 1, 2
    )
    select
        store,
        day_type,
        avg(logins) as avg_logins,
        avg(searches) as avg_searches,
        avg(collections_viewed) as avg_collections_viewed,
        avg(products_viewed) as avg_products_viewed,
        avg(added_to_cart) as avg_added_to_cart,
        avg(orders_placed) as avg_orders_placed
    from engagement_metrics
    group by 1, 2'''},

{"question":"What is the average order value for each product type in the last 30 days?",
        "sql":'''select product_type, avg(item_selling_price::float) as avg_order_value
    from order_item
    where date(created_at) between current_date-30 and current_date
    group by product_type

    '''},

{"question":"is there a trend for discount and weekdays?",
        "sql":'''WITH discount_trend AS (
    SELECT
        EXTRACT(DOW FROM created_at::date) AS day_of_week,
        id,
        sum(coalesce(discount::float,0)) AS order_discount
    FROM order_item
    GROUP BY day_of_week, id
)

SELECT
    day_of_week,
    avg(order_discount) as avg_discount
FROM discount_trend
group by day_of_week
ORDER BY day_of_week

    '''},

{"question":"what are the top discounted product categories?",
        "sql":'''
with base as(
    SELECT product_type, id, SUM(discount::float) AS total_discount
FROM order_item
GROUP BY product_type, id)

select product_type, avg(total_discount) as avg_total_discount
from base
group by 1
ORDER BY avg_total_discount DESC

    '''},

{"question":"does more discount mean more orders?",
        "sql":'''
with base as(
    SELECT created_at::date as date_, id, SUM(discount::float) AS total_discount
FROM order_item
GROUP BY 1, id)

select date_, avg(total_discount) as avg_total_discount
from base
group by 1
ORDER BY 1

    '''},]


In [6]:
ddl = [

    """CREATE TABLE orders (
    id INT PRIMARY KEY,
    created_at TIMESTAMP,
    line_items JSON,
    store VARCHAR(255),
    email VARCHAR(255),
    shipping_lines JSON,
    source_name VARCHAR(255),
    source_identifier VARCHAR(255),
    total_line_items_price DECIMAL(10, 2),
    total_price DECIMAL(10, 2),
    total_discounts DECIMAL(10, 2)
    );"""  ,

    """CREATE TABLE order_item (
    id INT PRIMARY KEY,
    created_at TIMESTAMP,
    store VARCHAR(255),
    email VARCHAR(255),
    source_name VARCHAR(255),
    source_identifier VARCHAR(255),
    sku VARCHAR(255),
    quantity INT,
    title_x VARCHAR(255),
    date DATE,
    product_type VARCHAR(255),
    vendor VARCHAR(255),
    weight DECIMAL(10, 2),
    weight_unit VARCHAR(50),
    item_selling_price DECIMAL(10, 2),
    discount DECIMAL(5, 2)
    );""" ,

    """CREATE TABLE events (
    id INT PRIMARY KEY,
    timestamp VARCHAR(255),
    event VARCHAR(255),
    sku VARCHAR(255),
    product_name VARCHAR(255),
    store VARCHAR(255),
    email VARCHAR(255),
    source_name VARCHAR(255)
    );"""
]

In [7]:
docs = [
"If similar question present in example straightforward use that query. " ,
"Week is calculated as : date_Trunc('week' , created_at)::date as week"
]

In [8]:
benchmark = [
{'metric' : 'order'  , 'value': 500} ,{'metric' : 'retention'  , 'value':  15} ,{'metric' : 'aov'  , 'value':  70} ,{'metric' : 'churn'  , 'value':  75} ,{'metric' : 'conversion'  , 'value':  10} ,
{'metric' : 'sales'  , 'value': 50000}
]

In [64]:
##############################################################################  Utility Functions   #############################################################################################

# Function to extract important aspects from the question
def extract_details(question):
    # Load the English language model
    nlp = spacy.load("en_core_web_sm")

    doc = nlp(question.lower())  # Convert to lowercase for case-insensitive matching
    actions = ['select', 'count', 'sum', 'average', 'total']  # Common SQL actions
    entities = []
    action = None

    analytics_words = ['retention', 'churn', 'conversion', 'growth', 'engagement', 'acquisition', 'orders', 'penetration', 'sales','interaction', 'event' ,
    'logins', 'searched', 'collection_viewed', 'product_viewed', 'atc' , 'new' ,'existing' ,'product views' ,'storewise' ,'store' ,'source' ,'user logins' , 'interacted users' ,'top selling']

    threshold = 90

    # Extract entities and action from the question
    for token in doc:
        # if token.text in actions:
        #     entities.append(token.text)
        # elif token.pos_ == 'NUM' or token.text.isdigit():  # Include digits as numbers
        #     entities.append(token.text)
        # elif token.pos_ == 'NOUN':  # Exclude 'days' from entities
        #     entities.append(token.text)
        # else:
            # Check for similarity with analytics words
        matches = process.extractOne(token.text, analytics_words)
        if matches[1] > threshold:
          entities.append(matches[0])

    return " ".join(entities)


#######################################

def modify_data(dataset):

    # Modify the dataset to include the concatenated important aspects
    dataset_with_concatenated_aspects = []
    for data in dataset:
        question = data["question"]
        concatenated_aspects = extract_details(question)
        data_with_concatenated_aspects = data.copy()  # Make a copy of the original data
        data_with_concatenated_aspects["important_aspects"] = concatenated_aspects  # Add the concatenated aspects
        dataset_with_concatenated_aspects.append(data_with_concatenated_aspects)

    # Print the modified dataset with important aspects concatenated into a single string
    important_aspects=[entry["important_aspects"] for entry in dataset_with_concatenated_aspects]
    question_texts = [entry["important_aspects"] for entry in dataset_with_concatenated_aspects]
    sql = [{"sql": entry["sql"],
        "question": entry["question"]} for entry in dataset_with_concatenated_aspects]
    return question_texts,sql



##########################################

def execute_query(sql_query):
    password_postgres = 'f15fd380627faa60a45fcc286d45f6e9610e'
    username_postgres = 'adjay'
    connection_string = '''postgresql://'''+f'''{username_postgres}'''+''':'''+f'''{password_postgres}'''+'''@dookan-dev.claxyccbejgz.eu-west-1.rds.amazonaws.com:5432/dookan'''
    engine = create_engine(connection_string)
    df = pd.DataFrame()
    with engine.connect() as conn, conn.begin():
        df = pd.read_sql(sql_query, conn)
        conn.close()

    return df




################################################################################
def check_question_db(question , data_sql):
    nlp = spacy.load("en_core_web_sm")
    lower_question = nlp(question.lower())
    matched_sql = None
    max_match = 0
    for i in data_sql['metadatas']:
        for k in i:
            q = nlp(k['question'].lower())
            f = fuzz.token_set_ratio(q , lower_question)
            if f >= 75 and extract_details(k['question']) == extract_details(question):
                if f > max_match:
                    matched_sql = k['sql']
                    max_match = f


    return matched_sql

In [54]:
##################################################################  ChromaDB class  ##############################################################################
class chroma_db:

    def __init__(self , client_type = "temporary"):
        if client_type == "temporary":
            self.client = chromadb.Client()
        else:
            self.client = chromadb.Client(Settings(chroma_db_impl="duckdb+parquet",
                                    persist_directory="db/"
                                ))
        self.collection_sql = self.client.get_or_create_collection(name = "sql")
        self.collection_ddl = self.client.get_or_create_collection(name = "ddl")
        self.collection_docs = self.client.get_or_create_collection(name = "docs")
        self.collection_benchmark = self.client.get_or_create_collection(name = "benchmark")
        self.embedding_function = embedding_functions.DefaultEmbeddingFunction()



    def generate_embedding(self, data, **kwargs) -> List[float]:
        embedding = self.embedding_function(data)
        if len(embedding) == 1:
            return embedding[0]
        return embedding


    def add_sql(self, question_texts , sql):
        ids =  [str(i) for i in range(len(question_texts))]
        self.collection_sql.upsert(
            documents = question_texts,
            metadatas = sql,
            ids = ids
        )


    def add_ddl(self, ddl ):
        ids = [str(i) for i in range(len(ddl))]
        self.collection_ddl.upsert(
            documents=ddl,
            #embeddings=self.generate_embedding(ddl),
            ids=ids ,
        )


    def add_docs(self, docs ):
        ids = [str(i) for i in range(len(docs))]
        self.collection_docs.upsert(
            documents=docs,
           # embeddings=self.generate_embedding(docs),
            ids=ids,
        )

    def add_benchmark(self,benchmark):
        ids = [str(i) for i in range(len(benchmark))]
        docs = [dic['metric'] for dic in benchmark]
        self.collection_benchmark.upsert(
            documents = docs,
            metadatas = benchmark,
            ids = ids
        )

    def view_sql_db(self):
        return self.collection_sql.get()

    def view_ddl_db(self):
        return self.collection_ddl.get()

    def view_docs_db(self):
        return self.collection_docs.get()

    def view_benchmark_db(self):
        return self.collection_benchmark.get()


    def update_data(self , ids , metadatas , documents):
        self.collection.update(ids = ids , documents = documents , metadatas = metadatas)

    def delete(self , ids):
        self.delete(ids = ids)

    def retrieve_sql_data(self , question , n_results = 3):
        ques_details = extract_details(question)
        data = self.collection_sql.query(
        query_texts=[ques_details],
        n_results=n_results
        )
        return data

    def retrieve_ddl_data(self , question ,n_results = 3):
        #ques_details = extract_details(question)
        data = self.collection_ddl.query(
        query_texts=[question],
        n_results=n_results
        )
        return data

    def retrieve_docs_data(self , question ,n_results = 3):
        #ques_details = extract_details(question)
        data = self.collection_docs.query(
        query_texts=[question],
        n_results=n_results
        )
        return data

    def retrieve_benchmark_data(self , question ,n_results = 1):
        ques_details = extract_details(question)
        data = self.collection_benchmark.query(
        query_texts=[ques_details],
        n_results=n_results
        )
        return data

In [86]:
#######################################################################  LLM model class  #####################################################################################
class customllm:
    def __init__(self , api_key : str , model : str , base_url : str):
        #self.client = OpenAI(api_key = api_key)
        self.client = OpenAI(base_url = base_url , api_key = api_key )
        #self.client = Groq(api_key= "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",)
        self.model = model

    def system_message(self , message: str) -> any:
        return {"role": "system", "content": message}


    def user_message(self , message: str) -> any:
        return {"role": "user", "content": message}


    def assistant_message(self , message: str) -> any:
        return {"role": "assistant", "content": message}

    def ask(self , question , sql_data , ddl_data , docs_data , temperature = 0.3):
        prompt = ""
        prompt = "[Instruction]The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n[~Instruction]"

        prompt += f"[Instruction]\nYou should always refer this defintions:\n\n[~Instruction]"
        docs = [doc for doc_list in docs_data['documents'] for doc in doc_list]

        for i in range(len(docs)):
            prompt = prompt + docs[i] + "\n"


        prompt += f"[Instruction]\nYou may use the following DDL statements as a reference for what tables might be available:\n\n[~Instruction]"
        ddl = [doc for doc_list in ddl_data['documents'] for doc in doc_list]

        for i in range(len(ddl)):
            prompt = prompt + ddl[i] + "\n "

        message_log = [self.system_message(prompt)]

        prompt = prompt + "\n You have these set of example questions and queries:\n"

        # for result in results:
        sql_d = [meta['sql'] for metadata_list in sql_data['metadatas'] for meta in metadata_list]
        documents = [doc for doc_list in sql_data['documents'] for doc in doc_list]
        questions = [meta['question'] for metadata_list in sql_data['metadatas'] for meta in metadata_list]


        for i in range(len(sql_d)):
            message_log.append(self.user_message(questions[i]))
            message_log.append(self.assistant_message(sql_d[i]))
            #prompt = prompt + "question:" + questions[i] + "\nsql:" + sql_d[i] + "\n"

        message_log.append(self.user_message(question))
        return message_log


    def submit_prompt(self, prompt, temperature = 0.2 , max_tokens = 3000) -> str:
        if prompt is None:
            raise Exception("Prompt is None")

        if len(prompt) == 0:
            raise Exception("Prompt is empty")

        # Count the number of tokens in the message log
        # Use 4 as an approximation for the number of characters per token
        num_tokens = 0
        for message in prompt:
            num_tokens += len(message["content"]) / 4

        response =  self.client.chat.completions.create(
              model= self.model ,
              messages= prompt ,
              temperature= temperature,
              max_tokens= max(max_tokens , num_tokens)
              )

        query = response.choices[0].message.content

        for choice in response.choices:
            if "text" in choice:
                query = choice.text

        sql = re.search(r"```sql\n(.*)```", query, re.DOTALL)
        if sql:
            #self.log(f"Output from LLM: {query} \nExtracted SQL: {sql.group(1)}")
            query =  sql.group(1)

        sql = re.search(r"```(.*)```", query, re.DOTALL)
        if sql:
            #self.log(f"Output from LLM: {query} \nExtracted SQL: {sql.group(1)}")
            query = sql.group(1)

        query = query.replace("\_" , "_")
        query = query.replace("\\" , "")

        return query

    def generate_summary(self, question: str, df: pd.DataFrame, temperature  ,benchmark , max_tokens , **kwargs) -> str:
        message_log = [
            self.system_message(
                f"You are a great Sherlock Holmes of Data. The user asked the question: '{question}'\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
            ),
            self.user_message(
                f"[Instruction]Generate great insights from the data.Use benchmark numbers from following dictionary:{benchmark['metadatas']}\n.Don't output the whole same data and obvious facts. Do not respond with any additional explanation beyond the summary.[~Instruction]"
            ),
        ]

        summary = self.submit_prompt(prompt = message_log,temperature = temperature ,max_tokens = max_tokens , **kwargs)

        return summary

    def regenerate_query(self , prompt : any , error ,query , temperature = 0.1 , **kwargs):
        additions = []
        additions.append(self.system_message(f"You are a great SQL debugger and respond with only SQL query-no explanation and the user got the error with following SQL query : {query}\n\n"))
        additions.append(self.user_message(f"[Instruction]Please correct this error :{error} \n in the SQL query generated and output correct SQL query[~Instruction]\n\n"))
        new_prompt =  additions
        sql_query = self.submit_prompt(prompt = new_prompt , temperature = 0.1)
        return sql_query


    def _extract_python_code(self, markdown_string: str) -> str:
        # Regex pattern to match Python code blocks
        pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```"

        # Find all matches in the markdown string
        matches = re.findall(pattern, markdown_string, re.IGNORECASE)

        # Extract the Python code from the matches
        python_code = []
        for match in matches:
            python = match[0] if match[0] else match[1]
            python_code.append(python.strip())

        if len(python_code) == 0:
            return markdown_string

        return python_code[0]

    def _sanitize_plotly_code(self, raw_plotly_code: str) -> str:
        # Remove the fig.show() statement from the plotly code
        plotly_code = raw_plotly_code.replace("fig.show()", "")

        return plotly_code



    def generate_plotly_code(
        self, question: str = None, sql: str = None, df_metadata: str = None, temperature = 0.3 , **kwargs) -> str:
        if question is not None:
            system_msg = f"The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: '{question}'"
        else:
            system_msg = "The following is a pandas DataFrame "

        if sql is not None:
            system_msg += f"\n\nThe DataFrame was produced using this query: {sql}\n\n"

        system_msg += f"The following is information about the resulting pandas DataFrame 'df': \n{df_metadata}"

        message_log = [
            self.system_message(system_msg),
            self.user_message(
                "Can you generate the Python plotly code to chart the results of the dataframe? Assume the data is in a pandas dataframe called 'df'. If there is only one value in the dataframe, use an Indicator. Respond with only Python code. Do not answer with any explanations -- just the code."
            ),
        ]

        plotly_code = self.submit_prompt(message_log , temperature = temperature)

        return self._sanitize_plotly_code(self._extract_python_code(plotly_code))

    def get_plotly_figure(
        self, plotly_code: str, df: pd.DataFrame, dark_mode: bool = True) -> plotly.graph_objs.Figure:
        ldict = {"df": df, "px": px, "go": go}
        try:
            exec(plotly_code, globals(), ldict)

            fig = ldict.get("fig", None)
        except:
                # Inspect data types
            fig = None
            date_cols = df.select_dtypes(include=["datetime"]).columns.tolist()
            numeric_cols = df.select_dtypes(include=["number" ,"int" ,"float"]).columns.tolist()
            categorical_cols = df.select_dtypes(
                include=["object", "category"]
            ).columns.tolist()

                # Decision-making for plot type

            if len(date_cols) + len(categorical_cols) <= 3 and len(date_cols) + len(categorical_cols) > 0:
                figure = []
                figure_traces = []

                if len(date_cols) == 0:
                    date_cols = []
                    date_cols.append(categorical_cols[0])
                    categorical_cols.pop(0)
                for i in range(len(numeric_cols)):
                    if len(categorical_cols) == 0:
                        figure.append(px.bar(df , x = date_cols[0] , y = numeric_cols[i] ,text = numeric_cols[i]))
                    elif len(categorical_cols) == 1:
                        figure.append(px.bar(df , x = date_cols[0] , y = numeric_cols[i] , color = categorical_cols[0],pattern_shape= categorical_cols[0] , text = numeric_cols[i]))
                    else:
                        figure.append(px.bar(df , x = date_cols[0] , y = numeric_cols[i] , color = categorical_cols[0] , pattern_shape= categorical_cols[1] , text = numeric_cols[i]))


                # For as many traces that exist per Express figure, get the traces from each plot and store them in an array.
                # This is essentially breaking down the Express fig into it's traces

                for i in range(len(numeric_cols)):
                    figure_traces.append([])
                    for trace in range(len(figure[i]["data"])):
                        figure_traces[i].append(figure[i]["data"][trace])


                #Create a 1x2 subplot
                fig = sp.make_subplots(rows=1, cols=len(numeric_cols))

                for i in range(1,len(numeric_cols)+1):
                    fig['layout']['xaxis{}'.format(i)]['title']= date_cols[0]
                    fig['layout']['yaxis{}'.format(i)]['title']= numeric_cols[i-1]

                # Get the Express fig broken down as traces and add the traces to the proper plot within in the subplot
                for i in range(1 ,len(numeric_cols)+1):
                    for traces in figure_traces[i-1]:
                        fig.append_trace(traces, row=1, col=i)

            else:
                # Default to a simple line plot if above conditions are not met
                fig = px.line(df)


        return fig

    def generate_followup_questions(self, question: str, sql: str, df: pd.DataFrame, **kwargs ) -> list:
        message_log = [
            self.system_message(
                f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
            ),
            self.user_message(
                "[Instruction]Generate a list of three followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions.[~Instruction]"
            ),
        ]

        llm_response = self.submit_prompt(message_log, **kwargs)

        numbers_removed = re.sub(r"^\d+\.\s*", "", llm_response, flags=re.MULTILINE)
        l = numbers_removed.split("\n")
        return l


In [90]:
##########################################################################################  Main run  ################################################################################################
#running the script
def run(question , my_db : chroma_db , my_model : customllm):
    #retrieving data
    print("retrieving data from database.........")
    data_sql = my_db.retrieve_sql_data(question = question)
    data_ddl = my_db.retrieve_ddl_data(question = question)
    data_docs = my_db.retrieve_docs_data(question = question)
    data_benchmark = my_db.retrieve_benchmark_data(question = question)
    #checking if question already exists
    print("checking if query already exists.........")
    sql_query = check_question_db(question = question , data_sql = data_sql)

    # if no existing question exists
    if sql_query is None:
        print("query doesnt exist,Calling llm model.........")
        prompt = my_model.ask(question = question , sql_data = data_sql , ddl_data = data_ddl , docs_data = data_docs )
        sql_query = my_model.submit_prompt(prompt , temperature = 0.1)

    df = pd.DataFrame()

    try:
        print("Query generated,executing query...........")
        sqlglot.transpile(sql_query)
        df = execute_query(sql_query = sql_query)
    except Exception as e:
        print("Wrong query generated!")
        print("Trying again......")
        try:
            sql_query = my_model.regenerate_query(prompt = prompt , error = e , query = sql_query , temperature = 0.3)
            df = execute_query(sql_query = sql_query)
        except:
            print("Execution failed.Sending this report to administrator....")
            return


    if df.shape[0] == 0:
        print(sql_query)
        print("Sorry the SQL query generated was correct but no data found!")
    else:
        print("The query generated is : \n")
        print(sql_query)
        print("\n")
        print("------------------------------------------------------------------------------------------------------------\n")
        print(df)
        print("\n")
        try:
            summary = my_model.generate_summary(question = question , df = df ,benchmark = data_benchmark , temperature = 0.4 , max_tokens = 500)
            print(summary)
        except:
            pass
        print("\n")
        if df.shape[0] > 1:
            try:
                df.reset_index(drop=True, inplace=True)
                plotly_code = my_model.generate_plotly_code(
                                    question=question,
                                    sql=sql_query,
                                    df_metadata=f"Running df.dtypes gives:\n {df.dtypes}", temperature = 0.3
                                )
                fig = my_model.get_plotly_figure(plotly_code=plotly_code, df=df)
                plt.figure(figsize=(5,5))
                if fig is not None:
                    fig.show()
            except:
                pass

        follow_up_questions = []
        try:
            l = my_model.generate_followup_questions(question = question , sql = sql_query , df = df)
            for i in l:
                sql_data = my_db.retrieve_sql_data(question = i)
                questions = [meta['question'] for metadata_list in sql_data['metadatas'] for meta in metadata_list]
                for q in questions:
                    if q not in follow_up_questions:
                        print(q)
                        follow_up_questions.append(questions)
        except:
            for i in data_sql['metadatas']:
                for k in i:
                    if k['question'] != question:
                        print(k['question'])
                        follow_up_questions.append(k['question'])


In [91]:
#initiallising the database
question_texts , sql = modify_data(dataset)
my_db = chroma_db(client_type = 'temporary')
my_db.add_sql(question_texts = question_texts , sql = sql)
my_db.add_ddl(ddl = ddl)
my_db.add_docs(docs = docs)
my_db.add_benchmark(benchmark = benchmark)

#initialising model
my_model = customllm(base_url = "https://api.endpoints.anyscale.com/v1" ,api_key = "esecret_7wr8fwgwiq5vigx65zuhlb3ruu" , model = "mistralai/Mixtral-8x7B-Instruct-v0.1")

In [92]:
#running
question = "what is my retention store and month wise?"
run(question = question , my_db = my_db , my_model = my_model)

retrieving data from database.........
checking if query already exists.........
Query generated,executing query...........
The query generated is : 


        with base as (
            select store ,date_trunc('MONTH',created_at)::DATE as month , email from orders group by 1,2,3
        ) ,

        ret_base as (
            select a.store , a.month as base_month, b.month as retained_month , count(distinct b.email) as retained_users
            from base a
            join base b on b.month >= a.month and b.email = a.email
            group by 1,2,3
        ),

        totals_base as (
        select a.store, a.month as base_month, count(distinct a.email) as total_users
        from base a
        group by 1,2)

        select a.store, a.base_month, a.retained_month, a.retained_users*100.00/b.total_users as percentage_retained_users
        from ret_base a
        left join totals_base b on a.store=b.store and a.base_month=b.base_month
        order by a.store, a.base_month, a.retain

what is my retention monthly and store wise?
how does the month on month retention percentage look like?
What is the average store and source wise month on month retention%?


<Figure size 500x500 with 0 Axes>