In [1]:
! pip install ollama mysql-connector-python transformers torch



In [2]:
import mysql.connector
import ollama
from sklearn.metrics.pairwise import cosine_similarity
from transformers import BertTokenizer, BertModel

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Step 1: Connect to MySQL database and extract schema
database = "VenueScope"

db = mysql.connector.connect(
    host="localhost",
    user="root",
    password="Karaikudi-630002",
    database=database
)
cursor = db.cursor()

query = f"""
SELECT 
    TABLE_NAME, 
    COLUMN_NAME 
FROM 
    INFORMATION_SCHEMA.COLUMNS 
WHERE 
    TABLE_SCHEMA = '{database}'
ORDER BY 
    TABLE_NAME, ORDINAL_POSITION;
"""
cursor.execute(query)
schema_columns = cursor.fetchall()


In [4]:
# Step 2: Process schema to extract table and column names
schema_info = {}
for table_name, column_name in schema_columns:
    if table_name not in schema_info:
        schema_info[table_name] = []
    schema_info[table_name].append(column_name)

# Load BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

# Helper function to get BERT embeddings
def get_bert_embeddings(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
    outputs = model(**inputs)
    embeddings = outputs.last_hidden_state.mean(dim=1).detach().numpy()  # Use mean pooling for sentence embeddings
    return embeddings



In [5]:
def extract_keywords(user_query):
    """
    Simple keyword extraction to match schema terms.
    """
    keywords = user_query.lower().split()  # Split the query into words
    return keywords


In [6]:
# Step 3: Convert table and column names to BERT embeddings
def get_schema_embeddings(schema_info):
    """
    Convert schema information (table and column names) into BERT embeddings.
    """
    schema_embeddings = []
    table_keys = []
    
    for table, columns in schema_info.items():
        for column in columns:
            text = f"{table} {column}"
            embedding = get_bert_embeddings(text)
            schema_embeddings.append(embedding)
            table_keys.append(table)
    
    return schema_embeddings, table_keys

schema_embeddings, table_keys = get_schema_embeddings(schema_info)

In [7]:
def get_query_embedding(user_query):
    """
    Convert user query into BERT embeddings.
    """
    return get_bert_embeddings(user_query)

In [15]:
def rank_schemas_v2(user_keywords, schema_info, user_query_embedding, schema_embeddings, table_keys):
    """
    Generalized function to rank schema tables based on keyword matches and BERT embeddings.
    """
    # Initialize table scores
    table_scores = {}

    # Step 1: Apply string matching across table and column names
    for table, columns in schema_info.items():
        table_lower = table.lower()
        table_scores[table] = 0  # Initialize score for the table
        
        # Boost score if user keywords match in the table name or its columns
        for keyword in user_keywords:
            # Boost for keyword in table name
            if keyword in table_lower:
                table_scores[table] += 2

            # Boost for keyword in column names
            for column in columns:
                column_lower = column.lower()
                if keyword in column_lower:
                    table_scores[table] += 1  # Smaller boost for column matches

    # Step 2: Apply embedding similarity as a secondary ranking factor
    for i, table in enumerate(table_keys):
        similarity_score = cosine_similarity(user_query_embedding, schema_embeddings[i]).flatten()[0]
        table_scores[table] = table_scores.get(table, 0) + similarity_score

    # Step 3: Sort the tables based on the combined score (higher is better)
    ranked_tables = sorted(table_scores.keys(), key=lambda x: table_scores[x], reverse=True)

    # Step 4: Remove duplicates, maintaining order
    unique_ranked_tables = []
    seen_tables = set()
    for table in ranked_tables:
        if table not in seen_tables:
            unique_ranked_tables.append(table)
            seen_tables.add(table)

    return unique_ranked_tables


In [9]:
# Step 6: Construct SQL query dynamically based on top-ranked table and columns
def construct_sql_query(ranked_table_names, schema_info, user_query, top_n):
    """
    Construct SQL query dynamically using Ollama based on top n-ranked tables and the user's query.
    """
    # Get the top n-ranked tables
    top_ranked_tables = ranked_table_names[:top_n]
    
    # Collect schema information for the top-ranked tables
    schema_info_str = ""
    for table in top_ranked_tables:
        columns = schema_info[table]
        schema_info_str += f"Table {table}: Columns ({', '.join(columns)})\n"

    # Pass the schema info and user query to Ollama
    stream = ollama.chat(
        model='duckdb-nsql',
        messages=[{'role': 'user', 'content': f"This is the schema: \n{schema_info_str}\n{user_query}"}],
        stream=True,
    )

    response = ""
    for chunk in stream:
        response += chunk['message']['content']

    return response

# top_n = 2  # Set how many top-ranked tables to include
# Get the response from Ollama based on the user's query and ranked schema
# ollama_query = construct_sql_query(ranked_table_names, schema_info, user_query, top_n)
# print("Ollama SQL Query:", ollama_query)

In [10]:
# Step 1: Modify ranking to consider foreign keys and relationships
def get_foreign_key_relations(cursor, schema_info, database):
    """
    Extract foreign key relationships from the INFORMATION_SCHEMA for the given database.
    Returns a dictionary mapping tables to their related tables via foreign keys.
    """
    foreign_key_query = f"""
    SELECT 
        TABLE_NAME, 
        COLUMN_NAME, 
        REFERENCED_TABLE_NAME, 
        REFERENCED_COLUMN_NAME
    FROM 
        INFORMATION_SCHEMA.KEY_COLUMN_USAGE 
    WHERE 
        TABLE_SCHEMA = '{database}' 
        AND REFERENCED_TABLE_NAME IS NOT NULL;
    """
    
    cursor.execute(foreign_key_query)
    foreign_keys = cursor.fetchall()
    
    fk_relations = {}
    for table, column, ref_table, ref_column in foreign_keys:
        if table not in fk_relations:
            fk_relations[table] = []
        fk_relations[table].append((column, ref_table, ref_column))
    
    return fk_relations

In [11]:
from sklearn.metrics.pairwise import cosine_similarity

def rank_columns_by_relevance(user_query_embedding, column_names, column_embeddings):
    """
    Compare user query embedding with column embeddings and rank columns based on relevance.
    """
    column_scores = []
    user_query_embedding = user_query_embedding.reshape(1, -1)  # Reshape user query embedding to 2D

    for column, embedding in zip(column_names, column_embeddings):
        embedding = embedding.reshape(1, -1)  # Reshape column embedding to 2D
        # Compute similarity between the user query and each column embedding (cosine similarity)
        similarity_score = cosine_similarity(user_query_embedding, embedding)[0][0]  # Extract scalar
        column_scores.append((column, similarity_score))

    # Sort columns by relevance (higher similarity score first)
    column_scores.sort(key=lambda x: x[1], reverse=True)
    return column_scores



def construct_and_execute_query(cursor, ranked_table_names, schema_info, user_query, top_n, max_attempts=5):
    """
    retry_construct_and_execute_query_with_column_reranking
    For each top_n ranked table, rank its columns by relevance to the user query,
    re-rank tables based on the relevance of columns, and generate SQL query if relevant.
    """
    attempt = 0
    success = False
    ollama_query = ""
    user_query_embedding = get_bert_embeddings(user_query)  # Embed the user's query

    while not success and attempt < max_attempts:
        try:
            # Increment attempt count
            attempt += 1
            print(f"Attempt {attempt} to generate and execute the query...")

            # Iterate over top-ranked tables to find the most relevant column match
            for table_name in ranked_table_names[:top_n]:
                column_names = schema_info[table_name]  # Get columns for the table
                column_embeddings = get_bert_embeddings(column_names)  # Embed the column names

                # Rank columns based on their relevance to the user's query
                ranked_columns = rank_columns_by_relevance(user_query_embedding, column_names, column_embeddings)
                print(f"Ranked columns for table {table_name}: {ranked_columns}")

                # Check if the top-ranked column has sufficient relevance
                top_column, relevance_score = ranked_columns[0]
                print(f"Top column: {top_column}, Relevance score: {relevance_score}")

                if relevance_score > 0.5:  # Threshold for relevance (can be adjusted)
                    print(f"Proceeding with table {table_name} and top column {top_column}")

                    # Generate SQL query using Ollama with the relevant table and columns
                    ollama_query = construct_sql_query([table_name], schema_info, user_query, top_n=1)
                    print("Generated Query from Ollama:", ollama_query)

                    # Try executing the query
                    cursor.execute(ollama_query)
                    results = cursor.fetchall()
                    success = True  # Mark success if query executes successfully
                    break
                else:
                    print(f"Relevance score too low for table {table_name}. Trying the next table...")

        except mysql.connector.Error as err:
            print(f"Query execution failed with error: {err}")
            print("Re-ranking columns and trying the next table...")

    # If successful, return the results
    if success:
        print("Query executed successfully!")
        return results
    else:
        print(f"Failed after {max_attempts} attempts.")
        return None


In [17]:
def get_foreign_keys(cursor, schema_info):
    """
    Extract foreign key relationships for the tables in schema_info.
    """
    foreign_keys = {}
    for table in schema_info.keys():
        cursor.execute(f"""
            SELECT 
                COLUMN_NAME, 
                REFERENCED_TABLE_NAME, 
                REFERENCED_COLUMN_NAME
            FROM 
                INFORMATION_SCHEMA.KEY_COLUMN_USAGE
            WHERE 
                TABLE_NAME = '{table}' AND 
                TABLE_SCHEMA = 'your_database_name' AND 
                REFERENCED_TABLE_NAME IS NOT NULL;
        """)
        foreign_keys[table] = cursor.fetchall()
    return foreign_keys

def construct_sql_query_for_ollama(top_tables, schema_info, user_query):
    """
    Construct SQL query using schema information for the top tables.
    """
    schema_info_str = ""
    for table in top_tables:
        columns = schema_info[table]
        schema_info_str += f"Table {table}: Columns ({', '.join(columns)})\n"

    # You can also include foreign key information if necessary
    foreign_keys = get_foreign_keys(cursor, schema_info)
    for table, keys in foreign_keys.items():
        if keys:
            schema_info_str += f"Foreign keys for {table}:\n"
            for key in keys:
                schema_info_str += f"  - {key[0]} -> {key[1]}({key[2]})\n"

    # Prepare the final query for Ollama
    query_for_ollama = f"{schema_info_str}\n{user_query}"
    return query_for_ollama


In [13]:
def getQueriesFromFile(file_path):
    """
    Read the queries from a file and generate BERT embeddings for each.
    """
    print("Reading file...")
    queries = []
    
    # Open the file and read queries
    with open(file_path, 'r') as file:
        queries = file.readlines()
    
    print(f"Total lines read: {len(queries)}")
    
    # Generate embeddings for each query
    for i in range (len(queries)):
        queries[i] = queries[i].strip()  # Remove any leading/trailing whitespace
    
    return queries

# Assuming the queries are in 'queries.txt' file
file_path = 'queries.txt'
queries = getQueriesFromFile(file_path)

Reading file...
Total lines read: 3


In [18]:
# After re-ranking the top tables
top_n = 3  # Adjust the number of top tables as needed

for query in queries:
    user_keywords = extract_keywords(query)
    # Step 6: Get query embedding
    user_query_embedding = get_bert_embeddings(query)

    # Rank schema tables based on query relevance and uniqueness
    ranked_table_names = rank_schemas_v2(user_keywords, schema_info, user_query_embedding, schema_embeddings, table_keys)
    print(ranked_table_names)

    top_tables_after_re_ranking = ranked_table_names[:top_n]  # Replace with your actual top tables

    # Construct the query for Ollama
    ollama_query = construct_sql_query_for_ollama(top_tables_after_re_ranking, schema_info, query)

    # Pass this to Ollama for query generation
    stream = ollama.chat(
        model='duckdb-nsql',
        messages=[{'role': 'user', 'content': ollama_query}],
        stream=True,
    )

    response = ""
    for chunk in stream:
        response += chunk['message']['content']

    # Execute the generated SQL query
    cursor.execute(response)
    results = cursor.fetchall()
    print(query)
    print(response, results)


['booked_venue', 'club_list', 'venue_list', 'club_head_details', 'club_head']
List out the names of the heads from all the clubs
 SELECT club_name FROM club_list [('AeroModeling Club',), ('Animal Welfare Club',), ('Anti Drug Club',), ('Artificial Intelligence & Robotics',), ('Association of Serious Quizzers',), ('Astronomy Club',), ('Book Readers Club',), ('CAP Nature Club',), ('Cyber Security Club',), ('Dramatix Club',), ('English Literary Society',), ('Entrepreneurs Club',), ('Fine Arts Club',), ('Finverse Club',), ('Global Leaders Forum',), ('Higher Education Forum',), ('Industry Interaction Forum',), ('Martial Arts Club',), ('PSG Tech Chronicle Club',), ('Paathshala Club',), ('Radio Hub',), ('Rotaract Club',), ('SPIC-MACAY Heritage Club',), ('Student Research Council',), ('Tech Music',), ('Women Development Cell',), ('Youth Outreach Club',), ('Youth Red Cross Society',), ('Yuva Tourism Club',)]
['club_head_details', 'booked_venue', 'club_list', 'club_head', 'venue_list']
Get the na

In [163]:
results

[('SPIC-MACAY Heritage Club',)]

In [164]:
# Close cursor and database connection
cursor.close()
db.close()