In [315]:
import pandas as pd
import numpy as np
import json
import psycopg2
from psycopg2.extras import execute_values
from openai import OpenAI
import os
import time
from dotenv import load_dotenv

# database connection

In [316]:
# Load environment variables from .env file
load_dotenv()

# OpenAI API configuration
# Initialize the client
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

# PostgreSQL database connection
DB_PARAMS = {
    "dbname": "mydatabase",
    "user": "myuser",
    "password": "mypassword",
    "host": "localhost",
    "port": "5433"  # Ensure this matches your PostgreSQL container port
}

In [317]:
def get_db_connection():
    """Create and return database connection and cursor"""
    conn = psycopg2.connect(**DB_PARAMS)
    cur = conn.cursor()
    return conn, cur

def fetch_all_metadata():
    """Fetch metadata for all tables"""
    conn, cur = get_db_connection()
    
    try:
        # Get metadata for all tables
        cur.execute("""
            SELECT table_name, description, table_purpose, columns_info, 
                   primary_keys, foreign_keys, important_considerations,
                   common_joins, example_questions
            FROM mimic_table_metadata;
        """)
        
        all_metadata = cur.fetchall()
        
        # Format metadata as dictionary
        tables_metadata = {}
        for row in all_metadata:
            table_name = row[0]
            tables_metadata[table_name] = {
                'description': row[1],
                'table_purpose': row[2],
                'columns_info': row[3],
                'primary_keys': row[4],
                'foreign_keys': row[5],
                'important_considerations': row[6],
                'common_joins': row[7],
                'example_questions': row[8]
            }
        
        return tables_metadata
    
    finally:
        cur.close()
        conn.close()

# Format metadata into text suitable for LLM prompt

In [318]:
def format_metadata_for_prompt(metadata):
    """Format metadata into text suitable for LLM prompt"""
    formatted_text = "# Database Schema Information\n\n"
    
    for table_name, table_info in metadata.items():
        formatted_text += f"## Table: {table_name}\n"
        formatted_text += f"Description: {table_info['description']}\n"
        formatted_text += f"Purpose: {table_info['table_purpose']}\n\n"
        
        # Add primary key information
        if table_info['primary_keys']:
            formatted_text += f"Primary Keys: {', '.join(table_info['primary_keys'])}\n\n"
        
        # Add foreign key information
        if table_info['foreign_keys']:
            formatted_text += "Foreign Keys:\n"
            for fk_col, fk_info in table_info['foreign_keys'].items():
                formatted_text += f"- {fk_col} references {fk_info['table']}.{fk_info['column']}\n"
            formatted_text += "\n"
        
        # Add column information
        formatted_text += "Columns:\n"
        for col_name, col_info in table_info['columns_info'].items():
            formatted_text += f"- {col_name} ({col_info['data_type']}): {col_info['description']}\n"
            
            # Add categorical value distribution if available and not too long
            if 'categorical_values' in col_info and len(col_info['categorical_values']) < 15:
                formatted_text += f"  Possible values: {', '.join(col_info['categorical_values'])}\n"
            
            # Add value range if available
            if 'value_range' in col_info:
                formatted_text += f"  Range: {col_info['value_range']['min']} to {col_info['value_range']['max']}\n"
        
        formatted_text += "\n"
        
        # Add important considerations
        if table_info['important_considerations']:
            formatted_text += f"Important Considerations: {table_info['important_considerations']}\n\n"
        
        # Add common joins
        if table_info['common_joins']:
            formatted_text += "Common Joins:\n"
            for join in table_info['common_joins']:
                formatted_text += f"- {join}\n"
            formatted_text += "\n"
        
        formatted_text += "---\n\n"
    
    return formatted_text

# Fetch metadata for tables most relevant to the query embedding

In [None]:
def fetch_relevant_metadata(query_embedding, top_k):
    """
    Fetch metadata for tables most relevant to the query embedding
    
    Parameters:
        query_embedding (list): Vector representation of the user query
        top_k (int): Number of most relevant tables to return
        
    Returns:
        dict: Dictionary mapping table names to their metadata with similarity scores
    """
    conn, cur = get_db_connection()
    
    try:
        print(f"Computing similarity to find top {top_k} relevant tables")
        cur.execute("""
            SELECT 
                table_name, description, table_purpose, columns_info, 
                primary_keys, foreign_keys, important_considerations,
                common_joins, example_questions, embedding
            FROM mimic_table_metadata
            WHERE embedding IS NOT NULL;
        """)
        
        rows = cur.fetchall()
        
        # Calculate similarity for each table
        table_similarities = []
        for row in rows:
            table_name = row[0]
            table_embedding = row[9]
            
            # Skip if embedding is NULL
            if table_embedding is None:
                continue
            
            # Convert string embedding to list of floats if needed
            if isinstance(table_embedding, str):
                import json
                table_embedding = json.loads(table_embedding.replace("'", '"'))
            
            # Calculate cosine similarity
            similarity = cosine_similarity(query_embedding, table_embedding)
            table_similarities.append((row, similarity))
        
        # Sort by similarity (descending) and take top_k
        table_similarities.sort(key=lambda x: x[1], reverse=True)
        top_tables = table_similarities[:top_k]
        
        # Format as dictionary
        tables_metadata = {}
        for row, similarity in top_tables:
            table_name = row[0]
            tables_metadata[table_name] = {
                'description': row[1],
                'table_purpose': row[2],
                'columns_info': row[3],
                'primary_keys': row[4],
                'foreign_keys': row[5],
                'important_considerations': row[6],
                'common_joins': row[7],
                'example_questions': row[8],
                'similarity_score': similarity  # Add similarity score to metadata
            }
            
            # Print similarity for debugging
            print(f"Table: {table_name}, Similarity: {similarity:.4f}")
        
        return tables_metadata
    
    finally:
        cur.close()
        conn.close()

# Compute cosine similarity between two vectors

In [320]:
def cosine_similarity(vec1, vec2):
    """
    Compute cosine similarity between two vectors
    
    Parameters:
        vec1 (list): First vector
        vec2 (list): Second vector
        
    Returns:
        float: Cosine similarity (between -1 and 1)
    """
    import numpy as np
    
    # Convert to numpy arrays
    vec1 = np.array(vec1)
    vec2 = np.array(vec2)
    
    # Compute cosine similarity
    dot_product = np.dot(vec1, vec2)
    norm_a = np.linalg.norm(vec1)
    norm_b = np.linalg.norm(vec2)
    
    if norm_a == 0 or norm_b == 0:
        return 0
    
    return dot_product / (norm_a * norm_b)

# Create the prompt to send to the LLM

In [321]:
def create_llm_prompt(user_question, metadata_text):
    """Create the complete prompt to send to the LLM with improved instructions"""
    # Add explicit column information to the prompt
    prompt = f"""You are a professional medical database expert specializing in SQL and the MIMIC-IV database. Based on the user's question and the provided database metadata, generate a PostgreSQL query.

## User Question
{user_question}

## Database Metadata
{metadata_text}

## Task
1. Analyze the user question to determine which tables and columns need to be queried
2. Design an effective SQL query based on the provided metadata
3. Ensure the generated SQL is syntactically correct and considers table relationships
4. Use ONLY columns that are explicitly mentioned in the metadata for each table
5. If multiple table joins are needed, use the correct join conditions
6. Handle any potential edge cases
7. When dealing with medical codes (ICD diagnosis/procedure codes, medication codes), always join with their respective descriptor tables (d_icd_diagnoses, d_icd_procedures) to include both codes AND their human-readable descriptions
8. For medications, include the actual drug names from prescriptions.drug or emar.medication rather than just codes

## Response Format
Please return ONLY the SQL query without any explanation or comments. Start your answer with "SELECT" or "WITH" and end with a semicolon. Do not include anything else.

SQL Query:
"""
    return prompt

# Generate SQL query using OpenAI API

In [322]:
def generate_sql_with_openai(prompt):
    """Generate SQL query using OpenAI API with improved cleaning and validation"""
    try:
        # Using the new client format
        response = client.chat.completions.create(
            model="gpt-4",  # or another suitable model
            messages=[
                {"role": "system", "content": "You are a medical database expert who converts natural language questions into PostgreSQL queries. Return ONLY the SQL query with no explanations or comments."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.1,  # Low temperature for more deterministic output
            max_tokens=500
        )
        
        # Get raw content
        raw_content = response.choices[0].message.content.strip()
        
        # More comprehensive cleaning of markdown and prefixes
        # Remove common prefixes
        prefixes = ["SQL Query:", "Query:", "PostgreSQL Query:"]
        for prefix in prefixes:
            if raw_content.startswith(prefix):
                raw_content = raw_content[len(prefix):].strip()
        
        # Remove markdown code blocks (handling various formats)
        import re
        sql_query = re.sub(r'```(?:sql|postgresql)?|```', '', raw_content)
        sql_query = sql_query.strip()
        
        # Find the termination point of the SQL part - look for typical SQL statement ending (semicolon) followed by a newline
        # This will remove explanatory text after the query
        match = re.search(r';[\s\n]*(\n|$)', sql_query)
        if match:
            # Only keep the part up to the semicolon
            sql_query = sql_query[:match.end()].strip()
        
        # Basic SQL syntax validation
        if not sql_query.lower().startswith(('select', 'with')):
            print("Warning: Generated SQL may not be valid. It doesn't start with SELECT or WITH.")
            
        # Log the cleaned query for debugging
        print(f"Cleaned SQL query: {sql_query[:100]}...")
        
        return sql_query
    
    except Exception as e:
        print(f"Error calling OpenAI API: {e}")
        return None

# Check SQL query

In [323]:
def validate_table_structure(table_name):
    """Get the actual column structure of a table, returns a list of column names"""
    conn, cur = get_db_connection()
    
    try:
        # Get column names for the table
        cur.execute(f"""
            SELECT column_name 
            FROM information_schema.columns 
            WHERE table_name = '{table_name}'
        """)
        
        columns = [row[0] for row in cur.fetchall()]
        print(f"Columns in table {table_name}: {', '.join(columns)}")
        return columns
    
    except Exception as e:
        print(f"Error getting structure for table {table_name}: {e}")
        return []
    finally:
        cur.close()
        conn.close()

def check_query_columns(sql_query):
    """Analyze SQL query, validate that all referenced tables and columns exist"""
    import re
    
    # Extract tables used in the query
    from_pattern = re.compile(r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_]*)', re.IGNORECASE)
    join_pattern = re.compile(r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_]*)', re.IGNORECASE)
    
    tables = from_pattern.findall(sql_query) + join_pattern.findall(sql_query)
    tables = list(set(tables))  # Remove duplicates
    
    # Get the actual column structure for each table
    table_columns = {}
    for table in tables:
        table_columns[table] = validate_table_structure(table)
    
    # A simple method to find column references in the query
    # This is a simplified version; a complete SQL parser would be needed for full accuracy
    potential_issues = []
    
    for table in tables:
        columns = table_columns[table]
        # Look for patterns like "table.column"
        table_column_pattern = re.compile(rf'\b{table}\.([a-zA-Z_][a-zA-Z0-9_]*)', re.IGNORECASE)
        referenced_columns = table_column_pattern.findall(sql_query)
        
        for col in referenced_columns:
            if col not in columns:
                potential_issues.append(f"Warning: Column '{col}' does not exist in table '{table}'")
    
    return potential_issues, table_columns

# Attempt to automatically fix column reference issues in the SQL query

In [324]:
def attempt_fix_sql(sql_query, table_columns):
    """Attempt to automatically fix column reference issues in the SQL query"""
    prompt = f"""As a database expert, please fix the following SQL query to make it compatible with the provided table structures.
    
SQL query:
```sql
{sql_query}
```

Table structure information:
"""
    
    # Add table structure information
    for table, columns in table_columns.items():
        prompt += f"\nTable '{table}' columns: {', '.join(columns)}"
    
    prompt += """

Please provide the fixed SQL query, ensuring that:
1. Only use columns that exist in the tables
2. Fix any mismatches in join conditions
3. Maintain the basic logic and purpose of the query
4. Return only the fixed SQL query, without any explanations or comments

Fixed SQL query:
"""
    
    try:
        response = client.chat.completions.create(
            model="gpt-4",
            messages=[
                {"role": "system", "content": "You are an SQL expert, specializing in fixing errors in SQL queries."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.1
        )
        
        fixed_sql = response.choices[0].message.content.strip()
        
        # Clean formatting
        import re
        fixed_sql = re.sub(r'```(?:sql|postgresql)?|```', '', fixed_sql)
        fixed_sql = fixed_sql.strip()
        
        return fixed_sql
    
    except Exception as e:
        print(f"Error while attempting to fix SQL: {e}")
        return None

# Execute SQL query and return results

In [325]:
def execute_sql_query(sql_query, timeout_seconds=3000):
    """Execute SQL query with timeout and improved error handling"""
    conn = None
    cur = None
    
    try:
        # Get a fresh connection
        conn = psycopg2.connect(**DB_PARAMS)
        
        # Enable autocommit for session parameter changes
        conn.autocommit = True
        
        # Create cursor
        cur = conn.cursor()
        
        # Set statement timeout before starting transaction
        cur.execute(f"SET statement_timeout = {timeout_seconds * 1000};")  # milliseconds
        
        # Switch to transaction mode for the actual query
        conn.autocommit = False
        
        # Log the query being executed
        print(f"Executing SQL (with {timeout_seconds}s timeout): {sql_query[:200]}...")
        
        # Execute the query
        cur.execute(sql_query)
        
        # Get column names if the query returns results
        if cur.description:
            column_names = [desc[0] for desc in cur.description]
            
            # Fetch results with a row limit to avoid memory issues
            results = []
            while True:
                batch = cur.fetchmany(1000)  # Fetch in batches
                if not batch:
                    break
                results.extend(batch)
                
                # Check if we've fetched enough rows
                if len(results) >= 10000:  # Set a reasonable maximum
                    print("Warning: Query returned more than 10,000 rows, truncating results")
                    break
            
            # Commit transaction
            conn.commit()
            
            # Convert results to DataFrame
            df = pd.DataFrame(results, columns=column_names)
            
            print(f"Query returned {len(df)} rows and {len(df.columns)} columns")
            
            return df
        else:
            # For queries that don't return results (e.g., INSERT, UPDATE)
            conn.commit()
            print("Query executed successfully (no results returned)")
            return pd.DataFrame()  # Empty DataFrame
    
    except psycopg2.Error as e:
        if conn:
            try:
                conn.rollback()
            except:
                pass
        
        error_msg = f"Error executing SQL query: {e}"
        print(error_msg)
        return None
    
    finally:
        # Clean up resources
        if cur:
            try:
                # Reset statement timeout if possible
                if conn and conn.status == psycopg2.extensions.STATUS_READY:
                    conn.autocommit = True
                    cur.execute("RESET statement_timeout;")
            except:
                pass
            cur.close()
        
        if conn:
            conn.close()

# Convert a user's natural language query into a vector representation

In [326]:
def vectorize_user_query(query_text):
    """
    Convert a user's natural language query into a vector representation
    
    Parameters:
        query_text (str): The user's natural language query
        
    Returns:
        list: Vector representation of the query
    """
    try:
        # Preprocess the query text - more processing steps can be added as needed
        processed_query = query_text.strip()
        
        # Generate embedding vector using OpenAI API
        response = client.embeddings.create(
            input=processed_query,
            model="text-embedding-3-small"  # Use the same model as for table embeddings
        )
        
        # Extract the embedding vector
        query_embedding = response.data[0].embedding
        
        print(f"✅ Successfully vectorized query: '{query_text[:50]}...' if len(query_text) > 50 else query_text")
        return query_embedding
    
    except Exception as e:
        print(f"❌ Error vectorizing query: {e}")
        return None

# Format a simplified version of metadata for the answer generation prompt

In [327]:
def format_metadata_for_answer(metadata):
    """Format a simplified version of metadata for the answer generation prompt"""
    formatted_text = ""
    
    for table_name, table_info in metadata.items():
        formatted_text += f"Table: {table_name}\n"
        formatted_text += f"Description: {table_info['description']}\n\n"
        
        # Add key columns (simplified)
        formatted_text += "Key columns:\n"
        for col_name, col_info in table_info['columns_info'].items():
            if col_name in table_info.get('primary_keys', []) or 'key' in col_name.lower():
                formatted_text += f"- {col_name}: {col_info['description']}\n"
    
    return formatted_text

# Generate a comprehensive natural language answer based on query results

In [328]:
def generate_natural_language_answer(user_question, metadata, sql_query, query_results):
    """
    Generate a comprehensive natural language answer based on query results
    
    Parameters:
        user_question (str): Original user question
        metadata (dict): Metadata of relevant tables used in the query
        sql_query (str): Generated SQL query
        query_results: DataFrame or other format containing query results
        
    Returns:
        str: Natural language answer explaining the results
    """
    # Convert query results to a suitable format for LLM
    if isinstance(query_results, pd.DataFrame):
        # For large DataFrames, include sample and summary statistics
        if len(query_results) > 10:
            results_text = f"Results contain {len(query_results)} rows.\n\n"
            results_text += "Sample of first 5 rows:\n"
            results_text += query_results.head(5).to_string() + "\n\n"
            
            # Add summary statistics if numerical columns exist
            if any(query_results.dtypes.apply(lambda x: np.issubdtype(x, np.number))):
                results_text += "Summary statistics:\n"
                results_text += query_results.describe().to_string()
        else:
            results_text = "Results:\n" + query_results.to_string()
    else:
        results_text = f"Results: {str(query_results)}"
    
    # Create a comprehensive prompt for the LLM
    prompt = f"""
    User question: {user_question}
    
    Database information used:
    {format_metadata_for_answer(metadata)}
    
    SQL query executed:
    ```sql
    {sql_query}
    ```
    
    {results_text}
    
    Based on the above information, please provide a comprehensive answer to the user's question.
    Explain the results in natural language, highlighting key insights, patterns, or important values.
    If appropriate, suggest any follow-up analyses that might be valuable.
    """
    
    # Call the LLM to generate the answer
    response = client.chat.completions.create(
        model="gpt-4",
        messages=[
            {"role": "system", "content": "You are a helpful assistant that explains database query results clearly."},
            {"role": "user", "content": prompt}
        ]
    )
    
    return response.choices[0].message.content

#  Classify the query to identify if it's answerable with available data

In [329]:
def classify_query(user_question, metadata):
    """
    Classify the query to identify if it's answerable with available data
    
    Parameters:
        user_question (str): The user's question
        metadata (dict): Metadata of most relevant tables
        
    Returns:
        dict: Classification results including status and reason
    """
    # Create a prompt for query classification
    tables_summary = "\n".join([f"- {table}: {info['description']}" 
                              for table, info in metadata.items()])
    
    prompt = f"""
    Based on the following database tables from a MEDICAL DATABASE (MIMIC-IV 2.2), 
    classify if this question can be answered:
    
    USER QUESTION: {user_question}
    
    AVAILABLE TABLES:
    {tables_summary}
    
    Please classify this question as one of:
    1. "answerable": Can be answered with the available tables
    2. "out_of_scope": Relates to medical data but not available in these tables
    3. "non_medical": Not related to medical data at all
    4. "future_data": Requires data from after the database collection period
    5. "private_data": Asks for personally identifiable information
    
    You must respond in valid JSON format with exactly these fields:
    {{
      "status": "one of the options above",
      "reason": "brief explanation why you classified it this way",
      "message": "user-friendly message explaining if/why the question can't be answered"
    }}
    """
    
    # Call LLM to classify, without specifying response_format
    response = client.chat.completions.create(
        model="gpt-4",
        messages=[
            {"role": "system", "content": "You are a helpful assistant that classifies database queries. Always respond in valid JSON format."},
            {"role": "user", "content": prompt}
        ]
    )
    
    # Parse the JSON response
    import json
    try:
        classification = json.loads(response.choices[0].message.content)
    except json.JSONDecodeError:
        # Fallback in case the response isn't valid JSON
        return {
            "classification": "answerable",  # Default to answerable
            "reason": "Failed to parse classification response",
            "message": "I'll try to answer your question with the available data.",
            "status": "supported"
        }
    
    # Map the classification to pipeline control values
    result = {
        "classification": classification.get("status", "answerable"),
        "reason": classification.get("reason", "No reason provided"),
        "message": classification.get("message", "No message provided")
    }
    
    # Set overall status
    if result["classification"] == "answerable":
        result["status"] = "supported"
    else:
        result["status"] = "not_supported"
    
    return result

# Extract the highest similarity score and corresponding table

In [330]:
def get_highest_similarity(metadata):
    """Extract the highest similarity score and corresponding table"""
    best_table = None
    best_score = 0
    
    for table_name, info in metadata.items():
        if "similarity_score" in info and info["similarity_score"] > best_score:
            best_score = info["similarity_score"]
            best_table = table_name
    
    return best_table, best_score

# Pipeline

In [331]:
def sqlrag_pipeline(user_question):
    """Execute the complete SQLRAG pipeline with enhanced safety checks"""
    print(f"User Question: {user_question}")
    
    # 1. Vectorize user query
    print("Vectorizing user query...")
    query_embedding = vectorize_user_query(user_question)
    
    if not query_embedding:
        return {"error": "Failed to vectorize user query"}
    
    # 2. Fetch relevant metadata using vector similarity
    print("Finding relevant tables...")
    metadata = fetch_relevant_metadata(query_embedding, top_k=10)
    
    # 3. Check if any relevant tables were found with good similarity
    if not metadata:
        return {
            "user_question": user_question,
            "error": "No relevant tables found in the database for this query.",
            "answer": "I don't have the necessary data to answer this question. The database doesn't contain information related to your query."
        }
    
    # 4. Check similarity scores to ensure they're above threshold
    # Get the highest similarity score
    best_table, best_similarity = get_highest_similarity(metadata)
    if best_similarity < 0.2:  # Adjust threshold as needed
        return {
            "user_question": user_question,
            "best_match": best_table,
            "similarity": best_similarity,
            "error": "The query doesn't seem to match well with available data.",
            "answer": f"Your question might not be answerable with the available medical data. The closest match I found was related to '{best_table}' but the relevance is low."
        }
    
    # 5. Use query classifier to identify query intent and feasibility
    query_classification = classify_query(user_question, metadata)
    if query_classification["status"] == "not_supported":
        return {
            "user_question": user_question,
            "error": query_classification["reason"],
            "answer": query_classification["message"]
        }
    
    # 6. Format metadata for prompt - ensure metadata includes complete column information
    metadata_text = format_metadata_for_prompt(metadata)
    
    # 7. Create LLM prompt with improved instructions
    prompt = create_llm_prompt(user_question, metadata_text)
    
    print("Generating SQL with LLM...")
    # 8. Generate SQL query
    sql_query = generate_sql_with_openai(prompt)
    
    if not sql_query:
        return {
            "user_question": user_question,
            "error": "Failed to generate SQL query",
            "answer": "I couldn't generate a SQL query to answer your question. Please try rephrasing it."
        }
    
    print(f"Generated SQL: \n{sql_query}\n")
    
    # 9. New step: Validate the generated SQL against table structure
    print("Validating SQL against database structure...")
    issues, table_columns = check_query_columns(sql_query)
    
    if issues:
        print("Potential issues detected:")
        for issue in issues:
            print(f"  - {issue}")
        
        # Try to automatically fix the SQL
        print("Attempting to fix SQL...")
        fixed_sql = attempt_fix_sql(sql_query, table_columns)
        
        if fixed_sql:
            print(f"Fixed SQL: \n{fixed_sql}\n")
            sql_query = fixed_sql
        else:
            return {
                "user_question": user_question, 
                "error": f"Generated SQL query is incompatible with database structure: {'; '.join(issues)}",
                "generated_sql": sql_query,
                "answer": "I couldn't generate a valid query compatible with the database structure. There might be a mismatch in my understanding of the database schema."
            }
    
    print("Executing SQL query...")
    # 10. Execute SQL query
    results = execute_sql_query(sql_query)
    
    if results is None:
        return {
            "user_question": user_question,
            "error": "Error executing SQL query",
            "generated_sql": sql_query,
            "answer": "I apologize, but I couldn't execute the generated SQL query. There might be an issue with the database or the query structure."
        }
    
    print("Generating natural language answer...")
    # 11. Generate natural language answer
    answer = generate_natural_language_answer(user_question, metadata, sql_query, results)
    
    return {
        "user_question": user_question,
        "generated_sql": sql_query,
        "results": results,
        "answer": answer
    }

# Test1

In [332]:
def extract_tables_from_sql(sql_query):
    """Extract table names used in a SQL query"""
    # Simple regex-based extraction - could be enhanced
    import re
    from_pattern = re.compile(r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_]*)', re.IGNORECASE)
    join_pattern = re.compile(r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_]*)', re.IGNORECASE)
    
    tables = from_pattern.findall(sql_query) + join_pattern.findall(sql_query)
    return list(set(tables))  # Remove duplicates

def measure_execution_time(sql_query):
    """Measure execution time of a SQL query"""
    import time
    conn, cur = get_db_connection()
    
    try:
        start_time = time.time()
        cur.execute(sql_query)
        results = cur.fetchall()
        end_time = time.time()
        
        return end_time - start_time
    finally:
        cur.close()
        conn.close()

In [333]:
def test_comprehensive_sqlrag():
    """Test the enhanced SQLRAG pipeline with a complex medical query"""
    
    # Complex query that requires multiple table joining and medical domain knowledge
    complex_query = """
    Which patients had the highest number of hospital readmissions within 30 days of discharge, 
    and what were their most common diagnoses and prescribed medications? 
    Also analyze if there's any correlation between length of stay and readmission rates.
    """
    
    print("\n" + "="*80)
    print("COMPREHENSIVE TEST CASE FOR MIMIC-IV 2.2")
    print("="*80 + "\n")
    print(f"QUERY: {complex_query}\n")
    
    # Run enhanced SQLRAG pipeline
    result = sqlrag_pipeline(complex_query)
    
    # Check for errors in the result
    if "error" in result:
        print("\nERROR DETECTED:")
        print("-"*50)
        print(f"Error: {result['error']}")
        print(f"Answer: {result.get('answer', 'No answer provided')}")
        return result
    
    if isinstance(result, str):
        print("\nERROR AS STRING:")
        print("-"*50)
        print(result)
        
    result = {
        "error": result,
        "answer": "An error occurred during query processing."
    }

    # Print the answer
    print("\nNATURAL LANGUAGE ANSWER:")
    print("-"*50)
    print(result.get("answer", "No answer generated"))
    
    # Only continue with analysis if we have SQL and results
    if "generated_sql" in result and "results" in result:
        print("\nQUERY QUALITY ANALYSIS:")
        print("-"*50)
        
        # 1. Check tables used in SQL
        tables_used = extract_tables_from_sql(result["generated_sql"])
        print(f"Tables used in query: {', '.join(tables_used)}")
        
        # 2. Execution metrics
        execution_time = measure_execution_time(result["generated_sql"])
        print(f"Query execution time: {execution_time:.2f} seconds")
        
        # 3. Result size check
        if isinstance(result["results"], pd.DataFrame):
            print(f"Result size: {len(result['results'])} rows, {len(result['results'].columns)} columns")
    else:
        print("\nSKIPPING QUERY ANALYSIS: No SQL or results available")
    
    return result

In [334]:
test_comprehensive_sqlrag()


COMPREHENSIVE TEST CASE FOR MIMIC-IV 2.2

QUERY: 
    Which patients had the highest number of hospital readmissions within 30 days of discharge, 
    and what were their most common diagnoses and prescribed medications? 
    Also analyze if there's any correlation between length of stay and readmission rates.
    

User Question: 
    Which patients had the highest number of hospital readmissions within 30 days of discharge, 
    and what were their most common diagnoses and prescribed medications? 
    Also analyze if there's any correlation between length of stay and readmission rates.
    
Vectorizing user query...
✅ Successfully vectorized query: '
    Which patients had the highest number of hosp...' if len(query_text) > 50 else query_text
Finding relevant tables...
Computing similarity to find top 10 relevant tables
Table: admissions, Similarity: 0.4974
Table: diagnoses_icd, Similarity: 0.4940
Table: transfers, Similarity: 0.4487
Table: drgcodes, Similarity: 0.4329
Table: proce

{'user_question': "\n    Which patients had the highest number of hospital readmissions within 30 days of discharge, \n    and what were their most common diagnoses and prescribed medications? \n    Also analyze if there's any correlation between length of stay and readmission rates.\n    ",
 'error': 'Error executing SQL query',
 'generated_sql': "WITH readmissions AS (\n    SELECT a1.subject_id, COUNT(*) AS readmission_count\n    FROM admissions a1\n    JOIN admissions a2 ON a1.subject_id = a2.subject_id AND a1.hadm_id <> a2.hadm_id\n    WHERE a2.admittime BETWEEN a1.dischtime AND a1.dischtime + INTERVAL '30 days'\n    GROUP BY a1.subject_id\n),\ndiagnoses AS (\n    SELECT d.subject_id, d.icd_code, COUNT(*) AS diagnosis_count\n    FROM diagnoses_icd d\n    JOIN readmissions r ON d.subject_id = r.subject_id\n    GROUP BY d.subject_id, d.icd_code\n),\nmedications AS (\n    SELECT p.subject_id, p.drug, COUNT(*) AS medication_count\n    FROM prescriptions p\n    JOIN readmissions r ON p.

# Test2

In [337]:
def test_complex_medical_analysis():
    """Test SQLRAG with a complex medical analysis query"""
    
    complex_query = """
    For patients diagnosed with sepsis (ICD code 99591 or A41.9), 
    what is the average length of ICU stay, mortality rate, 
    and what are the top 5 most commonly administered medications?
    """
    
    print("\n" + "="*80)
    print("COMPLEX MEDICAL ANALYSIS TEST CASE")
    print("="*80 + "\n")
    print(f"QUERY: {complex_query}\n")
    
    # Run the SQLRAG pipeline
    result = sqlrag_pipeline(complex_query)
    
    # Handle different return types
    if isinstance(result, str):
        print("\nERROR OR STRING RESULT:")
        print("-"*50)
        print(result)
        return result
    
    # Check for errors in dictionary result
    if isinstance(result, dict) and "error" in result:
        print("\nERROR DETECTED:")
        print("-"*50)
        print(f"Error: {result['error']}")
        print(f"Answer: {result.get('answer', 'No answer provided')}")
        return result
    
    # Print the results (if dictionary)
    if isinstance(result, dict):
        print("\nNATURAL LANGUAGE ANSWER:")
        print("-"*50)
        print(result.get("answer", "No answer generated"))
        
        # Query analysis
        if "generated_sql" in result and "results" in result:
            print("\nQUERY QUALITY ANALYSIS:")
            print("-"*50)
            
            # Tables used
            tables_used = extract_tables_from_sql(result["generated_sql"])
            print(f"Tables used in query: {', '.join(tables_used)}")
            
            # Execution metrics
            execution_time = measure_execution_time(result["generated_sql"])
            print(f"Query execution time: {execution_time:.2f} seconds")
            
            # Results size
            if isinstance(result["results"], pd.DataFrame):
                print(f"Result size: {len(result['results'])} rows, {len(result['results'].columns)} columns")
        else:
            print("\nSKIPPING QUERY ANALYSIS: No SQL or results available")
    
    return result

In [338]:
test_complex_medical_analysis()


COMPLEX MEDICAL ANALYSIS TEST CASE

QUERY: 
    For patients diagnosed with sepsis (ICD code 99591 or A41.9), 
    what is the average length of ICU stay, mortality rate, 
    and what are the top 5 most commonly administered medications?
    

User Question: 
    For patients diagnosed with sepsis (ICD code 99591 or A41.9), 
    what is the average length of ICU stay, mortality rate, 
    and what are the top 5 most commonly administered medications?
    
Vectorizing user query...
✅ Successfully vectorized query: '
    For patients diagnosed with sepsis (ICD code ...' if len(query_text) > 50 else query_text
Finding relevant tables...
Computing similarity to find top 10 relevant tables
Table: diagnoses_icd, Similarity: 0.4668
Table: procedures_icd, Similarity: 0.4274
Table: transfers, Similarity: 0.4018
Table: d_icd_diagnoses, Similarity: 0.3855
Table: hcpcsevents, Similarity: 0.3839
Table: services, Similarity: 0.3797
Table: drgcodes, Similarity: 0.3590
Table: microbiologyevents, Sim

{'user_question': '\n    For patients diagnosed with sepsis (ICD code 99591 or A41.9), \n    what is the average length of ICU stay, mortality rate, \n    and what are the top 5 most commonly administered medications?\n    ',
 'generated_sql': "WITH sepsis_patients AS (\n    SELECT subject_id, hadm_id\n    FROM diagnoses_icd\n    WHERE icd_code IN ('99591', 'A419')\n),\nicu_stays AS (\n    SELECT subject_id, hadm_id, AVG(EXTRACT(EPOCH FROM (outtime - intime))/3600) AS avg_icu_length\n    FROM transfers\n    WHERE subject_id IN (SELECT subject_id FROM sepsis_patients)\n    GROUP BY subject_id, hadm_id\n),\nmortality_rate AS (\n    SELECT COUNT(*) FILTER (WHERE dod IS NOT NULL) * 1.0 / COUNT(*) AS mortality_rate\n    FROM patients\n    WHERE subject_id IN (SELECT subject_id FROM sepsis_patients)\n),\nmedications AS (\n    SELECT hcpcs_cd, COUNT(*) AS count\n    FROM hcpcsevents\n    WHERE subject_id IN (SELECT subject_id FROM sepsis_patients)\n    GROUP BY hcpcs_cd\n    ORDER BY count DE