In [2]:
import os 
from pathlib import Path
from transformers import pipeline
from pyspark.sql import SparkSession
import json
from sentence_transformers import SentenceTransformer
import numpy as np
import faiss
from groq import Groq

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
groq_token = os.environ.get("GROQ_TOKEN")
local_warehouse = Path("/data/warehouse")

spark = SparkSession.builder \
    .appName("rag") \
    .config("spark.hadoop.hadoop.native.lib", "false") \
    .config("spark.jars.packages", "org.apache.iceberg:iceberg-spark-runtime-3.5_2.12:1.9.0") \
    .config("spark.sql.catalog.local", "org.apache.iceberg.spark.SparkCatalog") \
    .config("spark.sql.catalog.local.type", "hadoop") \
    .config("spark.sql.catalog.local.warehouse", "data/warehouse") \
    .getOrCreate()

groq_client = Groq(api_key=groq_token)

model = SentenceTransformer("all-MiniLM-L6-v2")


ConnectionError: (MaxRetryError('HTTPSConnectionPool(host=\'huggingface.co\', port=443): Max retries exceeded with url: /api/models/sentence-transformers/all-MiniLM-L6-v2/tree/main/additional_chat_templates?recursive=False&expand=False (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x0000018A9FC8AD80>: Failed to resolve \'huggingface.co\' ([Errno 11001] getaddrinfo failed)"))'), '(Request ID: b9a1d373-3c41-433e-8afd-f87742e13a54)')

In [15]:
def get_column_metadata(table):
    df = spark.sql(f"DESCRIBE TABLE EXTENDED {table}")
    metadata = df.filter("col_name NOT IN ('# col_name', '')") \
                 .select("col_name", "comment") \
                 .dropna()
                 
    metadata = metadata.withColumn("col_name", metadata["col_name"].cast("string")) \
                       .withColumn("comment", metadata["comment"].cast("string"))
    #metadata = metadata.union(table_description_df)
    metadata_dict = {row['col_name'].lower(): row['comment'] for row in metadata.collect()}

    return metadata_dict



# Combine metadata and semantic info
def build_corpus(metadata, semantic_layer):
    corpus = []
    ids = []
    for table, info in metadata.items():
        corpus.append(f"description of table named {table}: {info['table_description']}")
        ids.append(("table_description", table))
        for col, desc in info["col_metadata"].items():
            corpus.append(f"description of column {table}.{col}: {desc}")
            ids.append((table, col))
    for kpi, kpi_info in semantic_layer.get("KPI definitions", {}).items():
        corpus.append(f"KPI: {kpi}, description: {kpi_info['description']}, formula: {kpi_info['formula']}")
        ids.append(("kpi", kpi))
    return corpus, ids

with open("semantic_layer.json") as f:
    semantic_layer = json.load(f)

tables = [
    "local.bronze.amazon_sale_report",
    "local.bronze.cloud_warehouse_compersion_chart",
    "local.bronze.expense_iigf",
    "local.bronze.international_sale_report",
    "local.bronze.may22",
    "local.bronze.p__l_march_2021",
    "local.bronze.sale_report"
]
metadata = {}
for table_name in tables:
    #get metadata and vectorize them
    table_description = spark.catalog.getTable(table_name).description

    col_metadata = get_column_metadata(table_name)
    metadata[table_name] = {
        "table_description": table_description,
        "col_metadata": col_metadata
    }
corpus, ids = build_corpus(metadata, semantic_layer)
embeddings = model.encode(corpus, convert_to_tensor=True)

dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)  # or IndexFlatIP for cosine similarity if normalized
index.add(embeddings)
faiss.write_index(index, "table_and_col_vectorized_metadata.index")



In [17]:
ids

[('table_description', 'local.bronze.amazon_sale_report'),
 ('local.bronze.amazon_sale_report', 'order_id'),
 ('local.bronze.amazon_sale_report', 'date'),
 ('local.bronze.amazon_sale_report', 'status'),
 ('local.bronze.amazon_sale_report', 'fulfilment'),
 ('local.bronze.amazon_sale_report', 'sales_channel_'),
 ('local.bronze.amazon_sale_report', 'ship_service_level'),
 ('local.bronze.amazon_sale_report', 'style'),
 ('local.bronze.amazon_sale_report', 'sku'),
 ('local.bronze.amazon_sale_report', 'category'),
 ('local.bronze.amazon_sale_report', 'size'),
 ('local.bronze.amazon_sale_report', 'asin'),
 ('local.bronze.amazon_sale_report', 'courier_status'),
 ('local.bronze.amazon_sale_report', 'qty'),
 ('local.bronze.amazon_sale_report', 'currency'),
 ('local.bronze.amazon_sale_report', 'amount'),
 ('local.bronze.amazon_sale_report', 'ship_city'),
 ('local.bronze.amazon_sale_report', 'ship_state'),
 ('local.bronze.amazon_sale_report', 'ship_postal_code'),
 ('local.bronze.amazon_sale_report'

In [None]:
def retrieve_context(question, top_k=10):
    # get content with similar vectors
    question_vec = model.encode([question], convert_to_tensor=True).cpu().numpy()
    _, indices = index.search(question_vec, top_k)

    # get all the table metadata that the content found earlier relates to. 
    # E.g. if the content is about a column, get the metadata of table that the col belongs to.
    added_kpis = []
    added_tables = []
    corpus_list = []
    for i in indices[0]:
        # if the content is about a table's description, get the table name from the id
        # otherwise, get the table name from the column id
        if ids[i][0] == "table_description":
            table_name = ids[i][1]
        elif ids[i][0] != "kpi":
            table_name = ids[i][0]
        elif ids[i][0] == "kpi" and ids[i][1] not in added_kpis:
            # if the content is about a KPI, get the KPI description
            kpi_name = ids[i][1]
            kpi_info = semantic_layer.get("KPI definitions", {}).get(kpi_name, {})
            corpus_list.append(f"KPI: {kpi_name}, description: {kpi_info['description']}, formula: {kpi_info['formula']}")
            added_kpis.append(kpi_name)
            continue
            
        
        # if the table is not already added get the human readable metadata of the table
        if table_name and table_name not in added_tables:
            table_description = spark.catalog.getTable(table_name).description
            corpus_list.append(f"description of table named {table_name}: {table_description}")

            col_metadata = get_column_metadata(table_name)
            for col, desc in col_metadata.items():
                corpus_list.append(f"description of column {table_name}.{col}: {desc}")

            added_tables.append(table_name)
        
    return corpus_list

def generate_sql(question, client):
    context = retrieve_context(question)
    prompt = f"Given the following metadata and semantic info:\n\n" + \
             "\n".join(context) + f"\n\nGenerate Spark SQL that answers this question (just SQL code, nothing else):\n{question}"
    
    completion = client.chat.completions.create(
        model="compound-beta",
        messages=[
            {
                "role": "system",
                "content": """
                    You are a Spark SQL generator bot with the skills of an expert data analyst with advanced SQL knowledge. 
                    You understand the need for SQL code that is efficient and readable. 
                    You understand CTEs and window functions.
                    You do not order the results unless it is explicitly asked in the question.
                    Your output must only contain valid SPARK SQL code that can be executed in a SQL engine.
                    You only respond with SQL code that is ready to be executed. Under no circumstances do you respond with anything else.
                    You know that to create efficient SQL queries you should do filtering as early as possible - but only if it is possible at part of the query.
                    You are well aware of the fact that when answering business related questions, it's important to be very precise on what the question is and how the KPIs are defined.
                """
            },
            {
                "role": "user",
                "content": prompt
            }
        ],
        temperature=1,
        max_completion_tokens=1024,
        top_p=1,
        stream=True,
        stop=None,
    )

    generated_text = ""
    for chunk in completion:
        generated_text += chunk.choices[0].delta.content or ""        

    return generated_text
    
sql_response = generate_sql("What is the total primary profit for each product category in the last quarter?", groq_client)

lines = sql_response.splitlines()
middle_lines = lines[1:-1]
sql_code = "\n".join(middle_lines)

print(sql_code)


```sql
SELECT 
    category,
    SUM(CASE 
        WHEN fulfilment = 'Amazon' THEN amount * 0.5 
        ELSE amount * 0.3 
    END) AS total_primary_profit
FROM 
    local.bronze.amazon_sale_report
WHERE 
    date >= DATE_SUB(CURRENT_DATE, INTERVAL 3 MONTH)
GROUP BY 
    category
```


In [4]:
df = spark.sql("""
    select date, cast(date as date) as date2 
    from local.bronze.amazon_sale_report
    WHERE 
        date >= DATE_SUB(CURRENT_DATE, INTERVAL 3 MONTH)
""")
df.show()

AnalysisException: [DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "date_sub(current_date(), INTERVAL '3' MONTH)" due to data type mismatch: Parameter 2 requires the ("INT" or "SMALLINT" or "TINYINT") type, however "INTERVAL '3' MONTH" has the type "INTERVAL MONTH".; line 5 pos 16;
'Project ['date, cast('date as date) AS date2#0]
+- 'Filter (date#3 >= date_sub(current_date(Some(Europe/Helsinki)), INTERVAL '3' MONTH))
   +- SubqueryAlias local.bronze.amazon_sale_report
      +- RelationV2[index#1, Order_ID#2, date#3, Status#4, Fulfilment#5, Sales_Channel_#6, ship_service_level#7, Style#8, SKU#9, Category#10, Size#11, ASIN#12, Courier_Status#13, Qty#14, currency#15, Amount#16, ship_city#17, ship_state#18, ship_postal_code#19, ship_country#20, promotion_ids#21, B2B#22, fulfilled_by#23, Unnamed_22#24] local.bronze.amazon_sale_report local.bronze.amazon_sale_report
