In [None]:
import json
from pathlib import Path
from typing import List, Dict
import importlib.util
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.utils import embedding_functions

class VectorDatabaseMapper:
    """
    Vector database system using ChromaDB for mapping user prompts to relevant database tables and columns.
    """
    
    def __init__(self, model_name='all-MiniLM-L6-v2', collection_name='schema_embeddings'):
        """
        Initialize the vector database mapper with ChromaDB.
        
        Args:
            model_name (str): Name of the sentence transformer model to use
            collection_name (str): Name of the ChromaDB collection
        """
        self.model = SentenceTransformer(model_name)
        self.client = chromadb.PersistentClient(path="./chroma_db")
        self.embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_name)
        self.collection = self.client.get_or_create_collection(
            name=collection_name,
            embedding_function=self.embedding_function
        )
        self.schema_metadata = {}
        
    def load_schema_from_file(self, schema_file_path: str) -> Dict:
        """Load schema from a Python file."""
        try:
            spec = importlib.util.spec_from_file_location("schema_module", schema_file_path)
            schema_module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(schema_module)
            return schema_module.schema
        except Exception as e:
            print(f"Error loading schema from {schema_file_path}: {e}")
            return None
    
    def generate_text_descriptions(self, schema: Dict, db_name: str) -> List[Dict]:
        """
        Generate natural language descriptions for tables and columns.
        
        Args:
            schema (dict): Database schema
            db_name (str): Database name
            
        Returns:
            List[Dict]: List of text descriptions with metadata
        """
        descriptions = []
        
        for table_name, table_info in schema.items():
            # Handle both old and new schema formats
            if isinstance(table_info, dict) and 'columns' in table_info:
                columns = table_info['columns']
                sample_data = table_info.get('sample', [])
            else:
                columns = table_info
                sample_data = []
            
            # Generate table-level description
            column_names = []
            for col_info in columns:
                if isinstance(col_info, dict):
                    column_names.extend(col_info.keys())
                else:
                    column_names.append(str(col_info))
            
            table_description = f"Table {table_name} contains information about {self._infer_table_purpose(table_name, column_names)}. "
            table_description += f"It has columns: {', '.join(column_names)}."
            
            # Add sample data context if available
            if sample_data:
                sample_values = []
                for sample in sample_data[:2]:
                    for key, value in sample.items():
                        if value is not None:
                            sample_values.append(f"{key}: {value}")
                
                if sample_values:
                    table_description += f" Example data includes: {'; '.join(sample_values[:5])}."
            
            descriptions.append({
                'text': table_description,
                'type': 'table',
                'database': db_name,
                'table': table_name,
                'column': None,
                'metadata': {
                    'column_count': len(columns),
                    'has_samples': len(sample_data) > 0
                }
            })
            
            # Generate column-level descriptions
            for col_info in columns:
                if isinstance(col_info, dict):
                    for col_name, col_type in col_info.items():
                        col_description = f"Column {col_name} in table {table_name} is of type {col_type}. "
                        col_description += self._infer_column_purpose(col_name, col_type, table_name)
                        
                        if sample_data:
                            sample_values = [str(sample.get(col_name, '')) for sample in sample_data if sample.get(col_name) is not None]
                            if sample_values:
                                col_description += f" Example values: {', '.join(sample_values[:3])}."
                        
                        descriptions.append({
                            'text': col_description,
                            'type': 'column',
                            'database': db_name,
                            'table': table_name,
                            'column': col_name,
                            'metadata': {
                                'data_type': col_type,
                                'has_samples': len(sample_data) > 0
                            }
                        })
        
        return descriptions
    
    def _infer_table_purpose(self, table_name: str, column_names: List[str]) -> str:
        """Infer the purpose of a table based on its name and columns."""
        table_lower = table_name.lower()
        cols_lower = [col.lower() for col in column_names]
        
        if 'user' in table_lower or 'customer' in table_lower:
            return "users or customers"
        elif 'order' in table_lower or 'purchase' in table_lower:
            return "orders or purchases"
        elif 'product' in table_lower or 'item' in table_lower:
            return "products or items"
        elif 'payment' in table_lower or 'transaction' in table_lower:
            return "payments or transactions"
        elif 'employee' in table_lower or 'staff' in table_lower:
            return "employees or staff"
        elif 'track' in table_lower or 'song' in table_lower:
            return "music tracks or songs"
        elif 'album' in table_lower:
            return "music albums"
        elif 'artist' in table_lower:
            return "artists or musicians"
        elif 'genre' in table_lower:
            return "music genres"
        elif 'invoice' in table_lower:
            return "invoices or billing"
        elif 'playlist' in table_lower:
            return "playlists"
        else:
            return f"{table_name.replace('_', ' ')}"
    
    def _infer_column_purpose(self, col_name: str, col_type: str, table_name: str) -> str:
        """Infer the purpose of a column based on its name and type."""
        col_lower = col_name.lower()
        
        if 'id' in col_lower:
            return f"It serves as an identifier for {table_name}."
        elif 'name' in col_lower:
            return "It stores name information."
        elif 'email' in col_lower:
            return "It stores email addresses."
        elif 'phone' in col_lower:
            return "It stores phone numbers."
        elif 'address' in col_lower:
            return "It stores address information."
        elif 'date' in col_lower or 'time' in col_lower:
            return "It stores date/time information."
        elif 'price' in col_lower or 'amount' in col_lower or 'cost' in col_lower:
            return "It stores monetary values."
        elif 'quantity' in col_lower or 'count' in col_lower:
            return "It stores quantity or count information."
        elif 'description' in col_lower:
            return "It stores descriptive text."
        elif 'status' in col_lower:
            return "It stores status information."
        else:
            return f"It stores {col_name.replace('_', ' ').lower()} data."
    
    def build_vector_database(self, schema_folder: str = "schemas") -> None:
        """
        Build ChromaDB vector database from schema files.
        
        Args:
            schema_folder (str): Path to folder containing schema files
        """
        print("Building ChromaDB vector database...")
        
        schema_files = list(Path(schema_folder).glob("*.py"))
        
        if not schema_files:
            print(f"No schema files found in {schema_folder}")
            return
        
        all_descriptions = []
        ids = []
        metadatas = []
        documents = []
        
        for schema_file in schema_files:
            db_name = schema_file.stem
            if db_name.endswith('_db'):
                db_name = db_name[:-3]
            
            print(f"Processing schema: {schema_file.name}")
            
            schema = self.load_schema_from_file(str(schema_file))
            if not schema:
                continue
            
            descriptions = self.generate_text_descriptions(schema, db_name)
            all_descriptions.extend(descriptions)
        
        for i, desc in enumerate(all_descriptions):
            ids.append(f"desc_{i}")
            documents.append(desc['text'])
            metadatas.append({
                'type': desc['type'],
                'database': desc['database'],
                'table': desc['table'],
                'column': desc['column'],
                'metadata': json.dumps(desc['metadata'])
            })
        
        print(f"Generated {len(all_descriptions)} descriptions")
        
        # Clear existing collection and add new embeddings
        self.collection.delete()
        if documents:
            self.collection.add(
                documents=documents,
                metadatas=metadatas,
                ids=ids
            )
        
        self.schema_metadata = all_descriptions
        print(f"ChromaDB vector database built with {len(all_descriptions)} entries")
    
    def find_relevant_schema(self, user_prompt: str, top_k: int = 10) -> List[Dict]:
        """
        Find the most relevant tables and columns for a user prompt using ChromaDB.
        
        Args:
            user_prompt (str): User's natural language query
            top_k (int): Number of top results to return
            
        Returns:
            List[Dict]: Ranked list of relevant schema elements
        """
        results = self.collection.query(
            query_texts=[user_prompt],
            n_results=top_k
        )
        
        relevant_schema = []
        for i in range(len(results['ids'][0])):
            metadata = results['metadatas'][0][i]
            relevant_schema.append({
                'text': results['documents'][0][i],
                'type': metadata['type'],
                'database': metadata['database'],
                'table': metadata['table'],
                'column': metadata['column'],
                'similarity_score': results['distances'][0][i],
                'metadata': json.loads(metadata['metadata'])
            })
        
        return relevant_schema
    
    def get_relevant_tables_columns(self, user_prompt: str, top_k: int = 5) -> Dict:
        """
        Get relevant tables and columns for a user prompt.
        
        Args:
            user_prompt (str): User's natural language query
            top_k (int): Number of top results to consider
            
        Returns:
            Dict: Structured context with tables and columns
        """
        relevant_schema = self.find_relevant_schema(user_prompt, top_k)
        
        if not relevant_schema:
            return {}
        
        context = {}
        
        for item in relevant_schema:
            db_name = item['database']
            table_name = item['table']
            
            if db_name not in context:
                context[db_name] = {}
            
            if table_name not in context[db_name]:
                context[db_name][table_name] = {
                    'columns': [],
                    'relevance_score': 0,
                    'description': ''
                }
            
            if item['type'] == 'table':
                context[db_name][table_name]['description'] = item['text']
                context[db_name][table_name]['relevance_score'] = max(
                    context[db_name][table_name]['relevance_score'],
                    1.0 - item['similarity_score']  # Convert distance to similarity
                )
            elif item['type'] == 'column':
                context[db_name][table_name]['columns'].append({
                    'name': item['column'],
                    'description': item['text'],
                    'relevance_score': 1.0 - item['similarity_score'],
                    'data_type': item['metadata'].get('data_type', 'UNKNOWN')
                })
        
        # Sort columns by relevance
        for db_name in context:
            for table_name in context[db_name]:
                context[db_name][table_name]['columns'].sort(
                    key=lambda x: x['relevance_score'], 
                    reverse=True
                )
        
        return context

if __name__ == "__main__":
    # Initialize vector database
    vector_db = VectorDatabaseMapper()
    
    # Build vector database from schema files
    vector_db.build_vector_database("schemas")
    
    # Example queries
    test_queries = [
        "Show me all the artists and their names",
        "How many tracks are there in the database?",
        "What is the total sales amount?",
        "List all customers from USA",
        "Find the most expensive album",
        "Group sales by country"
    ]
    
    print("\n" + "="*50)
    print("TESTING CHROMADB VECTOR MAPPING")
    print("="*50)
    
    for query in test_queries:
        print(f"\nQuery: {query}")
        print("-" * 30)
        
        # Get relevant tables and columns
        context = vector_db.get_relevant_tables_columns(query)
        
        if context:
            for db_name, tables in context.items():
                print(f"Database: {db_name}")
                for table_name, table_info in tables.items():
                    print(f"Table: {table_name} (Relevance: {table_info['relevance_score']:.3f})")
                    print(f"Description: {table_info['description']}")
                    print("Relevant Columns:")
                    for col in table_info['columns']:
                        print(f"  - {col['name']} (Relevance: {col['relevance_score']:.3f}, Type: {col['data_type']})")
        else:
            print("No relevant schema found")