# Testing the SQL Q&A Agent

In this notebook, we will:
1. Import the `generate_valid_query` function from our `app.py`.
2. Define a set of test nurse queries and expected outcomes (positive vs. negative).
3. Run the agent and capture whether we got a non-empty result (prediction=1) or empty result/error (prediction=0).
4. Compare with the expected label (1 or 0) to compute precision, recall, F1, confusion matrix, etc.

> **Note**: This is a **contrived** classification approach. In practice, you might prefer an approach that checks if the actual rows returned match expected data. But for demonstration, we’ll do a simple pass/fail on whether the query yields rows.


In [1]:
import unittest
import re
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
from app import generate_valid_query 

# If your app is named differently or you have a different import path, adjust accordingly.

# We'll define a test patient ID that you know has some data.
# In your real testing scenario, pick an ID that exists in the DB.
TEST_PATIENT_ID = "5e7e67c6-39d8-c775-aa42-8153e93fadb4"

  from .autonotebook import tqdm as notebook_tqdm


## Defining Our Test Cases

- **prompt**: The nurse-like question.
- **expected_label**: `1` means we expect **non-empty** results (the question is valid and should return rows). `0` means we expect **empty** or no valid result.

We'll store them in a list of dicts.

In [None]:
# Test cases

test_cases = [
    {
        "prompt": "How many encounters does this patient have?",
        "expected_label": 1, 
        "description": "We expect 11 encounters for this patient."
    },
    {
        "prompt": "Show me the patient's first, last name, and birthdate.",
        "expected_label": 1,
        "description": "We do have a real patient record, so at least 1 row should appear."
    },
    {
        "prompt": "List immunizations from 1900 to 1901 for this patient.",
        "expected_label": 0,
        "description": "No immunizations exist that far in the past."
    },
    {
        "prompt": "Give me the patient's ZIP code.",
        "expected_label": 1,
        "description": "The patient record likely has a ZIP code."
    },
    {
        "prompt": "List any allergies for this patient.",
        "expected_label": 1,
        "description": "The patient has one allergy in Allergy to substance (finding)."
    },
    {
        "prompt": "Show me the patient's city and state.",
        "expected_label": 1,
        "description": "Demographic info typically includes city & state."
    },
    {
        "prompt": "List all procedures for this patient after year 2000.",
        "expected_label": 1,
        "description": "Patient has 37 procedures after 2000."
    },
    {
        "prompt": "Show me the patient's current medications.",
        "expected_label": 1,
        "description": "Patient has 1 current medication."
    },
    {
        "prompt": "List the conditions that ended the same day they started for this patient.",
        "expected_label": 1,
        "description": "2 conditions started & ended on same day."
    },
    {
        "prompt": "How many conditions does this patient have that started after 2022?",
        "expected_label": 0,
        "description": "The patient has 0 new condition after 2022."
    }
]

test_cases

[{'prompt': 'How many encounters does this patient have?',
  'expected_label': 1,
  'description': 'We expect 11 encounters for this patient.'},
 {'prompt': "Show me the patient's first, last name, and birthdate.",
  'expected_label': 1,
  'description': 'We do have a real patient record, so at least 1 row should appear.'},
 {'prompt': 'List immunizations from 1900 to 1901 for this patient.',
  'expected_label': 0,
  'description': 'No immunizations exist that far in the past.'},
 {'prompt': "Give me the patient's ZIP code.",
  'expected_label': 1,
  'description': 'The patient record likely has a ZIP code.'},
 {'prompt': 'List any allergies for this patient.',
  'expected_label': 1,
  'description': 'The patient has one allergy in Allergy to substance (finding).'},
 {'prompt': "Show me the patient's city and state.",
  'expected_label': 1,
  'description': 'Demographic info typically includes city & state.'},
 {'prompt': 'List all procedures for this patient after year 2000.',
  'expe

## Test Execution Logic

We'll define a helper function to:
1. Call `generate_valid_query(prompt, TEST_PATIENT_ID)`.
2. If it returns an error string, we interpret that as **prediction=0** (empty result) because we failed.
3. If it returns `(col_names, rows)`, we interpret **prediction=1** if `len(rows) > 0`, else `0`.
4. We'll store `(prompt, expected_label, prediction, final_sql)` for reporting.


In [6]:
def run_test_case(prompt, expected_label, patient_id):
    final_sql, query_data = generate_valid_query(prompt, patient_id)
    if isinstance(query_data, str):
        # It's an error message
        prediction = 0
        col_names = None
        rows = None
    else:
        col_names, rows = query_data
        if rows and len(rows) > 0:
            prediction = 1
        else:
            prediction = 0

    return {
        "prompt": prompt,
        "expected_label": expected_label,
        "prediction": prediction,
        "final_sql": final_sql,
        "row_count": len(rows) if rows else 0,
        "col_names": col_names
    }


## Running All Tests, Collecting Predictions

We'll run each test prompt and store the results in a list. Then we can compute precision, recall, F1, confusion matrix, etc. using scikit-learn.

In [7]:
all_results = []
for tc in test_cases:
    r = run_test_case(
        prompt=tc["prompt"], 
        expected_label=tc["expected_label"], 
        patient_id=TEST_PATIENT_ID
    )
    r["description"] = tc["description"]
    all_results.append(r)

all_results

[DEBUG Attempt 1] Proposed SQL:
SELECT COUNT(encounters.Id) AS NumberOfEncounters
FROM encounters
WHERE encounters.PATIENT = '5e7e67c6-39d8-c775-aa42-8153e93fadb4';
[DEBUG Attempt 1] Proposed SQL:
SELECT FIRST, LAST, BIRTHDATE
FROM patients
WHERE Id = '5e7e67c6-39d8-c775-aa42-8153e93fadb4';
[DEBUG Attempt 1] Proposed SQL:
SELECT DATE, CODE, DESCRIPTION 
FROM immunizations 
WHERE DATE BETWEEN '1900-01-01' AND '1901-12-31'
   AND PATIENT = '5e7e67c6-39d8-c775-aa42-8153e93fadb4';
[DEBUG Attempt 1] Proposed SQL:
SELECT ZIP FROM patients WHERE Id = '5e7e67c6-39d8-c775-aa42-8153e93fadb4';
[DEBUG Attempt 1] Proposed SQL:
SELECT 
    a.CODE,
    a.DESCRIPTION,
    p.Id
FROM 
    allergies a
LEFT JOIN 
    patients p ON 
        a.PATIENT = p.Id
WHERE 
    p.Id = '5e7e67c6-39d8-c775-aa42-8153e93fadb4'
ORDER BY 
    a.DESCRIPTION ASC;
[DEBUG Attempt 2] Proposed SQL:
SELECT CODE, DESCRIPTION FROM allergies WHERE PATIENT = '5e7e67c6-39d8-c775-aa42-8153e93fadb4';
[DEBUG Attempt 1] Proposed SQL:
SEL

[{'prompt': 'How many encounters does this patient have?',
  'expected_label': 1,
  'prediction': 1,
  'final_sql': "SELECT COUNT(encounters.Id) AS NumberOfEncounters\nFROM encounters\nWHERE encounters.PATIENT = '5e7e67c6-39d8-c775-aa42-8153e93fadb4';",
  'row_count': 1,
  'col_names': ['NumberOfEncounters'],
  'description': 'We expect 11 encounters for this patient.'},
 {'prompt': "Show me the patient's first, last name, and birthdate.",
  'expected_label': 1,
  'prediction': 1,
  'final_sql': "SELECT FIRST, LAST, BIRTHDATE\nFROM patients\nWHERE Id = '5e7e67c6-39d8-c775-aa42-8153e93fadb4';",
  'row_count': 1,
  'col_names': ['FIRST', 'LAST', 'BIRTHDATE'],
  'description': 'We do have a real patient record, so at least 1 row should appear.'},
 {'prompt': 'List immunizations from 1900 to 1901 for this patient.',
  'expected_label': 0,
  'prediction': 0,
  'final_sql': "SELECT DATE, CODE, DESCRIPTION \nFROM immunizations \nWHERE DATE BETWEEN '1900-01-01' AND '1901-12-31'\n   AND PATIENT

## Generating Classification Metrics

We'll define:
- `y_true` = `[expected_label, ...]`
- `y_pred` = `[prediction, ...]`

Then we call `classification_report` and `confusion_matrix` from `sklearn`.

In [8]:
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np

y_true = [res["expected_label"] for res in all_results]
y_pred = [res["prediction"] for res in all_results]

cm = confusion_matrix(y_true, y_pred, labels=[1,0])
print("Confusion Matrix (labels=[1,0]):\n", cm)

report = classification_report(y_true, y_pred, labels=[1,0], target_names=["Positive","Negative"])
print("\nClassification Report:\n", report)

Confusion Matrix (labels=[1,0]):
 [[8 0]
 [1 1]]

Classification Report:
               precision    recall  f1-score   support

    Positive       0.89      1.00      0.94         8
    Negative       1.00      0.50      0.67         2

    accuracy                           0.90        10
   macro avg       0.94      0.75      0.80        10
weighted avg       0.91      0.90      0.89        10



## Detailed Report of Each Test Case
We can print out the final SQL, row counts, etc. for debugging.

In [9]:
for idx,res in enumerate(all_results, start=1):
    print(f"Test Case #{idx}")
    print("Prompt           :", res["prompt"])
    print("Description      :", res["description"])
    print("Expected Label   :", res["expected_label"])
    print("Prediction       :", res["prediction"])
    print("Final SQL        :", res["final_sql"])
    print("Row Count        :", res["row_count"])
    print("Column Names     :", res["col_names"])
    print("-"*50)


Test Case #1
Prompt           : How many encounters does this patient have?
Description      : We expect 11 encounters for this patient.
Expected Label   : 1
Prediction       : 1
Final SQL        : SELECT COUNT(encounters.Id) AS NumberOfEncounters
FROM encounters
WHERE encounters.PATIENT = '5e7e67c6-39d8-c775-aa42-8153e93fadb4';
Row Count        : 1
Column Names     : ['NumberOfEncounters']
--------------------------------------------------
Test Case #2
Prompt           : Show me the patient's first, last name, and birthdate.
Description      : We do have a real patient record, so at least 1 row should appear.
Expected Label   : 1
Prediction       : 1
Final SQL        : SELECT FIRST, LAST, BIRTHDATE
FROM patients
WHERE Id = '5e7e67c6-39d8-c775-aa42-8153e93fadb4';
Row Count        : 1
Column Names     : ['FIRST', 'LAST', 'BIRTHDATE']
--------------------------------------------------
Test Case #3
Prompt           : List immunizations from 1900 to 1901 for this patient.
Description      