In [1]:
# ! pip install --upgrade pip
# ! pip install --upgrade jupyter ipywidgets
# ! pip install mysql-connector-python ollama scikit-learn transformers torch torchvision torchaudio requests

In [2]:
import mysql.connector
import ollama
from sklearn.metrics.pairwise import cosine_similarity
import random

In [3]:
import requests
import json

In [4]:
from sentence_transformers import SentenceTransformer

In [5]:
# Load MPNet model for sentence embeddings
model = SentenceTransformer('all-mpnet-base-v2')

In [6]:
# Step 1: Connect to MySQL database and extract the schema
database = "VenueScope"

db = mysql.connector.connect(
    host="localhost",
    user="root",
    password="Karaikudi-630002",
    database=database
)
cursor = db.cursor()

query = f"""
SELECT 
    TABLE_NAME, 
    COLUMN_NAME 
FROM 
    INFORMATION_SCHEMA.COLUMNS 
WHERE 
    TABLE_SCHEMA = '{database}'
ORDER BY 
    TABLE_NAME, ORDINAL_POSITION;
"""
cursor.execute(query)
schemaColumns = cursor.fetchall()

In [7]:
# Step 2: Process schema to extract table and column names
schemaInfo = {}
for tableName, columnName in schemaColumns:
    if tableName not in schemaInfo:
        schemaInfo[tableName] = []
    schemaInfo[tableName].append(columnName)

In [8]:
def tokenizeQuery(query):
    """
    Tokenize the query
    """
    words = query.lower().split()  # Simple tokenization
    return words

In [9]:
# Helper function to get MPNet embeddings
def get_MPNet_embeddings(text):
    embedding = model.encode(text, convert_to_numpy=True)
    return embedding.reshape(1, -1)  # Ensure it's 2D for cosine similarity

In [10]:
# Step 3: Convert table and column names to BERT embeddings
def getSchemaEmbeddings(schemaInfo):
    """
    Convert schema information (table and column names) into BERT embeddings.
    """
    schemaEmbeddings = []
    tableKeys = []
    
    for table, columns in schemaInfo.items():
        for column in columns:
            text = f"{table} {column}"
            embedding = get_MPNet_embeddings(text)
            schemaEmbeddings.append(embedding)
            tableKeys.append(table)
    
    return schemaEmbeddings, tableKeys

schemaEmbeddings, tableKeys = getSchemaEmbeddings(schemaInfo)

In [11]:
def rankSchemas_V2(userKeywords, schemaInfo, userQueryEmbedding, schemaEmbeddings, tableKeys):
    """
    Rank schema tables based on substring matching and BERT embeddings.
    Priority is given to matches of user keywords, with embeddings as a secondary score.
    """
    # Initialize scores
    tableScores = {}

    # Extract relevant keywords from the user query
    relevantKeywords = set(keyword.lower() for keyword in userKeywords)

    # Step 1: Apply string matching to prioritize relevant tables
    for table in schemaInfo:
        columns = schemaInfo[table]
        tableLower = table.lower()

        # Boost for matches of relevant keywords in table name
        for keyword in relevantKeywords:
            if keyword in tableLower:
                tableScores[table] = tableScores.get(table, 0) + 2
        
        # Boost for matches of relevant keywords in column names
        for column in columns:
            for keyword in relevantKeywords:
                if keyword in column.lower():
                    tableScores[table] = tableScores.get(table, 0) + 2  # Adjust boost as needed

    # Step 2: Apply embedding similarity as secondary ranking factor
    for i, table in enumerate(tableKeys):
        similarityScore = cosine_similarity(userQueryEmbedding, schemaEmbeddings[i]).flatten()[0]
        tableScores[table] = tableScores.get(table, 0) + similarityScore

    # Step 3: Sort the tables based on the combined score (higher is better)
    rankedTables = sorted(tableScores.keys(), key=lambda x: tableScores[x], reverse=True)

    # Step 4: Remove duplicates, maintaining order
    uniqueRankedTables = []
    seenTables = set()
    for table in rankedTables:
        if table not in seenTables:
            uniqueRankedTables.append(table)
            seenTables.add(table)

    return uniqueRankedTables

In [12]:
# Step 6: Construct SQL query dynamically based on top-ranked table and columns
def construct_SQL_Query(rankedTableNames, schemaInfo, userQuery, top_N):
    """
    Construct SQL query dynamically using Ollama based on top n-ranked tables and the user's query.
    """
    # Get the top n-ranked tables
    topRankedTables = rankedTableNames[:top_N]
    
    # Collect schema information for the top-ranked tables
    schemaInfoAsString = ""
    for table in topRankedTables:
        columns = schemaInfo[table]
        schemaInfoAsString += f"Table {table}: Columns ({', '.join(columns)})\n"

    # Pass the schema info and user query to Ollama
    stream = ollama.chat(
        model='duckdb-nsql',
        messages=[{'role': 'user', 'content': f"This is the schema: \n{schemaInfoAsString}\n{userQuery}"}],
        stream=True,
    )

    response = ""
    for chunk in stream:
        response += chunk['message']['content']

    return response

In [13]:
def getForeignKeyRelations(cursor, schemaInfo, database):
    """
    Extract foreign key relationships from the INFORMATION_SCHEMA for the given database.
    Returns a dictionary mapping tables to their related tables via foreign keys.
    """
    foreignKeyQuery = f"""
    SELECT 
        TABLE_NAME, 
        COLUMN_NAME, 
        REFERENCED_TABLE_NAME, 
        REFERENCED_COLUMN_NAME
    FROM 
        INFORMATION_SCHEMA.KEY_COLUMN_USAGE 
    WHERE 
        TABLE_SCHEMA = '{database}' 
        AND REFERENCED_TABLE_NAME IS NOT NULL;
    """
    
    cursor.execute(foreignKeyQuery)
    foreignKeys = cursor.fetchall()
    
    foreignKeyRelations = {}
    for table, column, refTable, refColumn in foreignKeys:
        if table not in foreignKeyRelations:
            foreignKeyRelations[table] = []
        foreignKeyRelations[table].append((column, refTable, refColumn))
    
    return foreignKeyRelations

In [14]:
def rankColumnsByRelevance(userQueryEmbedding, columnNames, columnEmbeddings):
    """
    Compare user query embedding with column embeddings and rank columns based on relevance.
    """
    columnScores = []
    userQueryEmbedding = userQueryEmbedding.reshape(1, -1)  # Reshape user query embedding to 2D

    for column, embedding in zip(columnNames, columnEmbeddings):
        embedding = embedding.reshape(1, -1)  # Reshape column embedding to 2D
        # Compute similarity between the user query and each column embedding (cosine similarity)
        similarityScore = cosine_similarity(userQueryEmbedding, embedding)[0][0]  # Extract scalar
        columnScores.append((column, similarityScore))

    # Sort columns by relevance (higher similarity score first)
    columnScores.sort(key=lambda x: x[1], reverse=True)
    return columnScores

In [15]:
def constructAndExecuteQuery(cursor, rankedTableNames, schemaInfo, userQuery, top_N, maxAttempts=5):
    """
    retry_construct_and_execute_query_with_column_reranking
    For each top_N ranked table, rank its columns by relevance to the user query,
    re-rank tables based on the relevance of columns, and generate SQL query if relevant.
    """
    attempt = 0
    success = False
    ollamaQuery = ""
    userQueryEmbedding = get_MPNet_embeddings(userQuery)  # Embed the user's query

    while not success and attempt < maxAttempts:
        try:
            # Increment attempt count
            attempt += 1
            print(f"Attempt {attempt} to generate and execute the query...")

            # Iterate over top-ranked tables to find the most relevant column match
            for tableName in rankedTableNames[:top_N]:
                columnNames = schemaInfo[tableName]  # Get columns for the table
                columnEmbeddings = get_MPNet_embeddings(columnNames)  # Embed the column names

                # Rank columns based on their relevance to the user's query
                rankedColumns = rankColumnsByRelevance(userQueryEmbedding, columnNames, columnEmbeddings)
                print(f"Ranked columns for table {tableName}: {rankedColumns}")

                # Check if the top-ranked column has sufficient relevance
                topColumn, relevanceScore = rankedColumns[0]
                print(f"Top column: {topColumn}, Relevance score: {relevanceScore}")

                if relevanceScore > 0.5:  # Threshold for relevance (can be adjusted)
                    print(f"Proceeding with table {tableName} and top column {topColumn}")

                    # Generate SQL query using Ollama with the relevant table and columns
                    ollamaQuery = construct_SQL_Query([tableName], schemaInfo, userQuery, top_N=1)
                    print("Generated Query from Ollama:", ollamaQuery)

                    # Try executing the query
                    cursor.execute(ollamaQuery)
                    results = cursor.fetchall()
                    success = True  # Mark success if query executes successfully
                    break
                else:
                    print(f"Relevance score too low for table {tableName}. Trying the next table...")

        except mysql.connector.Error as err:
            print(f"Query execution failed with error: {err}")
            print("Re-ranking columns and trying the next table...")

    # If successful, return the results
    if success:
        print("Query executed successfully!")
        return results
    else:
        print(f"Failed after {maxAttempts} attempts.")
        return None

In [16]:
def getForeignKeys(cursor, schemaInfo):
    """
    Extract foreign key relationships for the tables in schema_info.
    """
    foreignKeys = {}
    for table in schemaInfo.keys():
        cursor.execute(f"""
            SELECT 
                COLUMN_NAME, 
                REFERENCED_TABLE_NAME, 
                REFERENCED_COLUMN_NAME
            FROM 
                INFORMATION_SCHEMA.KEY_COLUMN_USAGE
            WHERE 
                TABLE_NAME = '{table}' AND 
                TABLE_SCHEMA = 'your_database_name' AND 
                REFERENCED_TABLE_NAME IS NOT NULL;
        """)
        foreignKeys[table] = cursor.fetchall()
    return foreignKeys

In [17]:
def construct_SQL_QueryForOllama(topTables, schemaInfo, userQuery):
    """
    Construct SQL query using schema information for the top tables.
    """
    schemaInfoAsString = ""
    for table in topTables:
        columns = schemaInfo[table]
        schemaInfoAsString += f"Table {table}: Columns ({', '.join(columns)})\n"

    # Prepare the final query for Ollama
    queryForOllama = f"{schemaInfoAsString}\n{userQuery}"
    return queryForOllama

In [18]:
def processUserQuery(query):
    # Set maximum attempts and number of top tables to return
    maxAttempts = 10
    top_N = 2

    # Extract keywords from the user query
    userKeywords = tokenizeQuery(query)
    userQueryEmbedding = get_MPNet_embeddings(query)

    # Rank the table names based on user query and embeddings
    rankedTableNames = rankSchemas_V2(
        userKeywords, schemaInfo, userQueryEmbedding, schemaEmbeddings, tableKeys
    )

    # Select top N tables after re-ranking
    topTablesAfterReranking = rankedTableNames[:top_N]

    attempt = 0
    success = False
    responseText = ""

    # Try executing the query for a maximum of 'max_attempts' times
    while not success and attempt < maxAttempts:
        attempt += 1
        try:
            # Increase the number of tables considered by 1 for each attempt
            print(f"Attempt {attempt}: Considering top {top_N} tables")

            # Re-rank tables after the increment
            topTablesAfterReranking = rankedTableNames[:top_N]

            print(f"Top tables after re-ranking: {topTablesAfterReranking}")

            # Construct SQL query for Ollama or similar model
            ollamaQuery = construct_SQL_QueryForOllama(topTablesAfterReranking, schemaInfo, query)

            # Replace with Ollama chat or appropriate model call
            stream = ollama.chat(
                model='duckdb-nsql',
                messages=[{'role': 'user', 'content': ollamaQuery}],
                stream=True
            )

            # Collect the response from the stream
            response = ""
            for chunk in stream:
                response += chunk['message']['content']

            # Execute the query on the database
            cursor.execute(response)
            results = cursor.fetchall()

            # If successful, set response text and reset top_N
            success = True
            if success:
                responseText = f"Query Generated: {response}\nOutput: {str(results)}\n"
                top_N = 2  # Reset top_N to its initial value

        except mysql.connector.Error as err:
            top_N += 1
            print(f"Query execution failed with error: {err}")
        
        except Exception as e:
            print(f"An unexpected error occurred: {e}")
            if top_N % 2 != 0:
                try:
                    temp = random.choice([0.05, 0.1, 0.15, 0.2])
                    url = "http://127.0.0.1:11434/api/generate"
                    headers = {"Content-Type": "application/json"}
                    data = {
                        "model": "duckdb-nsql",
                        "prompt": ollamaQuery,
                        "max_tokens": 1024,
                        "temperature": temp,
                        "stream": True
                    }
                    response = requests.post(url, headers=headers, data=json.dumps(data), stream=True)
                    buffer = ""
                    for line in response.iter_lines():
                        if line:
                            try:
                                buffer += line.decode('utf-8')
                                result_chunk = json.loads(buffer)
                                buffer = ""
                                if "response" in result_chunk:
                                    response_text += result_chunk["response"]
                            except json.JSONDecodeError:
                                continue
                    cursor.execute(response_text)
                    results = cursor.fetchall()
                    response_text = f"Query Generated: {response_text}\nOutput: {str(results)}\n"
                    success = True
                except mysql.connector.Error as err:
                    print(f"Query execution failed with entropy with error: {err}, {response_text}")
                except Exception as entropy_error:
                    print(f"Entropy-based retry failed with error: {entropy_error}")
    # If no success after max attempts, notify the user
    if not success:
        responseText = f"Failed to execute query after {maxAttempts} attempts."

    return responseText


In [19]:
# Example usage:
inputText = "List the names of the club heads and the clubs they belong to."
reply = processUserQuery(inputText)
print(f"Response: {reply}")

Attempt 1: Considering top 2 tables
Top tables after re-ranking: ['club_list', 'club_head_details']
Response: Query Generated:  SELECT c.club_name, h.club_head FROM club_list AS c JOIN club_head_details AS h ON c.club_id = h.head_id;
Output: [('AeroModeling Club', 'John Doe'), ('Animal Welfare Club', 'Jane Smith'), ('Anti Drug Club', 'Sam Johnson'), ('Artificial Intelligence & Robotics', 'Emily Davis'), ('Association of Serious Quizzers', 'Chris Brown'), ('Astronomy Club', 'Anna Lee'), ('Book Readers Club', 'Michael White'), ('CAP Nature Club', 'Emma Wilson'), ('Cyber Security Club', 'David Harris'), ('Dramatix Club', 'Sophia Thompson'), ('English Literary Society', 'James Martin'), ('Entrepreneurs Club', 'Olivia Taylor'), ('Fine Arts Club', 'Benjamin Walker'), ('Finverse Club', 'Ava Scott'), ('Global Leaders Forum', 'Liam Adams'), ('Higher Education Forum', 'Isabella Nelson'), ('Industry Interaction Forum', 'Noah Young'), ('Martial Arts Club', 'Mia Allen'), ('PSG Tech Chronicle Club',

In [18]:
db.close()