In [1]:
import chromadb
from typing import Dict, List, Any
from pathlib import Path
import importlib.util

class SchemaVectorDB:
    def __init__(self, db_path: str = "./chroma_db"):
        """Initialize the ChromaDB client and collection"""
        self.client = chromadb.PersistentClient(path=db_path)
        self.collection = self.client.get_or_create_collection(
            name="schema_collection",
            metadata={"hnsw:space": "cosine"}
        )
        
    def load_schema_from_file(self, file_path: str) -> Dict[str, Any]:
        """Load schema from a Python file"""
        try:
            spec = importlib.util.spec_from_file_location("schema_module", file_path)
            schema_module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(schema_module)
            return schema_module.schema
        except Exception as e:
            print(f"Error loading schema from {file_path}: {e}")
            return {}
    
    def create_embeddings_from_schema(self, schema: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Create embedding documents from schema structure"""
        documents = []
        
        for table_name, table_info in schema.items():
            # Convert columns list to a comma-separated string for metadata
            columns = table_info.get("columns", {})
            columns_str = ", ".join(columns.keys()) if columns else ""
            
            # Create table-level document
            table_doc = {
                "id": f"table_{table_name}",
                "type": "table",
                "table_name": table_name,
                "description": table_info.get("description", ""),
                "columns": columns_str,  # Store as string instead of list
                "text": f"Table: {table_name}. Description: {table_info.get('description', '')}. Columns: {columns_str}"
            }
            documents.append(table_doc)
            
            # Create column-level documents
            for col_name, col_type in columns.items():
                col_doc = {
                    "id": f"column_{table_name}_{col_name}",
                    "type": "column",
                    "table_name": table_name,
                    "column_name": col_name,
                    "column_type": col_type,
                    "table_description": table_info.get("description", ""),
                    "text": f"Column: {col_name} (Type: {col_type}) in table {table_name}. Table description: {table_info.get('description', '')}"
                }
                documents.append(col_doc)
        
        return documents
    
    def populate_database(self, schemas_folder: str):
        """Populate the vector database with schema information"""
        schemas_path = Path(schemas_folder)
        
        # Clear existing collection
        self.client.delete_collection("schema_collection")
        self.collection = self.client.create_collection(
            name="schema_collection",
            metadata={"hnsw:space": "cosine"}
        )
        
        all_documents = []
        
        # Process each schema file
        for file_path in schemas_path.glob("*.py"):
            if file_path.name.startswith("__"):
                continue
                
            print(f"Processing schema file: {file_path}")
            schema = self.load_schema_from_file(str(file_path))
            
            if schema:
                documents = self.create_embeddings_from_schema(schema)
                all_documents.extend(documents)
        
        if all_documents:
            # Add to ChromaDB
            ids = [doc["id"] for doc in all_documents]
            texts = [doc["text"] for doc in all_documents]
            metadatas = [{k: v for k, v in doc.items() if k not in ["id", "text"]} 
                        for doc in all_documents]
            
            self.collection.add(
                documents=texts,
                metadatas=metadatas,
                ids=ids
            )
            
            print(f"Added {len(all_documents)} documents to the vector database")
        
    def query_relevant_tables(self, user_query: str, n_results: int = 10) -> Dict[str, Any]:
        """Query the vector database to find relevant tables and columns"""
        
        # Query the collection
        results = self.collection.query(
            query_texts=[user_query],
            n_results=n_results
        )
        
        # Process results to group by tables
        table_recommendations = {}
        column_recommendations = {}
        
        if results['documents'] and results['documents'][0]:
            for i, metadata in enumerate(results['metadatas'][0]):
                distance = results['distances'][0][i]
                
                if metadata['type'] == 'table':
                    table_name = metadata['table_name']
                    # Convert columns string back to list for consistency
                    columns_list = [col.strip() for col in metadata['columns'].split(",") if col.strip()]
                    table_recommendations[table_name] = {
                        'description': metadata['description'],
                        'columns': columns_list,
                        'relevance_score': 1 - distance
                    }
                
                elif metadata['type'] == 'column':
                    table_name = metadata['table_name']
                    column_name = metadata['column_name']
                    
                    if table_name not in column_recommendations:
                        column_recommendations[table_name] = {
                            'description': metadata['table_description'],
                            'columns': {}
                        }
                    
                    column_recommendations[table_name]['columns'][column_name] = {
                        'type': metadata['column_type'],
                        'relevance_score': 1 - distance
                    }
        
        return {
            'query': user_query,
            'table_recommendations': table_recommendations,
            'column_recommendations': column_recommendations
        }
    
    def get_recommendations(self, user_query: str, top_tables: int = 3) -> Dict[str, Any]:
        """Get formatted recommendations for tables and columns"""
        
        results = self.query_relevant_tables(user_query, n_results=20)
        
        # Combine and rank tables
        all_tables = {}
        
        # Add table-level recommendations
        for table_name, info in results['table_recommendations'].items():
            all_tables[table_name] = {
                'description': info['description'],
                'all_columns': info['columns'],
                'relevant_columns': {},
                'table_score': info['relevance_score'],
                'column_score': 0
            }
        
        # Add column-level recommendations
        for table_name, info in results['column_recommendations'].items():
            if table_name not in all_tables:
                all_tables[table_name] = {
                    'description': info['description'],
                    'all_columns': [],
                    'relevant_columns': {},
                    'table_score': 0,
                    'column_score': 0
                }
            
            all_tables[table_name]['relevant_columns'] = info['columns']
            
            # Calculate average column score
            if info['columns']:
                avg_score = sum(col['relevance_score'] for col in info['columns'].values()) / len(info['columns'])
                all_tables[table_name]['column_score'] = avg_score
        
        # Calculate combined score and sort
        for table_name in all_tables:
            table_info = all_tables[table_name]
            # Weighted combination of table and column scores
            combined_score = (table_info['table_score'] * 0.3) + (table_info['column_score'] * 0.7)
            table_info['combined_score'] = combined_score
        
        # Sort by combined score
        sorted_tables = sorted(all_tables.items(), key=lambda x: x[1]['combined_score'], reverse=True)
        
        # Format response
        recommendations = {
            'query': user_query,
            'recommended_tables': []
        }
        
        for table_name, info in sorted_tables[:top_tables]:
            table_rec = {
                'table_name': table_name,
                'description': info['description'],
                'relevance_score': round(info['combined_score'], 3),
                'recommended_columns': []
            }
            
            # Sort columns by relevance score
            if info['relevant_columns']:
                sorted_columns = sorted(info['relevant_columns'].items(), 
                                      key=lambda x: x[1]['relevance_score'], reverse=True)
                
                for col_name, col_info in sorted_columns:
                    table_rec['recommended_columns'].append({
                        'column_name': col_name,
                        'column_type': col_info['type'],
                        'relevance_score': round(col_info['relevance_score'], 3)
                    })
            
            recommendations['recommended_tables'].append(table_rec)
        
        return recommendations


vector_db = SchemaVectorDB()
    
    # Populate the database with schemas from the folder
schemas_folder = "./schemas"  # Update this path to your schemas folder
vector_db.populate_database(schemas_folder)

Processing schema file: schemas\chinook_db.py
Processing schema file: schemas\netflix_db.py
Added 117 documents to the vector database


In [2]:

test_queries = [
        "Give me track in pop genre"
    ]
    
print("\n" + "="*80)
print("SCHEMA RECOMMENDATION SYSTEM - TEST RESULTS")
print("="*80)

for query in test_queries:
    print(f"\nQuery: {query}")
    print("-" * 50)
        
recommendations = vector_db.get_recommendations(query, top_tables=2)
        
for table_rec in recommendations['recommended_tables']:
    print(f"\nTable: {table_rec['table_name']} (Score: {table_rec['relevance_score']})")
    print(f"Description: {table_rec['description']}")
            
    if table_rec['recommended_columns']: print("Recommended columns:")
    
    for col in table_rec['recommended_columns'][:5]:  # Show top 5 columns
        print(f"  - {col['column_name']} ({col['column_type']}) - Score: {col['relevance_score']}")
            
    print()



SCHEMA RECOMMENDATION SYSTEM - TEST RESULTS

Query: Give me track in pop genre
--------------------------------------------------

Table: PlaylistTrack (Score: 0.366)
Description: This table records the association between playlists and tracks, linking each track to the playlists it belongs to. It supports the organization and management of music collections by enabling the construction of playlists from individual tracks within a music application.
Recommended columns:
  - TrackId (INTEGER) - Score: 0.373
  - PlaylistId (INTEGER) - Score: 0.345


Table: Track (Score: 0.365)
Description: This table holds detailed information about individual music tracks, including their titles, associated album, media type, genre, composer, duration, file size, and unit price. It serves as a central repository for cataloging and managing music inventory, supporting media libraries, sales transactions, and music discovery features within the application.
Recommended columns:
  - GenreId (INTEGER) - Sc