In [49]:
import pandas as pd
import numpy as np
import json
import psycopg2
from psycopg2.extras import execute_values
from openai import OpenAI
import os
import time
from dotenv import load_dotenv

In [50]:
# Load environment variables from .env file
load_dotenv()

# OpenAI API configuration
# Initialize the client
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

# PostgreSQL database connection
DB_PARAMS = {
    "dbname": "mydatabase",
    "user": "myuser",
    "password": "mypassword",
    "host": "localhost",
    "port": "5433"  # Ensure this matches your PostgreSQL container port
}

In [51]:
def get_db_connection():
    """Create and return database connection and cursor"""
    conn = psycopg2.connect(**DB_PARAMS)
    cur = conn.cursor()
    return conn, cur

def fetch_all_metadata():
    """Fetch metadata for all tables"""
    conn, cur = get_db_connection()
    
    try:
        # Get metadata for all tables
        cur.execute("""
            SELECT table_name, description, table_purpose, columns_info, 
                   primary_keys, foreign_keys, important_considerations,
                   common_joins, example_questions
            FROM mimic_table_metadata;
        """)
        
        all_metadata = cur.fetchall()
        
        # Format metadata as dictionary
        tables_metadata = {}
        for row in all_metadata:
            table_name = row[0]
            tables_metadata[table_name] = {
                'description': row[1],
                'table_purpose': row[2],
                'columns_info': row[3],
                'primary_keys': row[4],
                'foreign_keys': row[5],
                'important_considerations': row[6],
                'common_joins': row[7],
                'example_questions': row[8]
            }
        
        return tables_metadata
    
    finally:
        cur.close()
        conn.close()

def format_metadata_for_prompt(metadata):
    """Format metadata into text suitable for LLM prompt"""
    formatted_text = "# Database Schema Information\n\n"
    
    for table_name, table_info in metadata.items():
        formatted_text += f"## Table: {table_name}\n"
        formatted_text += f"Description: {table_info['description']}\n"
        formatted_text += f"Purpose: {table_info['table_purpose']}\n\n"
        
        # Add primary key information
        if table_info['primary_keys']:
            formatted_text += f"Primary Keys: {', '.join(table_info['primary_keys'])}\n\n"
        
        # Add foreign key information
        if table_info['foreign_keys']:
            formatted_text += "Foreign Keys:\n"
            for fk_col, fk_info in table_info['foreign_keys'].items():
                formatted_text += f"- {fk_col} references {fk_info['table']}.{fk_info['column']}\n"
            formatted_text += "\n"
        
        # Add column information
        formatted_text += "Columns:\n"
        for col_name, col_info in table_info['columns_info'].items():
            formatted_text += f"- {col_name} ({col_info['data_type']}): {col_info['description']}\n"
            
            # Add categorical value distribution if available and not too long
            if 'categorical_values' in col_info and len(col_info['categorical_values']) < 15:
                formatted_text += f"  Possible values: {', '.join(col_info['categorical_values'])}\n"
            
            # Add value range if available
            if 'value_range' in col_info:
                formatted_text += f"  Range: {col_info['value_range']['min']} to {col_info['value_range']['max']}\n"
        
        formatted_text += "\n"
        
        # Add important considerations
        if table_info['important_considerations']:
            formatted_text += f"Important Considerations: {table_info['important_considerations']}\n\n"
        
        # Add common joins
        if table_info['common_joins']:
            formatted_text += "Common Joins:\n"
            for join in table_info['common_joins']:
                formatted_text += f"- {join}\n"
            formatted_text += "\n"
        
        formatted_text += "---\n\n"
    
    return formatted_text

In [52]:
def fetch_relevant_metadata(query_embedding, top_k=10):
    """
    Fetch metadata for tables most relevant to the query embedding using Python-based similarity
    
    Parameters:
        query_embedding (list): Vector representation of the user query
        top_k (int): Number of most relevant tables to return
        
    Returns:
        dict: Dictionary mapping table names to their metadata
    """
    conn, cur = get_db_connection()
    
    try:
        print(f"Computing similarity to find top {top_k} relevant tables")
        cur.execute("""
            SELECT 
                table_name, description, table_purpose, columns_info, 
                primary_keys, foreign_keys, important_considerations,
                common_joins, example_questions, embedding
            FROM mimic_table_metadata
            WHERE embedding IS NOT NULL;
        """)
        
        rows = cur.fetchall()
        
        # Calculate similarity for each table
        table_similarities = []
        for row in rows:
            table_name = row[0]
            table_embedding = row[9]
            
            # Skip if embedding is NULL
            if table_embedding is None:
                print(f"Skipping table {table_name} as embedding is NULL")
                continue
            
            # Debug information
            print(f"Table: {table_name}")
            print(f"Embedding type: {type(table_embedding)}")
            print(f"Embedding first few elements: {str(table_embedding)[:100]}...")
            
            # Convert embedding to numeric list if needed
            try:
                # If embedding is stored as JSON string
                if isinstance(table_embedding, str):
                    import json
                    table_embedding = json.loads(table_embedding)
                # If embedding is stored as Python dict/list directly
                elif isinstance(table_embedding, dict) or hasattr(table_embedding, 'keys'):
                    # For JSONB stored as Python dict
                    table_embedding = list(table_embedding.values())
                
                # Ensure numeric array
                table_embedding = [float(x) for x in table_embedding]
                
                # Calculate cosine similarity
                similarity = cosine_similarity(query_embedding, table_embedding)
                table_similarities.append((row, similarity))
                
            except Exception as e:
                print(f"Error processing embedding for table {table_name}: {str(e)}")
                continue
            
        # Sort by similarity (descending) and take top_k
        table_similarities.sort(key=lambda x: x[1], reverse=True)
        top_tables = table_similarities[:top_k]
        
        # Format as dictionary
        tables_metadata = {}
        for row, similarity in top_tables:
            table_name = row[0]
            tables_metadata[table_name] = {
                'description': row[1],
                'table_purpose': row[2],
                'columns_info': row[3],
                'primary_keys': row[4],
                'foreign_keys': row[5],
                'important_considerations': row[6],
                'common_joins': row[7],
                'example_questions': row[8]
            }
            
            # Print similarity for debugging
            print(f"Table: {table_name}, Similarity: {similarity:.4f}")
        
        return tables_metadata
    
    finally:
        cur.close()
        conn.close()

def cosine_similarity(vec1, vec2):
    """
    Compute cosine similarity between two vectors
    
    Parameters:
        vec1 (list): First vector
        vec2 (list): Second vector
        
    Returns:
        float: Cosine similarity (between -1 and 1)
    """
    import numpy as np
    
    # Convert to numpy arrays
    vec1 = np.array(vec1)
    vec2 = np.array(vec2)
    
    # Compute cosine similarity
    dot_product = np.dot(vec1, vec2)
    norm_a = np.linalg.norm(vec1)
    norm_b = np.linalg.norm(vec2)
    
    if norm_a == 0 or norm_b == 0:
        return 0
    
    return dot_product / (norm_a * norm_b)

In [53]:
def create_llm_prompt(user_question, metadata_text):
    """Create the complete prompt to send to the LLM"""
    prompt = f"""You are a professional medical database expert specializing in SQL and the MIMIC-IV database. Based on the user's question and the provided database metadata, generate a PostgreSQL query.

## User Question
{user_question}

## Database Metadata
{metadata_text}

## Task
1. Analyze the user question to determine which tables and columns need to be queried
2. Design an effective SQL query based on the provided metadata
3. Ensure the generated SQL is syntactically correct and considers table relationships
4. If multiple table joins are needed, use the correct join conditions
5. Handle any potential edge cases

## Response Format
Please return ONLY the SQL query without any explanation. Start your answer with "SELECT".

SQL Query:
"""
    return prompt

In [54]:
def generate_sql_with_openai(prompt):
    """Generate SQL query using OpenAI API (updated for version 1.0+)"""
    try:
        # Using the new client format
        response = client.chat.completions.create(
            model="gpt-4",  # or another suitable model
            messages=[
                {"role": "system", "content": "You are a medical database expert who converts natural language questions into PostgreSQL queries."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.1,  # Low temperature for more deterministic output
            max_tokens=500
        )
        
        # Updated way to access response content
        sql_query = response.choices[0].message.content.strip()
        
        # Remove "SQL Query:" prefix if present
        if sql_query.startswith("SQL Query:"):
            sql_query = sql_query[len("SQL Query:"):].strip()
        
        return sql_query
    
    except Exception as e:
        print(f"Error calling OpenAI API: {e}")
        return None


def execute_sql_query(sql_query):
    """Execute SQL query and return results"""
    conn, cur = get_db_connection()
    
    try:
        cur.execute(sql_query)
        
        # Get column names
        column_names = [desc[0] for desc in cur.description]
        
        # Get query results
        results = cur.fetchall()
        
        # Convert results to DataFrame
        df = pd.DataFrame(results, columns=column_names)
        
        return df
    
    except Exception as e:
        print(f"Error executing SQL query: {e}")
        return None
    
    finally:
        cur.close()
        conn.close()

In [55]:
def vectorize_user_query(query_text):
    """
    Convert a user's natural language query into a vector representation
    
    Parameters:
        query_text (str): The user's natural language query
        
    Returns:
        list: Vector representation of the query
    """
    try:
        # Preprocess the query text - more processing steps can be added as needed
        processed_query = query_text.strip()
        
        # Generate embedding vector using OpenAI API
        response = client.embeddings.create(
            input=processed_query,
            model="text-embedding-3-small"  # Use the same model as for table embeddings
        )
        
        # Extract the embedding vector
        query_embedding = response.data[0].embedding
        
        print(f"✅ Successfully vectorized query: '{query_text[:50]}...' if len(query_text) > 50 else query_text")
        return query_embedding
    
    except Exception as e:
        print(f"❌ Error vectorizing query: {e}")
        return None

In [56]:
def sqlrag_pipeline(user_question):
    """Execute the complete SQLRAG pipeline"""
    print(f"User Question: {user_question}")
    
    # 1. Vectorize user query
    print("Vectorizing user query...")
    query_embedding = vectorize_user_query(user_question)
    
    if not query_embedding:
        return "Failed to vectorize user query"
    
    # 2. Fetch relevant metadata using vector similarity
    print("Finding relevant tables...")
    metadata = fetch_relevant_metadata(query_embedding, top_k=5)
    
    # 3. Format metadata for prompt
    metadata_text = format_metadata_for_prompt(metadata)
    
    # 4. Create LLM prompt
    prompt = create_llm_prompt(user_question, metadata_text)
    
    print("Generating SQL with LLM...")
    # 5. Generate SQL query
    sql_query = generate_sql_with_openai(prompt)
    
    if not sql_query:
        return "Failed to generate SQL query"
    
    print(f"Generated SQL: \n{sql_query}\n")
    
    print("Executing SQL query...")
    # 6. Execute SQL query
    results = execute_sql_query(sql_query)
    
    if results is None:
        return "Error executing SQL query"
    
    print("Query Results:")
    print(results)
    
    return {
        "user_question": user_question,
        "generated_sql": sql_query,
        "results": results
    }