In [1]:
import json
from datetime import datetime, timedelta

def get_relative_date(days: int, past: bool = True) -> str:
    """Calculates a date in YYYY-MM-DD format relative to today."""
    if past:
        target_date = datetime.now() - timedelta(days=days)
    else:
        target_date = datetime.now() + timedelta(days=days)
    return target_date.strftime("%Y-%m-%d")

def generate_test_cases():
    """
    Generates a list of test cases from a set of templates.
    """
    test_cases = []

    # --- Define Templates ---
    # Each template defines a type of filter and provides variations.
    templates = [
        # --- EQUALITY / NEGATION ---
        {
            "field": "gender", "operator": "=", "type": str,
            "prompts_en": ["Find {value} customers", "Show me the {value} clients", "{value} users"],
            "prompts_ar": ["اعثر على العملاء من {value}", "أرني العملاء الـ {value}", "المستخدمون الـ {value}"],
            "values": [{"en": "male", "ar": "ذكور"}, {"en": "female", "ar": "إناث"}]
        },
        {
            "field": "country", "operator": "!=", "type": str,
            "prompts_en": ["Everyone except those from {value}", "All users not in {value}"],
            "prompts_ar": ["الجميع ما عدا أولئك من {value}", "كل المستخدمين ليسوا في {value}"],
            "values": [{"en": "Egypt", "ar": "مصر"}, {"en": "KSA", "ar": "السعودية"}]
        },
        # --- NUMERIC COMPARISONS ---
        {
            "field": "total_orders", "operator": ">", "type": int,
            "prompts_en": ["Customers with more than {value} orders", "Users who ordered over {value} times"],
            "prompts_ar": ["العملاء الذين لديهم أكثر من {value} طلبات", "المستخدمون الذين طلبوا أكثر من {value} مرات"],
            "values": [5, 50, 100]
        },
        {
            "field": "total_sales", "operator": "<=", "type": float,
            "prompts_en": ["Clients with total sales of {value} or less", "Sales at most {value}"],
            "prompts_ar": ["العملاء الذين تبلغ مبيعاتهم الإجمالية {value} أو أقل", "المبيعات بحد أقصى {value}"],
            "values": [100.50, 5000.0, 999.99]
        },
        # --- DATE COMPARISONS ---
        {
            "field": "joining_date", "operator": ">=", "type": "date",
            "prompts_en": ["Customers who joined in the last {value} days", "Signups from the past {value} days"],
            "prompts_ar": ["العملاء الذين انضموا في آخر {value} يومًا", "التسجيلات من آخر {value} يوم"],
            "values": [7, 30, 90]
        },
        # --- BETWEEN OPERATOR ---
        {
            "field": "store_rating", "operator": "between", "type": float,
            "prompts_en": ["Stores rated between {v1} and {v2} stars", "Find ratings from {v1} to {v2}"],
            "prompts_ar": ["المتاجر التي تقييمها بين {v1} و {v2} نجوم", "اعثر على التقييمات من {v1} إلى {v2}"],
            "values": [[3.0, 5.0], [1.5, 3.5]]
        },
        # --- BOOLEAN FLAGS ---
        {
            "field": "have_cancelled_orders", "operator": "=", "type": bool,
            "prompts_en": ["Users who have cancelled orders", "Find people with cancelled orders"],
            "prompts_ar": ["المستخدمون الذين لديهم طلبات ملغاة", "اعثر على الأشخاص أصحاب الطلبات الملغاة"],
            "values": [True]
        },
        # --- LISTS (OR logic) ---
        {
            "field": "city", "operator": "=", "type": list,
            "prompts_en": ["Users in {v1} or {v2}", "Customers from {v1}, {v2}"],
            "prompts_ar": ["المستخدمون في {v1} أو {v2}", "عملاء من {v1}، {v2}"],
            "values": [["Riyadh", "Jeddah"], ["Cairo", "Alexandria"]]
        }
    ]

    # --- Generate Test Cases ---
    for template in templates:
        for value in template["values"]:
            for i, prompt_en_template in enumerate(template["prompts_en"]):
                prompt_ar_template = template["prompts_ar"][i]

                # Handle different value structures
                if template["operator"] == "between":
                    v1, v2 = value
                    prompt_en = prompt_en_template.format(v1=v1, v2=v2)
                    prompt_ar = prompt_ar_template.format(v1=v1, v2=v2)
                    expected_value = [template["type"](v) for v in value]
                elif template["type"] == list:
                    v1, v2 = value
                    prompt_en = prompt_en_template.format(v1=v1, v2=v2)
                    prompt_ar = prompt_ar_template.format(v1=v1, v2=v2)
                    expected_value = value
                elif template["type"] == str:
                    prompt_en = prompt_en_template.format(value=value["en"])
                    prompt_ar = prompt_ar_template.format(value=value["ar"])
                    expected_value = value["en"]
                elif template["type"] == "date":
                    prompt_en = prompt_en_template.format(value=value)
                    prompt_ar = prompt_ar_template.format(value=value)
                    expected_value = get_relative_date(value)
                else:
                    prompt_en = prompt_en_template.format(value=value)
                    prompt_ar = prompt_ar_template.format(value=value)
                    expected_value = template["type"](value)

                case_id = f"{template['field']}_{template['operator']}_{i}_{str(value).replace(' ', '')}"
                
                test_case = {
                    "id": case_id,
                    "prompt_en": prompt_en,
                    "prompt_ar": prompt_ar,
                    "expected_status": 200,
                    "expected_output": {
                        "filters": [{
                            "field": template["field"],
                            "operator": template["operator"],
                            "value": expected_value
                        }]
                    }
                }
                test_cases.append(test_case)

    # --- Manually add complex/error cases ---
    test_cases.append({
        "id": "multi_filter_complex_ar",
        "prompt_en": "Female users from Egypt with sales over 1000",
        "prompt_ar": "المستخدمات الإناث من مصر بمبيعات تزيد عن 1000",
        "expected_status": 200,
        "expected_output": {
            "filters": [
                {"field": "gender", "operator": "=", "value": "female"},
                {"field": "country", "operator": "=", "value": "Egypt"},
                {"field": "total_sales", "operator": ">", "value": 1000.0}
            ]
        }
    })
    test_cases.append({
        "id": "error_unsupported_field",
        "prompt_en": "Customers with a high net promoter score",
        "prompt_ar": "العملاء الذين لديهم درجة رضا عالية",
        "expected_status": 400,
        "expected_output": {"detail": "No valid filters were found. The prompt may contain unsupported fields or be too ambiguous."}
    })

    return test_cases

if __name__ == "__main__":
    # Generate the test data
    all_test_cases = generate_test_cases()
    
    # Write to a JSON file
    with open("tests/test_data_2.json", "w", encoding="utf-8") as f:
        json.dump(all_test_cases, f, indent=4, ensure_ascii=False)

    print(f"✅ Successfully generated {len(all_test_cases)} test cases and saved to tests/test_data_2.json")


✅ Successfully generated 40 test cases and saved to tests/test_data_2.json


In [37]:
import os
import json
import requests
from tqdm import tqdm
import sys


TEST_DATA_PATH = "tests/test_data.json"

def run_single_test(case, language):
    """
    Runs a single test case against the API.

    Args:
        case (dict): A dictionary representing a single test case.
        language (str): 'en' for English prompt, 'ar' for Arabic.

    Returns:
        bool: True if the test passed, False otherwise.
    """
    prompt_key = f"prompt_{language}"
    prompt_text = case[prompt_key]
    
    headers = {
        "Content-Type": "application/json",
        "x-api-key": API_KEY
    }
    payload = {"prompt": prompt_text}

    try:
        response = requests.post(API_URL, headers=headers, json=payload, timeout=30)
        
        # 1. Check if the status code matches
        if response.status_code != case["expected_status"]:
            print(f"\n❌ FAILED: {case['id']} ({language.upper()})")
            print(f"  Reason: Status Code Mismatch")
            print(f"  Expected: {case['expected_status']}")
            print(f"  Got: {response.status_code}")
            print(f"  Response: {response.text}")
            return False
            
        # 2. Check if the JSON body matches
        response_json = response.json()
        if response_json != case["expected_output"]:
            print(f"\n❌ FAILED: {case['id']} ({language.upper()})")
            print(f"  Reason: JSON Output Mismatch")
            # Use json.dumps with ensure_ascii=False to handle Unicode characters
            print(f"  Expected: {json.dumps(case['expected_output'], ensure_ascii=False)}")
            print(f"  Got: {json.dumps(response_json, ensure_ascii=False)}")
            return False

        # If both checks pass
        return True

    except requests.exceptions.RequestException as e:
        print(f"\n❌ FAILED: {case['id']} ({language.upper()})")
        print(f"  Reason: API Request Failed")
        print(f"  Error: {e}")
        return False

def main():
    """
    Main function to load test data and run the accuracy test suite.
    """
    if not API_URL or not API_KEY:
        print("🚨 ERROR: Please set the API_URL and API_KEY environment variables.")
        sys.exit(1)

    try:
        with open(TEST_DATA_PATH, "r", encoding="utf-8") as f:
            test_cases = json.load(f)
    except FileNotFoundError:
        print(f"🚨 ERROR: Test data file not found at {TEST_DATA_PATH}")
        sys.exit(1)

    passed_tests = 0
    # Each case has an English and an Arabic prompt, so we double the count
    total_tests = len(test_cases) * 2
    
    print(f"🚀 Starting API accuracy test with {total_tests} total prompts...")
    
    # Use tqdm for a progress bar
    with tqdm(total=total_tests, unit="prompt") as pbar:
        for case in test_cases:
            # Test English prompt
            if run_single_test(case, "en"):
                passed_tests += 1
            pbar.update(1)
            
            # Test Arabic prompt
            if run_single_test(case, "ar"):
                passed_tests += 1
            pbar.update(1)

    # --- Final Report ---
    print("\n" + "="*50)
    print("📊 FINAL ACCURACY REPORT")
    print("="*50)
    
    accuracy = (passed_tests / total_tests) * 100 if total_tests > 0 else 0
    
    print(f"  Tests Passed: {passed_tests}")
    print(f"  Total Prompts: {total_tests}")
    print(f"  Accuracy: {accuracy:.2f}%")
    print("="*50)

    if accuracy >= 90:
        print("✅ Success! Accuracy meets the >= 90% requirement.")
    else:
        print("⚠️ Warning! Accuracy is below the 90% requirement.")

if __name__ == "__main__":
    main()



JSONDecodeError: Extra data: line 72 column 1 (char 1916)

In [None]:
import os
import json
import requests
from tqdm import tqdm
import sys



def run_single_test(case, language):
    """
    Runs a single test case against the API.

    Args:
        case (dict): A dictionary representing a single test case.
        language (str): 'en' for English prompt, 'ar' for Arabic.

    Returns:
        bool: True if the test passed, False otherwise.
    """
    prompt_key = f"prompt_{language}"
    prompt_text = case[prompt_key]
    
    headers = {
        "Content-Type": "application/json; charset=utf-8", # Be explicit about charset
        # "x-api-key": API_KEY
    }
    if API_KEY:
        headers["x-api-key"] = API_KEY
    payload = {"prompt": prompt_text}

    # Manually serialize the dictionary to a JSON string and encode it to UTF-8
    data_to_send = json.dumps(payload, ensure_ascii=False).encode('utf-8')

    try:
        # Use the `data` parameter instead of `json` since we manually encoded it
        response = requests.post(API_URL, headers=headers, data=data_to_send, timeout=30)
        
        # 1. Check if the status code matches
        if response.status_code != case["expected_status"]:
            print(f"\n❌ FAILED: {case['id']} ({language.upper()})")
            print(f"  Reason: Status Code Mismatch")
            print("data_to_send:", data_to_send)
            print(f"  Expected: {case['expected_status']}")
            print(f"  Got: {response.status_code}")
            print(f"  Response: {response.text}")
            return False
            
        # 2. Check if the JSON body matches
        response_json = response.json()
        if response_json != case["expected_output"]:
            print(f"\n❌ FAILED: {case['id']} ({language.upper()})")
            print("data_to_send:", data_to_send)
            print(f"  Reason: JSON Output Mismatch")
            # Use json.dumps with ensure_ascii=False to handle Unicode characters
            print(f"  Expected: {json.dumps(case['expected_output'], ensure_ascii=False)}")
            print(f"  Got: {json.dumps(response_json, ensure_ascii=False)}")
            return False

        # If both checks pass
        return True

    except requests.exceptions.RequestException as e:
        print(f"\n❌ FAILED: {case['id']} ({language.upper()})")
        print(f"  Reason: API Request Failed")
        print(f"  Error: {e}")
        return False





In [3]:
if not API_URL or not API_KEY:
    print("🚨 ERROR: Please set the API_URL and API_KEY environment variables.")
    sys.exit(1)

In [18]:
try:
    with open(TEST_DATA_PATH, "r", encoding="utf-8") as f:
        test_cases = json.load(f)
except FileNotFoundError:
    print(f"🚨 ERROR: Test data file not found at {TEST_DATA_PATH}")
    sys.exit(1)

In [12]:
TEST_DATA_PATH

'tests/test_data_2.json'

In [13]:
len(test_cases)

21

In [14]:
passed_tests = 0
# Each case has an English and an Arabic prompt, so we double the count
total_tests = len(test_cases) * 2

print(f"🚀 Starting API accuracy test with {total_tests} total prompts...")

🚀 Starting API accuracy test with 42 total prompts...


In [None]:
#deployment
API_URL = "https://w2xhkmywby.us-east-2.awsapprunner.com/parse_prompt" #os.getenv("API_URL")
API_KEY = "blabla" #os.getenv("API_KEY")

#local
API_URL = "http://127.0.0.1:8000/parse_prompt" #os.getenv("API_URL")
API_KEY = "" #os.getenv("API_KEY")

TEST_DATA_PATH = "tests/test_data_2.json"

In [None]:
# Use tqdm for a progress bar
#deployment
passed_tests = 0

with tqdm(total=total_tests, unit="prompt") as pbar:
    for case in test_cases:
        # Test English prompt
        if run_single_test(case, "en"):
            passed_tests += 1
        pbar.update(1)
        
        # Test Arabic prompt
        if run_single_test(case, "ar"):
            passed_tests += 1
        pbar.update(1)

 24%|██▍       | 10/42 [00:11<00:39,  1.24s/prompt]


❌ FAILED: country_!=_0_{'en':'Saudi Arabia','ar':'السعودية'} (AR)
  Reason: JSON Output Mismatch
  Expected: {"filters": [{"field": "country", "operator": "!=", "value": "Saudi Arabia"}]}
  Got: {"filters": [{"field": "city", "operator": "=", "value": ["Riyadh", "Jeddah"]}, {"field": "joining_date", "operator": ">", "value": "2023-01-01"}, {"field": "total_orders", "operator": ">", "value": 5}]}


 29%|██▊       | 12/42 [00:14<00:40,  1.36s/prompt]


❌ FAILED: total_orders_>_0_5 (AR)
  Reason: JSON Output Mismatch
  Expected: {"filters": [{"field": "total_orders", "operator": ">", "value": 5}]}
  Got: {"filters": [{"field": "city", "operator": "=", "value": ["Riyadh", "Jeddah"]}, {"field": "joining_date", "operator": ">", "value": "2023-01-01"}, {"field": "total_orders", "operator": ">", "value": 5}]}


 33%|███▎      | 14/42 [00:16<00:36,  1.30s/prompt]


❌ FAILED: total_orders_>_0_50 (AR)
  Reason: JSON Output Mismatch
  Expected: {"filters": [{"field": "total_orders", "operator": ">", "value": 50}]}
  Got: {"filters": [{"field": "city", "operator": "=", "value": "Riyadh"}, {"field": "total_orders", "operator": ">", "value": 50}]}


 48%|████▊     | 20/42 [00:24<00:28,  1.32s/prompt]


❌ FAILED: total_sales_<=_1_5000.0 (AR)
  Reason: JSON Output Mismatch
  Expected: {"filters": [{"field": "total_sales", "operator": "<=", "value": 5000.0}]}
  Got: {"filters": [{"field": "city", "operator": "=", "value": ["Riyadh", "Jeddah"]}, {"field": "joining_date", "operator": ">", "value": "2023-01-01"}, {"field": "total_orders", "operator": ">", "value": 5}]}


 52%|█████▏    | 22/42 [00:27<00:28,  1.43s/prompt]


❌ FAILED: total_sales_<=_1_999.99 (AR)
  Reason: JSON Output Mismatch
  Expected: {"filters": [{"field": "total_sales", "operator": "<=", "value": 999.99}]}
  Got: {"filters": [{"field": "city", "operator": "=", "value": ["Riyadh", "Jeddah"]}, {"field": "joining_date", "operator": ">", "value": "2023-01-01"}, {"field": "total_orders", "operator": ">", "value": 5}]}


 83%|████████▎ | 35/42 [00:42<00:09,  1.29s/prompt]


❌ FAILED: city_=_0_['Riyadh','Jeddah'] (EN)
  Reason: JSON Output Mismatch
  Expected: {"filters": [{"field": "city", "operator": "=", "value": ["Riyadh", "Jeddah"]}]}
  Got: {"filters": [{"field": "city", "operator": "=", "value": ["Riyadh", "Jeddah"]}, {"field": "joining_date", "operator": ">", "value": "2023-01-01"}, {"field": "total_orders", "operator": ">", "value": 5}]}


 86%|████████▌ | 36/42 [00:44<00:08,  1.42s/prompt]


❌ FAILED: city_=_0_['Riyadh','Jeddah'] (AR)
  Reason: JSON Output Mismatch
  Expected: {"filters": [{"field": "city", "operator": "=", "value": ["Riyadh", "Jeddah"]}]}
  Got: {"filters": [{"field": "city", "operator": "=", "value": ["Riyadh", "Jeddah"]}, {"field": "joining_date", "operator": ">", "value": "2023-01-01"}, {"field": "total_orders", "operator": ">", "value": 5}]}


 98%|█████████▊| 41/42 [00:51<00:01,  1.40s/prompt]


❌ FAILED: error_unsupported_field (EN)
  Reason: Status Code Mismatch
  Expected: 400
  Got: 200
  Response: {"filters":[{"field":"store_rating","operator":">","value":8.0}]}


100%|██████████| 42/42 [00:52<00:00,  1.25s/prompt]


❌ FAILED: error_unsupported_field (AR)
  Reason: Status Code Mismatch
  Expected: 400
  Got: 200
  Response: {"filters":[{"field":"store_rating","operator":">","value":4.0}]}





In [None]:
# Use tqdm for a progress bar
#local
passed_tests = 0
with tqdm(total=total_tests, unit="prompt") as pbar:
    for case in test_cases:
        # Test English prompt
        if run_single_test(case, "en"):
            passed_tests += 1
        pbar.update(1)
        
        # Test Arabic prompt
        if run_single_test(case, "ar"):
            passed_tests += 1
        pbar.update(1)

print("\n" + "="*50)
print("📊 FINAL ACCURACY REPORT")

accuracy = (passed_tests / total_tests) * 100 if total_tests > 0 else 0

print(f"  Tests Passed: {passed_tests}")
print(f"  Total Prompts: {total_tests}")
print(f"  Accuracy: {accuracy:.2f}%")
if accuracy >= 90:
    print("✅ Success! Accuracy meets the >= 90% requirement.")
else:
    print("⚠️ Warning! Accuracy is below the 90% requirement.")


 98%|█████████▊| 41/42 [00:45<00:01,  1.15s/prompt]


❌ FAILED: error_unsupported_field (EN)
data_to_send: b'{"prompt": "Customers with a high net promoter score"}'
  Reason: JSON Output Mismatch
  Expected: {"filters": [{"field": "store_rating", "operator": ">", "value": 4.0}]}
  Got: {"detail": "No valid filters were found. The prompt may contain unsupported fields or be too ambiguous."}


100%|██████████| 42/42 [00:46<00:00,  1.10s/prompt]


❌ FAILED: error_unsupported_field (AR)
  Reason: Status Code Mismatch
data_to_send: b'{"prompt": "\xd8\xa7\xd9\x84\xd8\xb9\xd9\x85\xd9\x84\xd8\xa7\xd8\xa1 \xd8\xa7\xd9\x84\xd8\xb0\xd9\x8a\xd9\x86 \xd9\x84\xd8\xaf\xd9\x8a\xd9\x87\xd9\x85 \xd8\xaf\xd8\xb1\xd8\xac\xd8\xa9 \xd8\xb1\xd8\xb6\xd8\xa7 \xd8\xb9\xd8\xa7\xd9\x84\xd9\x8a\xd8\xa9"}'
  Expected: 400
  Got: 200
  Response: {"filters":[{"field":"store_rating","operator":">","value":4.0}]}

📊 FINAL ACCURACY REPORT
  Tests Passed: 40
  Total Prompts: 42
  Accuracy: 95.24%
✅ Success! Accuracy meets the >= 90% requirement.





In [23]:
# Use tqdm for a progress bar
#local
passed_tests = 0

with tqdm(total=total_tests, unit="prompt") as pbar:
    for case in test_cases:
        # Test English prompt
        if run_single_test(case, "en"):
            passed_tests += 1
        pbar.update(1)
        
        # Test Arabic prompt
        if run_single_test(case, "ar"):
            passed_tests += 1
        pbar.update(1)

print("\n" + "="*50)
print("📊 FINAL ACCURACY REPORT")

accuracy = (passed_tests / total_tests) * 100 if total_tests > 0 else 0

print(f"  Tests Passed: {passed_tests}")
print(f"  Total Prompts: {total_tests}")
print(f"  Accuracy: {accuracy:.2f}%")
if accuracy >= 90:
    print("✅ Success! Accuracy meets the >= 90% requirement.")
else:
    print("⚠️ Warning! Accuracy is below the 90% requirement.")


 98%|█████████▊| 41/42 [00:44<00:01,  1.14s/prompt]


❌ FAILED: error_unsupported_field (EN)
data_to_send: b'{"prompt": "Customers with a high net promoter score"}'
  Reason: JSON Output Mismatch
  Expected: {"filters": [{"field": "store_rating", "operator": ">", "value": 4.0}]}
  Got: {"detail": "No valid filters were found. The prompt may contain unsupported fields or be too ambiguous."}


100%|██████████| 42/42 [00:45<00:00,  1.09s/prompt]


❌ FAILED: error_unsupported_field (AR)
  Reason: Status Code Mismatch
data_to_send: b'{"prompt": "\xd8\xa7\xd9\x84\xd8\xb9\xd9\x85\xd9\x84\xd8\xa7\xd8\xa1 \xd8\xa7\xd9\x84\xd8\xb0\xd9\x8a\xd9\x86 \xd9\x84\xd8\xaf\xd9\x8a\xd9\x87\xd9\x85 \xd8\xaf\xd8\xb1\xd8\xac\xd8\xa9 \xd8\xb1\xd8\xb6\xd8\xa7 \xd8\xb9\xd8\xa7\xd9\x84\xd9\x8a\xd8\xa9"}'
  Expected: 400
  Got: 200
  Response: {"filters":[{"field":"store_rating","operator":">","value":4.0}]}

📊 FINAL ACCURACY REPORT
  Tests Passed: 40
  Total Prompts: 42
  Accuracy: 95.24%
✅ Success! Accuracy meets the >= 90% requirement.



