In [None]:
import sqlite3
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import pickle
import importlib
from typing import Dict, List, Tuple, Any

class SchemaVectorDB:
    def __init__(self, db_path):
        self.db_path = db_path
        self.model = SentenceTransformer('all-MiniLM-L6-v2') 
        self.conn = None
        self.init_database()
    
    def init_database(self):
        """Initialize SQLite database with embedding table"""
        print("Initializing database..." , self.db_path)
        self.conn = sqlite3.connect(self.db_path)
        cursor = self.conn.cursor()
        
        # Create embedding table
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS embedding_table (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                table_name TEXT NOT NULL UNIQUE,
                description TEXT NOT NULL,
                embedding BLOB NOT NULL,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            )
        ''')
        
        # Create index for faster searches
        cursor.execute('''
            CREATE INDEX IF NOT EXISTS idx_table_name 
            ON embedding_table(table_name)
        ''')
        
        self.conn.commit()

    def load_schema_from_file(self ,file_path: str) -> Dict[str, Any]:
        """Load schema from a Python file"""
        try:
            spec = importlib.util.spec_from_file_location("schema_module", file_path)
            schema_module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(schema_module)
            return schema_module.schema
        except Exception as e:
            print(f"Error loading schema from {file_path}: {e}")
            return {}
    
    def process_schema(self, schema: Dict[str, Any]):
        """Process schema and create embeddings for each table"""
        for table_name, table_info in schema.items():
            if isinstance(table_info, dict) and 'description' in table_info:
                # Extract table information
                description = table_info.get('description', '')
                
                # Create comprehensive text for embedding
                embedding_text = self._create_embedding_text(
                    table_name,  description
                )
                
                # Generate embedding
                embedding = self.model.encode(embedding_text)
                
                # Store in database
                self._store_embedding(table_name, embedding_text, embedding)
    
    def _create_embedding_text(self, table_name: str,  
                              description: str) -> str:
        """Create comprehensive text for embedding generation"""
        text_parts = [
            f"Table: {table_name}",
            f"Description: {description}" if description else "",
        ]
        
        return " | ".join(filter(None, text_parts))
    
    def _store_embedding(self, table_name: str, description: str, embedding: np.ndarray):
        """Store embedding in SQLite database ensuring uniqueness"""
        cursor = self.conn.cursor()
        
        # Convert embedding to binary format
        embedding_blob = pickle.dumps(embedding)
        print(f"Storing embedding for table: {table_name}")
        
        # First, check if the table already exists
        cursor.execute('SELECT id FROM embedding_table WHERE table_name = ?', (table_name,))
        existing = cursor.fetchone()
        
        if existing:
            # Update existing record
            cursor.execute('''
                UPDATE embedding_table 
                SET description = ?, embedding = ?, created_at = CURRENT_TIMESTAMP
                WHERE table_name = ?
            ''', (description, embedding_blob, table_name))
            print(f"Updated existing embedding for table: {table_name}")
        else:
            # Insert new record
            cursor.execute('''
                INSERT INTO embedding_table 
                (table_name, description, embedding) 
                VALUES (?, ?, ?)
            ''', (table_name, description, embedding_blob))
            print(f"Inserted new embedding for table: {table_name}")
        
        self.conn.commit()

    def search_relevant_tables(self, user_prompt: str, top_k: int = 5) -> List[Tuple[str, float]]:
        """Search for relevant tables based on user prompt"""
        # Generate embedding for user prompt
        prompt_embedding = self.model.encode(user_prompt)
        
        # Retrieve all embeddings from database
        cursor = self.conn.cursor()
        cursor.execute('SELECT table_name, embedding FROM embedding_table')
        results = cursor.fetchall()
        
        if not results:
            return []
        
        # Calculate similarities and deduplicate
        table_similarities = {}
        for table_name, embedding_blob in results:
            stored_embedding = pickle.loads(embedding_blob)
            similarity = cosine_similarity(
                prompt_embedding.reshape(1, -1), 
                stored_embedding.reshape(1, -1)
            )[0][0]
            
            # Keep only the highest similarity for each table
            if table_name not in table_similarities or similarity > table_similarities[table_name]:
                table_similarities[table_name] = similarity
        
        # Convert to list of tuples and sort by similarity
        similarities = list(table_similarities.items())
        similarities.sort(key=lambda x: x[1], reverse=True)
        return similarities[:top_k]
   
    def close(self):
        """Close database connection"""
        if self.conn:
            self.conn.close()


# Example usage and testing
def main():
    # Example schema structure
    # Initialize vector database
    vector_db = SchemaVectorDB('database/chinook.db')
    example_schema = vector_db.load_schema_from_file('schemas/chinook_db.py')
    
    # Process schema and create embeddings
    print("Processing schema and creating embeddings...")
    vector_db.process_schema(example_schema)
    
    # Test queries
    return vector_db

In [27]:
vector_db = main()
test_queries = [
    "Give me playlist in pop genre",
    ]
    
print("\n" + "="*50)
print("TESTING VECTOR DATABASE")
print("="*50)
for query in test_queries:
        print(f"\nQuery: {query}")
        print("-" * 30)
        
        # Find relevant tables
        relevant_tables = vector_db.search_relevant_tables(query, top_k=5)
        print("Relevant tables:")
        for table, similarity in relevant_tables:
            print(f"  - {table}: {similarity:.3f}")

Initializing database... database/chinook.db
Processing schema and creating embeddings...
Storing embedding for table: Album
Inserted new embedding for table: Album
Storing embedding for table: Artist
Inserted new embedding for table: Artist
Storing embedding for table: Customer
Inserted new embedding for table: Customer
Storing embedding for table: Employee
Inserted new embedding for table: Employee
Storing embedding for table: Genre
Inserted new embedding for table: Genre
Storing embedding for table: Invoice
Inserted new embedding for table: Invoice
Storing embedding for table: InvoiceLine
Inserted new embedding for table: InvoiceLine
Storing embedding for table: MediaType
Inserted new embedding for table: MediaType
Storing embedding for table: Playlist
Inserted new embedding for table: Playlist
Storing embedding for table: PlaylistTrack
Inserted new embedding for table: PlaylistTrack
Storing embedding for table: Track
Inserted new embedding for table: Track

TESTING VECTOR DATABASE
