In [9]:
import dspy
import re
import pandas as pd
import sqlite3
from typing import List, Optional, Dict, Any

In [10]:
from helper import openai_api_key
import os
os.environ["OPENAI_API_KEY"] = openai_api_key

In [11]:
class HealthcareDBTool:
    def __init__(self, db_path: str):
        self.db_path = db_path

    def build_where_clause(self, filters: Dict[str, Any]) -> (str, List):
        clauses, params = [], []
        if "condition" in filters:
            clauses.append("LOWER([Medical Condition]) LIKE ?")
            params.append(f"%{filters['condition'].lower()}%")
        if "year" in filters:
            clauses.append("strftime('%Y', [Date of Admission]) = ?")
            params.append(str(filters["year"]))
        if "age_min" in filters:
            clauses.append("CAST(Age AS INTEGER) >= ?")
            params.append(filters["age_min"])
        if "gender" in filters:
            clauses.append("LOWER(Gender) = ?")
            params.append(filters["gender"].lower())
        if "test_result" in filters:
            clauses.append("LOWER([Test Results]) LIKE ?")
            params.append(f"%{filters['test_result'].lower()}%")
        where = " WHERE " + " AND ".join(clauses) if clauses else ""
        return where, params

    def search_patients(
        self,
        condition: Optional[str] = None,
        year: Optional[int] = None,
        age_min: Optional[int] = None,
        gender: Optional[str] = None,
        test_result: Optional[str] = None,
        k: int = 5
    ) -> List[Dict[str, Any]]:
        filters = {}
        if condition is not None:
            filters['condition'] = condition
        if year is not None:
            filters['year'] = year
        if age_min is not None:
            filters['age_min'] = age_min
        if gender is not None:
            filters['gender'] = gender
        if test_result is not None:
            filters['test_result'] = test_result

        
        cols = ["Name", "Medical Condition", "Medication", "Test Results", "Age", "Gender", "Date of Admission"]
        where_clause, params = self.build_where_clause(filters)

        quoted_cols = [f'"{col}"' if ' ' in col else col for col in cols]
        sql = f"SELECT {', '.join(quoted_cols)} FROM patients {where_clause} LIMIT ?"
        params.append(k)
        
        try:
            with sqlite3.connect(self.db_path) as conn:
                result = conn.execute(sql, params).fetchall()
            return [dict(zip(cols, row)) for row in result]
        except sqlite3.Error as e:
            print(f"Database error: {e}")
            return []


In [12]:
class PatientSearchSignature(dspy.Signature):
    """Search for patients in the healthcare database based on various criteria."""
    query: str = dspy.InputField(desc="Natural language query about patients")
    results: str = dspy.OutputField(desc="Formatted search results")

In [13]:
class PatientSearchModule(dspy.Module):
    def __init__(self, db_tool):
        super().__init__()
        self.db_tool = db_tool

    def forward(self, query: str) -> PatientSearchSignature:
        # Parse the query to extract search parameters
        params = self._parse_query(query)
        results = self.db_tool.search_patients(**params)
        
        # Format results as a string
        if not results:
            formatted_results = "No patients found matching the criteria."
        else:
            formatted_results = "Found patients:\n"
            for i, patient in enumerate(results, 1):
                formatted_results += f"{i}. {patient['Name']} - {patient['Medical Condition']} "
                formatted_results += f"(Age: {patient['Age']}, Gender: {patient['Gender']}, "
                formatted_results += f"Test Results: {patient['Test Results']}, "
                formatted_results += f"Admission: {patient['Date of Admission']})\n"
        
        return dspy.Prediction(results=formatted_results)
    
    def _parse_query(self, query: str) -> Dict[str, Any]:
        """Simple query parser - you might want to enhance this with NLP"""
        params = {}
        query_lower = query.lower()
        
        # Extract year
        year_match = re.search(r'\b(20\d{2})\b', query)
        if year_match:
            params['year'] = int(year_match.group(1))
        
        # Extract age
        age_match = re.search(r'above age (\d+)|age (\d+)', query_lower)
        if age_match:
            age = int(age_match.group(1) or age_match.group(2))
            params['age_min'] = age
        
        # Extract gender
        if 'female' in query_lower:
            params['gender'] = 'female'
        elif 'male' in query_lower:
            params['gender'] = 'male'
        
        # Extract test results
        if 'abnormal' in query_lower:
            params['test_result'] = 'abnormal'
        elif 'normal' in query_lower:
            params['test_result'] = 'normal'
        
        # Extract medical conditions
        conditions = ['cancer', 'diabetes', 'hypertension', 'asthma', 'arthritis']
        for condition in conditions:
            if condition in query_lower:
                params['condition'] = condition
                break
        
        return params

In [14]:
def create_data():
    """Load healthcare data from CSV and preprocess it."""
    df = pd.read_csv("/Users/priyanka./Documents/agentic-security/DSPy/healthcare_data.csv")
    df = df.drop(['Billing Amount' ,'Room Number' ,'Insurance Provider' ,'Doctor' ,'Hospital'],axis = 1)
    return df

In [None]:
#define a simple DSPy signature for main flow
class MultiServerSignature(dspy.Signature):
    request: str = dspy.InputField(desc="The user's request ,potentially requiring external tools.")
    response: str = dspy.OutputField(desc="Final response")

async def main():
    try:
        # Configure DSPy
        dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini"))
        
        # Create database
        df = create_data()
        db_path = 'healthcare.db'
        
        conn = sqlite3.connect(db_path)
        df.to_sql('patients', conn, if_exists='replace', index=False)
        conn.close()
        print("Database created successfully!")
        
        # Initialize tools
        db_tool = HealthcareDBTool(db_path)
        patient_module = PatientSearchModule(db_tool)
        
        # Test queries
        test_queries = [
            "Find patients with abnormal results in 2024",
            "List all cancer patients above age 40",
            "Who was admitted for diabetes in 2022?",
            "List female patients whose test results are abnormal",
        ]
        
        print("\n" + "="*50)
        print("TESTING HEALTHCARE DATABASE QUERIES")
        print("="*50)
        
        for i, query in enumerate(test_queries, 1):
            print(f"\nQuery {i}: {query}")
            print("-" * 40)
            result = patient_module.forward(query)
            print(result.results)
        
    except Exception as e:
        print(f"Error in main: {e}")
        import traceback
        traceback.print_exc()

In [16]:

await main()

Database created successfully!

TESTING HEALTHCARE DATABASE QUERIES

Query 1: Find patients with abnormal results in 2024
----------------------------------------
No patients found matching the criteria.

Query 2: List all cancer patients above age 40
----------------------------------------
Found patients:
1. adrIENNE bEll - Cancer (Age: 43, Gender: Female, Test Results: Abnormal, Admission: 2022-09-19)
2. ChRISTopher BerG - Cancer (Age: 58, Gender: Female, Test Results: Inconclusive, Admission: 2021-05-23)
3. mIchElLe daniELs - Cancer (Age: 72, Gender: Male, Test Results: Normal, Admission: 2020-04-19)
4. bROOkE brady - Cancer (Age: 44, Gender: Female, Test Results: Normal, Admission: 2021-10-08)
5. Erin oRTEga - Cancer (Age: 43, Gender: Male, Test Results: Normal, Admission: 2023-05-24)


Query 3: Who was admitted for diabetes in 2022?
----------------------------------------
Found patients:
1. mr. KenNEth MoORE - Diabetes (Age: 34, Gender: Female, Test Results: Abnormal, Admission: