In [15]:
import pandas as pd
import numpy as np
import json
import psycopg2
from psycopg2.extras import execute_values
from openai import OpenAI
import os
import time
from dotenv import load_dotenv

In [16]:
# Load environment variables from .env file
load_dotenv()

# OpenAI API configuration
# Initialize the client
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

# PostgreSQL database connection
DB_PARAMS = {
    "dbname": "mydatabase",
    "user": "myuser",
    "password": "mypassword",
    "host": "localhost",
    "port": "5433"  # Ensure this matches your PostgreSQL container port
}

In [17]:
def get_db_connection():
    """Create and return database connection and cursor"""
    conn = psycopg2.connect(**DB_PARAMS)
    cur = conn.cursor()
    return conn, cur

def fetch_all_metadata():
    """Fetch metadata for all tables"""
    conn, cur = get_db_connection()
    
    try:
        # Get metadata for all tables
        cur.execute("""
            SELECT table_name, description, table_purpose, columns_info, 
                   primary_keys, foreign_keys, important_considerations,
                   common_joins, example_questions
            FROM mimic_table_metadata;
        """)
        
        all_metadata = cur.fetchall()
        
        # Format metadata as dictionary
        tables_metadata = {}
        for row in all_metadata:
            table_name = row[0]
            tables_metadata[table_name] = {
                'description': row[1],
                'table_purpose': row[2],
                'columns_info': row[3],
                'primary_keys': row[4],
                'foreign_keys': row[5],
                'important_considerations': row[6],
                'common_joins': row[7],
                'example_questions': row[8]
            }
        
        return tables_metadata
    
    finally:
        cur.close()
        conn.close()

def format_metadata_for_prompt(metadata):
    """Format metadata into text suitable for LLM prompt"""
    formatted_text = "# Database Schema Information\n\n"
    
    for table_name, table_info in metadata.items():
        formatted_text += f"## Table: {table_name}\n"
        formatted_text += f"Description: {table_info['description']}\n"
        formatted_text += f"Purpose: {table_info['table_purpose']}\n\n"
        
        # Add primary key information
        if table_info['primary_keys']:
            formatted_text += f"Primary Keys: {', '.join(table_info['primary_keys'])}\n\n"
        
        # Add foreign key information
        if table_info['foreign_keys']:
            formatted_text += "Foreign Keys:\n"
            for fk_col, fk_info in table_info['foreign_keys'].items():
                formatted_text += f"- {fk_col} references {fk_info['table']}.{fk_info['column']}\n"
            formatted_text += "\n"
        
        # Add column information
        formatted_text += "Columns:\n"
        for col_name, col_info in table_info['columns_info'].items():
            formatted_text += f"- {col_name} ({col_info['data_type']}): {col_info['description']}\n"
            
            # Add categorical value distribution if available and not too long
            if 'categorical_values' in col_info and len(col_info['categorical_values']) < 15:
                formatted_text += f"  Possible values: {', '.join(col_info['categorical_values'])}\n"
            
            # Add value range if available
            if 'value_range' in col_info:
                formatted_text += f"  Range: {col_info['value_range']['min']} to {col_info['value_range']['max']}\n"
        
        formatted_text += "\n"
        
        # Add important considerations
        if table_info['important_considerations']:
            formatted_text += f"Important Considerations: {table_info['important_considerations']}\n\n"
        
        # Add common joins
        if table_info['common_joins']:
            formatted_text += "Common Joins:\n"
            for join in table_info['common_joins']:
                formatted_text += f"- {join}\n"
            formatted_text += "\n"
        
        formatted_text += "---\n\n"
    
    return formatted_text

In [18]:
def create_llm_prompt(user_question, metadata_text):
    """Create the complete prompt to send to the LLM"""
    prompt = f"""You are a professional medical database expert specializing in SQL and the MIMIC-IV database. Based on the user's question and the provided database metadata, generate a PostgreSQL query.

## User Question
{user_question}

## Database Metadata
{metadata_text}

## Task
1. Analyze the user question to determine which tables and columns need to be queried
2. Design an effective SQL query based on the provided metadata
3. Ensure the generated SQL is syntactically correct and considers table relationships
4. If multiple table joins are needed, use the correct join conditions
5. Handle any potential edge cases

## Response Format
Please return ONLY the SQL query without any explanation. Start your answer with "SELECT".

SQL Query:
"""
    return prompt

In [19]:
def generate_sql_with_openai(prompt):
    """Generate SQL query using OpenAI API (updated for version 1.0+)"""
    try:
        # Using the new client format
        response = client.chat.completions.create(
            model="gpt-4",  # or another suitable model
            messages=[
                {"role": "system", "content": "You are a medical database expert who converts natural language questions into PostgreSQL queries."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.1,  # Low temperature for more deterministic output
            max_tokens=500
        )
        
        # Updated way to access response content
        sql_query = response.choices[0].message.content.strip()
        
        # Remove "SQL Query:" prefix if present
        if sql_query.startswith("SQL Query:"):
            sql_query = sql_query[len("SQL Query:"):].strip()
        
        return sql_query
    
    except Exception as e:
        print(f"Error calling OpenAI API: {e}")
        return None


def execute_sql_query(sql_query):
    """Execute SQL query and return results"""
    conn, cur = get_db_connection()
    
    try:
        cur.execute(sql_query)
        
        # Get column names
        column_names = [desc[0] for desc in cur.description]
        
        # Get query results
        results = cur.fetchall()
        
        # Convert results to DataFrame
        df = pd.DataFrame(results, columns=column_names)
        
        return df
    
    except Exception as e:
        print(f"Error executing SQL query: {e}")
        return None
    
    finally:
        cur.close()
        conn.close()

In [20]:
def sqlrag_pipeline(user_question):
    """Execute the complete SQLRAG pipeline"""
    print(f"User Question: {user_question}")
    print("Fetching metadata...")
    
    # 1. Fetch metadata
    metadata = fetch_all_metadata()
    
    # 2. Format metadata for prompt
    metadata_text = format_metadata_for_prompt(metadata)
    
    # 3. Create LLM prompt
    prompt = create_llm_prompt(user_question, metadata_text)
    
    print("Generating SQL with LLM...")
    # 4. Generate SQL query
    sql_query = generate_sql_with_openai(prompt)
    
    if not sql_query:
        return "Failed to generate SQL query"
    
    print(f"Generated SQL: \n{sql_query}\n")
    
    print("Executing SQL query...")
    # 5. Execute SQL query
    results = execute_sql_query(sql_query)
    
    if results is None:
        return "Error executing SQL query"
    
    print("Query Results:")
    print(results)
    
    return {
        "user_question": user_question,
        "generated_sql": sql_query,
        "results": results
    }

In [22]:
# Test with user question
if __name__ == "__main__":
    # Example question
    question = "Which patient has the highest number of hospital admissions?"
    
    # Run SQLRAG pipeline
    result = sqlrag_pipeline(question)

User Question: Which patient has the highest number of hospital admissions?
Fetching metadata...
Generating SQL with LLM...
Generated SQL: 
SELECT subject_id, COUNT(hadm_id) AS num_admissions
FROM admissions
GROUP BY subject_id
ORDER BY num_admissions DESC
LIMIT 1;

Executing SQL query...
Query Results:
   subject_id  num_admissions
0    15496609             238
