In [12]:
import dspy
import re
import pandas as pd
import sqlite3
import os
from typing import List, Optional, Dict, Any

In [13]:
from helper import openai_api_key
os.environ["OPENAI_API_KEY"] = openai_api_key

In [14]:
class TextToSql(dspy.Signature):
    """
    Converts natural language queries into SQL queries.
    """
    sql_prompt = dspy.InputField(desc ="Natural language query")
    sql_context = dspy.InputField(desc="Context for the query including database schema")
    sql: str = dspy.OutputField(desc="Generated SQL query")

In [15]:
class HealthcareDBTool:
    """Tool for interacting with a healthcare database using LLM-generated SQL queries."""

    def __init__(self, db_path: str):
        self.db_path = db_path
        self.text_to_sql = dspy.Predict(TextToSql)

    def get_schema_context(self) -> str:
        """Get database schema information for context."""
        schema_info = """
        Database Schema:
        Table: patients
        Columns:
        - Name(TEXT): Patient's full name
        - "Medical Condition"(TEXT): Patientt's medical condition (e.g., Cancer , Diabetes ,Hypertension ,Asthma ,Arthritis)
        - Medication(TEXT): Prescribed Medication
        - "Test Results"(TEXT): Test results (Normal , Abnormal ,Inconclusive)
        - Age(INTEGER): Patient's age
        -Gender(TEXT): Patient's gender (Male ,Female)
        - "Date of Admission"(TEXT): Date of admission in YYYY-MM-DD format
        
        Important Notes:
        - Column names with spaces must be quoted with double quotes
        - Use LOWER() for case-insensitive string matching
        - Use LIKE with % wildcards for partial string matching
        - Use strftime('%Y', "Date of Admission") to extract year from date
        - Always limit results to prevent excessive output
        """

        return schema_info
    
    def generate_sql_query(self, natural_query: str) -> str:
        """Generate SQL query from natural language input."""
        context = self.get_schema_context()
        
        result = self.text_to_sql(
            sql_prompt=natural_query,
            sql_context=context
        )

        #clean up the SQL query
        sql_query = result.sql.strip()

        if sql_query.startswith('```sql'):
            sql_query = sql_query[6:]
        if sql_query.startswith('```'):
            sql_query = sql_query[3:]
        if sql_query.endswith('```'):
            sql_query = sql_query[:-3]

        return sql_query.strip()

    def execute_sql_query(self, sql_query: str, limit: int = 10) -> List[Dict[str, Any]]:
        """Execute SQL query and return results."""
        try:
            # Add LIMIT if not present in query
            if 'LIMIT' not in sql_query.upper():
                sql_query += f' LIMIT {limit}'
                
            with sqlite3.connect(self.db_path) as conn:
                cursor = conn.execute(sql_query)
                columns = [description[0] for description in cursor.description]
                result = cursor.fetchall()
                
            return [dict(zip(columns, row)) for row in result]
            
        except sqlite3.Error as e:
            print(f"Database error: {e}")
            print(f"Query: {sql_query}")
            return []
        except Exception as e:
            print(f"Unexpected error: {e}")
            print(f"Query: {sql_query}")
            return []
        
    def search_patients_with_llm(self, natural_query: str, limit: int = 5) -> List[Dict[str, Any]]:
        """Search patients using LLM-generated SQL query."""
        sql_query = self.generate_sql_query(natural_query)
        print(f"Generated SQL: {sql_query}")  # For debugging
        return self.execute_sql_query(sql_query, limit)

In [16]:
class PatientSearchSignature(dspy.Signature):
    """Search for patients in the healthcare database based on various criteria."""
    query: str = dspy.InputField(desc="Natural language query about patients")
    results: str = dspy.OutputField(desc="Formatted search results")

In [17]:
class PatientSearchModule(dspy.Module):
    """DSPy module for searching patients using natural language queries."""
    
    def __init__(self, db_tool):
        super().__init__()
        self.db_tool = db_tool

    def forward(self, query: str) -> PatientSearchSignature:
        """Process natural language query and return formatted results."""
        # Use LLM to generate and execute SQL query
        results = self.db_tool.search_patients_with_llm(query)
        
        # Format results as a string
        if not results:
            formatted_results = "No patients found matching the criteria."
        else:
            formatted_results = "Found patients:\n"
            for i, patient in enumerate(results, 1):
                # Handle different column formats that might be returned
                name = patient.get('Name', patient.get('name', 'Unknown'))
                condition = patient.get('Medical Condition', patient.get('medical_condition', 'Unknown'))
                age = patient.get('Age', patient.get('age', 'Unknown'))
                gender = patient.get('Gender', patient.get('gender', 'Unknown'))
                test_results = patient.get('Test Results', patient.get('test_results', 'Unknown'))
                admission_date = patient.get('Date of Admission', patient.get('date_of_admission', 'Unknown'))
                
                formatted_results += f"{i}. {name} - {condition} "
                formatted_results += f"(Age: {age}, Gender: {gender}, "
                formatted_results += f"Test Results: {test_results}, "
                formatted_results += f"Admission: {admission_date})\n"
        
        return dspy.Prediction(results=formatted_results)

In [18]:
def create_data():
    """Load healthcare data from CSV and preprocess it."""
    df = pd.read_csv("/Users/priyanka./Documents/agentic-security/DSPy/healthcare_data.csv")
    df = df.drop(['Billing Amount' ,'Room Number' ,'Insurance Provider' ,'Doctor' ,'Hospital'],axis = 1)
    return df

In [19]:
class MultiServerSignature(dspy.Signature):
    """Define a simple DSPy signature for main flow."""
    request: str = dspy.InputField(desc="The user's request, potentially requiring external tools.")
    response: str = dspy.OutputField(desc="Final response")

In [20]:
async def main():
    """Main function to run the healthcare database query system."""
    try:
        # Configure DSPy
        dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini"))
        
        # Create database
        df = create_data()
        if df is None:
            return
            
        db_path = 'healthcare.db'
        
        conn = sqlite3.connect(db_path)
        df.to_sql('patients', conn, if_exists='replace', index=False)
        conn.close()
        print("Database created successfully!")
        
        # Initialize tools
        db_tool = HealthcareDBTool(db_path)
        patient_module = PatientSearchModule(db_tool)
        
        # Test queries
        test_queries = [
            # "Find patients with abnormal test results in 2024",
            "List all cancer patients above age 40",
            # "Who was admitted for diabetes in 2022?",
            # "List female patients whose test results are abnormal",
            # "Show me all patients with hypertension who are male",
            # "Find patients admitted between 2021 and 2023 with normal test results"
        ]
        
        print("\n" + "="*50)
        print("TESTING HEALTHCARE DATABASE QUERIES WITH LLM-GENERATED SQL")
        print("="*50)
        
        for i, query in enumerate(test_queries, 1):
            print(f"\nQuery {i}: {query}")
            print("-" * 40)
            result = patient_module.forward(query)
            print(result.results)
            print()
        
    except Exception as e:
        print(f"Error in main: {e}")
        import traceback
        traceback.print_exc()

In [21]:

await main()

Database created successfully!

TESTING HEALTHCARE DATABASE QUERIES WITH LLM-GENERATED SQL

Query 1: List all cancer patients above age 40
----------------------------------------
Generated SQL: SELECT * FROM patients 
WHERE "Medical Condition" = 'Cancer' 
AND Age > 40 
LIMIT 100;
Found patients:
1. adrIENNE bEll - Cancer (Age: 43, Gender: Female, Test Results: Abnormal, Admission: 2022-09-19)
2. ChRISTopher BerG - Cancer (Age: 58, Gender: Female, Test Results: Inconclusive, Admission: 2021-05-23)
3. mIchElLe daniELs - Cancer (Age: 72, Gender: Male, Test Results: Normal, Admission: 2020-04-19)
4. bROOkE brady - Cancer (Age: 44, Gender: Female, Test Results: Normal, Admission: 2021-10-08)
5. Erin oRTEga - Cancer (Age: 43, Gender: Male, Test Results: Normal, Admission: 2023-05-24)
6. pAUL wILLiAmS - Cancer (Age: 81, Gender: Female, Test Results: Abnormal, Admission: 2020-08-23)
7. lYNn MaRtinez - Cancer (Age: 65, Gender: Male, Test Results: Normal, Admission: 2022-10-12)


