In [None]:
# import chromadb
# from llama_index.core import Settings
# from llama_index.core.readers import SimpleDirectoryReader
# from llama_index.readers.file import PagedCSVReader
# from llama_index.vector_stores.chroma import ChromaVectorStore
# from llama_index.core.ingestion import IngestionPipeline
# from llama_index.core import VectorStoreIndex
# import os

# # Set up the basic configuration
# collection_name = "sales_data_collection"
# persist_dir = "chroma_db"
# os.makedirs(persist_dir, exist_ok=True)

# def setup_vector_store_index(file_path: str) -> VectorStoreIndex:
#     """
#     Set up ChromaDB and create vector store index from a CSV file
    
#     Args:
#         file_path (str): Path to your CSV file
        
#     Returns:
#         VectorStoreIndex: Initialized vector store index
#     """
#     try:
#         # Initialize ChromaDB client
#         db = chromadb.PersistentClient(path=persist_dir)
        
#         # Delete existing collection if it exists
#         try:
#             db.delete_collection(name=collection_name)
#         except ValueError:
#             pass  # Collection doesn't exist, which is fine
        
#         # Create new collection
#         collection = db.create_collection(name=collection_name)
#         vector_store = ChromaVectorStore(chroma_collection=collection)
        
#         # Load and index the CSV data
#         csv_reader = PagedCSVReader()
#         reader = SimpleDirectoryReader(
#             input_files=[file_path],
#             file_extractor={".csv": csv_reader}
#         )
        
#         # Load documents and create nodes
#         docs = reader.load_data()
#         pipeline = IngestionPipeline(
#             vector_store=vector_store,
#             documents=docs
#         )
        
#         # Run the ingestion pipeline
#         nodes = pipeline.run()
        
#         # Create and return the vector store index
#         return VectorStoreIndex(nodes)
    
#     except Exception as e:
#         print(f"Error in setting up index: {str(e)}")
#         raise

# # Example usage:

# # First, set up your embedding model (example with Gemini)
# from llama_index.embeddings.gemini import GeminiEmbedding

# Settings.embed_model = GeminiEmbedding(
#     model_name="models/embedding-001",
#     api_key="api_key",
#     title="this is a document"
# )

# # Then create the index
# file_path = r"D:\ml_main\intern\app1_3\temp_sales_data.csv"
# vector_store_index = setup_vector_store_index(file_path)

# # # Create a query engine
# # query_engine = vector_store_index.as_query_engine(
# #     similarity_top_k=10000,
# #     response_mode="compact"
# # )

# # # Now you can query your data
# # response = query_engine.query("your query here")
# # print(response.response)


In [11]:
import chromadb
from llama_index.core import Settings, VectorStoreIndex
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.llms.gemini import Gemini
from typing import Dict, Any, Optional, Tuple

class SalesDataQueryEngine:
    def __init__(self, persist_dir: str = "chroma_db", collection_name: str = "sales_data_collection", gemini_api_key: str = None):
        """
        Initialize query engine for existing ChromaDB
        
        Args:
            persist_dir (str): Directory where ChromaDB is stored
            collection_name (str): Name of the collection to query
            gemini_api_key (str): Your Gemini API key
        # """
        # if gemini_api_key is None:
        #     raise ValueError("Gemini API key is required")

        self.persist_dir = persist_dir
        self.collection_name = collection_name
        
        # Initialize Gemini LLM
        self.llm = Gemini(
            model="models/gemini-1.5-flash",
            api_key='AIzaSyCsQn7Tj9WydJ_gOco5JIofPn4LhzUUzyU'
        )
        
        # Set the LLM in global settings
        Settings.llm = self.llm
        
        # Set up the query engine
        self.query_engine = self._setup_query_engine()

    def _setup_query_engine(self):
        """Set up the query engine from existing ChromaDB"""
        try:
            # Connect to existing ChromaDB
            db = chromadb.PersistentClient(path=self.persist_dir)
            collection = db.get_collection(name=self.collection_name)
            
            # Create vector store and index
            vector_store = ChromaVectorStore(chroma_collection=collection)
            index = VectorStoreIndex.from_vector_store(vector_store)
            
            # Return query engine with explicitly set LLM
            return index.as_query_engine(
                similarity_top_k=10000,
                response_mode="compact",
                llm=self.llm  # Explicitly pass the LLM
            )
            
        except Exception as e:
            print(f"Error setting up query engine: {str(e)}")
            raise

    
    

    def _get_query_type(self, user_query: str) -> Tuple[str, dict]:
        """Classify the query type and extract parameters"""
        prompt = f"""
        Analyze this query: "{user_query}"
        
        Categorize it into one of these types:
        1. yearly_orders - Questions about orders per year
        2. top_customer_by_orders - Questions about customers with most orders
        3. top_5_customers_by_sales - Questions about top customers by sales
        4. product_line_sales - Questions about product line performance
        5. country_sales_distribution - Questions about sales across countries
        6. order_status_counts - Questions about order statuses
        7. monthly_orders - Questions about orders per month
        8. customers_by_country - Questions about customer distribution
        
        Return only the category name and any specific parameters (like year) needed.
        Format: category_name|param1=value1,param2=value2
        Example: monthly_orders|year=2023
        """
        
        response = self.llm.complete(prompt)
        result = response.text.strip()
        
        parts = result.split('|')
        query_type = parts[0].strip()
        
        params = {}
        if len(parts) > 1:
            param_pairs = parts[1].split(',')
            for pair in param_pairs:
                if '=' in pair:
                    key, value = pair.split('=')
                    params[key.strip()] = value.strip()
        
        return query_type, params

    def _get_analytical_query(self, query_type: str, params: Optional[Dict[str, Any]] = None) -> str:
        """Get the structured query based on query type"""
        query_templates = {
            "yearly_orders": """
                Analyze the entire dataset and provide:
                - Total number of orders for each year
                - Present in chronological order
                - Include percentage of total orders
                Use exact numbers and format consistently.
            """,
            "top_customer_by_orders": """
                Analyze the dataset to find:
                - Customer with highest number of orders
                - Include their total order count
                - Include customer name and ID
                - Show what percentage of total orders they represent
                Use exact numbers and format consistently.
            """,
            "top_5_customers_by_sales": """
                Find and list the top 5 customers by total sales value:
                For each customer provide:
                - Customer name
                - Total sales amount
                - Number of orders
                - Average order value
                Sort by total sales descending and use exact numbers.
            """,
            "product_line_sales": """
                Analyze sales performance by product line:
                - List each product line with total sales
                - Sort by sales value (highest first)
                - Include percentage of total sales
                - Include total number of orders
                Use exact numbers and format consistently.
            """,
            "country_sales_distribution": """
                Analyze sales distribution across countries:
                For each country show:
                - Total sales value
                - Number of orders
                - Number of unique customers
                - Percentage of global sales
                Sort by total sales and use exact numbers.
            """,
            "order_status_counts": """
                Provide exact counts for:
                - Total number of orders
                - Orders by status (shipped, disputed, canceled)
                - Include percentage for each status
                Use exact numbers and format consistently.
            """,
            "monthly_orders": f"""
                Analyze orders for {params.get('year', 'each year')}:
                For each month provide:
                - Order count
                - Total sales value
                - Average order value
                Sort chronologically and use exact numbers.
            """,
            "customers_by_country": """
                Analyze customer distribution by country:
                - Number of unique customers per country
                - Sort by customer count (highest first)
                - Include percentage of total customers
                - Include total sales per country
                Use exact numbers and format consistently.
            """
        }
        
        return query_templates.get(query_type, "Invalid query type")

    def query(self, user_query: str) -> str:
        """
        Process and execute a user query
        
        Args:
            user_query (str): Natural language query from user
            
        Returns:
            str: Formatted response
        """
        try:
            # Classify query and get parameters
            query_type, params = self._get_query_type(user_query)
            
            # Get structured query
            structured_query = self._get_analytical_query(query_type, params)
            
            # Execute query
            response = self.query_engine.query(structured_query)
            
            return response.response
            
        except Exception as e:
            return f"Error processing query: {str(e)}"


# Example usage:
"""
# Initialize the query engine with your Gemini API key
gemini_api_key = "your-gemini-api-key"
query_engine = SalesDataQueryEngine(gemini_api_key=gemini_api_key)

Make a query
result = query_engine.query("Show me the top 5 customers by sales value")
print(result)
"""

'\n# Initialize the query engine with your Gemini API key\ngemini_api_key = "your-gemini-api-key"\nquery_engine = SalesDataQueryEngine(gemini_api_key=gemini_api_key)\n\nMake a query\nresult = query_engine.query("Show me the top 5 customers by sales value")\nprint(result)\n'

In [None]:
# gemini_api_key = ""
query_engine = SalesDataQueryEngine()

In [16]:
result = query_engine.query("How many total orders were placed in each year?")
print(result)

Number of requested results 10000 is greater than number of elements in index 2823, updating n_results = 2823


2003: 71 orders (36%)
2004: 87 orders (44%)
2005: 37 orders (19%)


In [3]:
# from llama_index import Settings
import query_engine