In [2]:
from langchain_groq import ChatGroq

from langgraph.graph import StateGraph ,START , END 
from langchain_huggingface import HuggingFaceEndpoint , ChatHuggingFace
from typing import TypedDict
from dotenv import load_dotenv


In [3]:
load_dotenv()

llm = ChatGroq(
    model="llama-3.3-70b-versatile",
    temperature=0.2
)




In [4]:
# create the state 
class LLMState(TypedDict):
    
    question: str
    query: str
    

In [5]:
system_prompt = """You are a helpful assistant and smart SQL developer that can generate SQL queries.


FACT_TABLES = {
    "fact_sales": {
        "table": "fact_sales",
        "joins": {
            "dim_product": ("product_id", "product_id"),
            "dim_date": ("date_id", "date_id"),
            "dim_region": ("region_id", "region_id"),
            "dim_channel": ("channel_id", "channel_id"),
            "dim_employee": ("employee_id", "employee_id"),
            "dim_customer": ("customer_id", "customer_id")
        },
        "metrics": {
            "quantity": "SUM(f.quantity)",
            "gross_amount": "SUM(f.gross_amount)",
            "discount_amount": "SUM(f.discount_amount)",
            "net_amount": "SUM(f.net_amount)",
            "tax_amount": "SUM(f.tax_amount)"
        }
    },

    "fact_service_usage": {
        "table": "fact_service_usage",
        "joins": {
            "dim_product": ("service_id", "product_id"),
            "dim_date": ("date_id", "date_id"),
            "dim_region": ("region_id", "region_id"),
            "dim_customer": ("customer_id", "customer_id")
        },
        "metrics": {
            "usage_count": "SUM(f.usage_count)",
            "duration_minutes": "SUM(f.duration_minutes)",
            "cost_incurred": "SUM(f.cost_incurred)"
        }
    },

    "fact_finance_snapshot": {
        "table": "fact_finance_snapshot",
        "joins": {
            "dim_date": ("date_id", "date_id"),
            "dim_department": ("department_id", "department_id")
        },
        "metrics": {
            "total_revenue": "SUM(f.total_revenue)",
            "total_cost": "SUM(f.total_cost)",
            "profit": "SUM(f.profit)",
            "operational_expense": "SUM(f.operational_expense)"
        }
    }
}


DIMENSION_COLUMNS = {
    "product_name": ("dim_product", "product_name"),
    "product_type": ("dim_product", "product_type"),
    "category": ("dim_product", "category"),
    "sub_category": ("dim_product", "sub_category"),

    "customer_type": ("dim_customer", "customer_type"),
    "industry": ("dim_customer", "industry"),
    "customer_segment": ("dim_customer", "customer_segment"),

    "market_region": ("dim_region", "market_region"),
    "country": ("dim_region", "country"),
    "state": ("dim_region", "state"),
    "city": ("dim_region", "city"),

    "channel_name": ("dim_channel", "channel_name"),

    "role": ("dim_employee", "role"),
    "department": ("dim_employee", "department"),

    "year": ("dim_date", "year"),
    "quarter": ("dim_date", "quarter"),
    "month": ("dim_date", "month"),
    "full_date": ("dim_date", "full_date")
}




"""

In [9]:
def llm_qa(state: LLMState) -> LLMState:
    # extract the question from the state
    question = state['question']
    # form a prompt 
    prompt = f"{system_prompt} Answer the following question {question}"
    # ask the llm
    answer = llm.invoke(prompt).content 
    
    #update the answer in the state
    state['query'] = answer 
    
    return state

In [10]:
# create the graph 
graph = StateGraph(LLMState)
graph.add_node("llm_qa",llm_qa)


graph.add_edge(START,"llm_qa")
graph.add_edge("llm_qa",END)

workflow = graph.compile()

In [12]:
initial_state = {
    "question" : "generate the code for getting the total sales from the database",
    }

res=workflow.invoke(initial_state)

In [15]:
print(res['query'])

To get the total sales from the database, we need to query the `fact_sales` table and sum up the `gross_amount` metric. Here is the SQL query to achieve this:

```sql
SELECT 
    SUM(f.gross_amount) AS total_sales
FROM 
    fact_sales f
```

However, if you want to get the total sales for a specific date range, product, region, or any other dimension, you would need to join the `fact_sales` table with the corresponding dimension tables and apply the necessary filters. For example, to get the total sales for a specific date range, you would use the following query:

```sql
SELECT 
    SUM(f.gross_amount) AS total_sales
FROM 
    fact_sales f
JOIN 
    dim_date d ON f.date_id = d.date_id
WHERE 
    d.year = 2022 AND d.quarter = 1
```

To get the total sales for a specific product, you would use the following query:

```sql
SELECT 
    SUM(f.gross_amount) AS total_sales
FROM 
    fact_sales f
JOIN 
    dim_product p ON f.product_id = p.product_id
WHERE 
    p.product_name = 'Product A'
``