# Automated Reasoning Test Case Playground

This notebook demonstrates how to create and run a test for automated reasoning policies using the AWS Bedrock API. It includes the following functionality:

1. Setting up the Bedrock client with custom API models
2. Creating an automated reasoning policy test case
3. Listing all existing test cases
4. Running a test case
5. Getting a test case result

## Setup

First, let's import the necessary libraries and set up the Bedrock client with custom API models.

In [None]:
%pip install -r requirements.txt

In [None]:
!pip install --upgrade boto3

In [None]:
import os
import json
import boto3
import uuid
import time
import pandas as pd
from IPython.display import display, HTML, JSON
import ipywidgets as widgets
from datetime import datetime

In [None]:
# Create the Bedrock client
my_session = boto3.session.Session()
REGION_NAME = my_session.region_name
print(f'The region is {REGION_NAME}')
runtime_client = my_session.client('bedrock-runtime', region_name=REGION_NAME)
bedrock_client = my_session.client('bedrock', region_name=REGION_NAME)

## Create a Test Case

Now, let's create a test case for an automated reasoning policy using the Bedrock API.

A test case is a simulation designed to mimic a user interaction, with the main goal of confirming that an automated reasoning policy is functioning as intended. 

In [None]:
# Provide the policy arn for which you will create a test case without the version ID or number
policy_arn="<POLICY_ARN>"

In [None]:
def create_test_case(policy_arn, guard_content, expected_aggregated_findings_result, query_content=None, confidence_threshold=None):
    """
    Creates a test case for an automated reasoning policy
    
    Args:
        policy_arn (str): ARN of the automated reasoning policy to test against
        guard_content (str): The LLM answer to be evaluated by the policy
        expected_aggregated_findings_result (str): Expected result of the automated reasoning check
            Valid values: "VALID", "INVALID", "SATISFIABLE", "IMPOSSIBLE", 
            "TRANSLATION_AMBIGUOUS", "TOO_COMPLEX", "NO_TRANSLATION"
        query_content (str, optional): User query to test against the policy
        confidence_threshold (float, optional): Confidence threshold for the test (0.0 to 1.0)

    Returns:
        dict: Response from the API call
    """
    try:
        kwargs = {}
        if query_content is not None:
            kwargs['queryContent'] = query_content
        if confidence_threshold is not None:
            kwargs['confidenceThreshold'] = confidence_threshold

        return bedrock_client.create_automated_reasoning_policy_test_case(
            policyArn=policy_arn,
            guardContent=guard_content,
            expectedAggregatedFindingsResult=expected_aggregated_findings_result,
            clientRequestToken=str(uuid.uuid4()),
            **kwargs
        )
    except Exception as e:
        print(f"Error creating a test case: {str(e)}")
        raise

In [None]:
# Create a test case for the automated reasoning policy
# This example demonstrates creating a test case that expects a VALID result
# with maximum confidence threshold (1.0)

# Example to test with

# query_content=""" Patient Profile:
#     Age: 58 years
#     Length of stay: 6 days
#     Has chronic kidney disease stage 4
#     Two ED visits in last 6 months
#     Uninsured status
#     New requirement for durable medical equipment
#     """
# guard_content="Classification: High Risk"

guard_content="<GUARD_CONTENT>"
query_content="<QUERY_CONTENT>"

create_test_case_response = create_test_case(
    policy_arn=policy_arn,                              # ARN of the policy to test
    query_content=query_content,                        # Replace with the user query information that is the request to the LLM
    guard_content=guard_content,                        # Replace with actual LLM response to validate
    expected_aggregated_findings_result="VALID",        # Expected validation outcome
    confidence_threshold=1.0                            # Maximum confidence (100%)
)

# Extract the test case ID from the response for future operations
test_case_id = create_test_case_response['testCaseId']

# Display the full API response in JSON format
JSON(create_test_case_response)

Let's list all existing test cases. We should be able to see the test case that we created above.

In [None]:
def list_existing_test_cases(policy_arn):
    """
    Returns a list all test cases for an automated reasoning policy
    
    Args:
        policy_arn (str): ARN of the policy

    Returns:
        List[dict]: List of test cases
    """
    try:
        test_cases = []
        is_first_run = True
        pagination_token = None

        while is_first_run or pagination_token:
            if pagination_token:
                response = bedrock_client.list_automated_reasoning_policy_test_cases(
                    policyArn=policy_arn,
                    nextToken=pagination_token
                )
            else:
                response = bedrock_client.list_automated_reasoning_policy_test_cases(
                    policyArn=policy_arn,
                )

            test_cases.extend(response['testCases'])

            is_first_run = False
            pagination_token = response.get('nextToken', None)

        return test_cases
    except Exception as e:
        print(f"Error listing all test cases: {str(e)}")
        raise

In [None]:
# Retrieve all existing test cases for the specified policy
test_cases = list_existing_test_cases(policy_arn=policy_arn)

# Convert the test cases list to a pandas DataFrame for better visualization
# This creates a tabular format with columns for each test case attribute
test_cases_df = pd.DataFrame(test_cases)

# Display the DataFrame in a formatted table within the Jupyter notebook
# Shows test case IDs, content, expected results, and other metadata
display(test_cases_df)

# Running a Test Case

Prior to running a test case, we need to retrieve a build workflow id.

In [None]:
def list_build_workflows(policy_arn):
    """
    Lists all build workflows
    
    Args:
        policy_arn (str): ARN of the policy

    Returns:
        dict: Response from the API call
    """
    try:
        return bedrock_client.list_automated_reasoning_policy_build_workflows(
            policyArn=policy_arn
        )
    except Exception as e:
        print(f"Error listing build workflow: {str(e)}")
        raise

In [None]:
# Retrieve all build workflows associated with the specified policy
list_build_workflows_response = list_build_workflows(policy_arn=policy_arn)

# Extract the build workflow ID from the first workflow in the response
# This ID is required to run test cases against the policy
build_workflow_id = list_build_workflows_response['automatedReasoningPolicyBuildWorkflowSummaries'][0]['buildWorkflowId']

# Display the complete API response in formatted JSON
# Shows all workflow summaries with their status, creation time, and other metadata
JSON(list_build_workflows_response)

Now, using the `build_workflow_id` let's run the test case that we created above.

In [None]:
def run_test_cases(policy_arn, build_workflow_id, test_case_ids):
    """
    Runs a test case

    Args:
        policy_arn (str): ARN of the policy
        build_workflow_id (str): Id of the build workflow
        test_case_ids (List[str]): Ids of all test cases

    Returns:
        dict: Response from the API call
    """
    try:
        response = bedrock_client.start_automated_reasoning_policy_test_workflow(
            policyArn=policy_arn,
            buildWorkflowId=build_workflow_id,
            testCaseIds=test_case_ids,
            clientRequestToken=str(uuid.uuid4()),
        )

        print(f"Test workflow started successfully!")

        return response
    except Exception as e:
        print(f"Error starting a test workflow: {str(e)}")
        raise

In [None]:
# Executes test cases for a given policy and build_workflow_id
# Note: This call only starts the test case execution asynchronously.
# The return does not indicate test completion.
# To check test run status and results, we will use get_test_case_result() below.
run_test_cases(
    policy_arn=policy_arn,               # ARN of the policy to test against
    build_workflow_id=build_workflow_id, # ID of the build workflow
    test_case_ids=[test_case_id],        # List of test case IDs to run
)

# Retrieve Test Case Result

Now, let's retrieve the test case result that we just started.

In [None]:
def get_test_case_result(policy_arn, build_workflow_id, test_case_id):
    """
    Returns the test case result
    
    Args:
        policy_arn (str): ARN of the policy
        build_workflow_id (str): Id of the build workflow
        test_case_id (str): Id of the test case

    Returns:
        dict: Response from the API call
    """
    try:
        return bedrock_client.get_automated_reasoning_policy_test_result(
            policyArn=policy_arn,
            buildWorkflowId=build_workflow_id,
            testCaseId=test_case_id,
        )
    except Exception as e:
        print(f"Error returning test case result: {str(e)}")
        raise

In [None]:
print(f'Policy ARN: {policy_arn}')

# Create widgets for monitoring
status_output = widgets.Output()
progress_bar = widgets.IntProgress(
    value=0,
    min=0,
    max=100,
    description='Progress:',
    bar_style='info',
    orientation='horizontal'
)
status_text = widgets.HTML(value="<b>Status:</b> Initializing...")

# Display the widgets
display(status_text)
display(progress_bar)
display(status_output)

# Monitor the test status
max_attempts = 30
poll_interval = 10  # seconds

test_status = None
test_run_result = None

for attempt in range(max_attempts):
    # Check test status
    response = get_test_case_result(policy_arn, build_workflow_id, test_case_id)

    # Check status
    test_status = response.get('testResult', {}).get('testRunStatus', 'UNKNOWN')
    test_run_result = response.get('testResult', {}).get('testRunResult', 'UNKNOWN')
    
    # Calculate progress
    progress = 0
    if test_status == 'NOT_STARTED':
        progress = 0
    if test_status == 'SCHEDULED':
        progress = 10
    elif test_status == 'IN_PROGRESS':
        progress = 20
    elif test_status == 'TESTING':
        progress = 75
    elif test_status == 'COMPLETED':
        progress = 100
    elif test_status == 'FAILED':
        progress = 100
    
    # Update the widgets
    progress_bar.value = progress
    status_text.value = f"<b>Status:</b> Test status: {test_status}, Test Case ID {test_case_id}"
    
    with status_output:
        print(f"Check {attempt + 1}: Test status: {test_status}, Test Case ID: {test_case_id}")
    
    # If the test is complete, then we are done
    if progress >= 100:
        if test_status == 'COMPLETED' and test_run_result == 'PASSED':
            progress_bar.bar_style = 'success'
        elif test_status == 'FAILED' or test_run_result == 'FAILED':
            progress_bar.bar_style = 'danger'
            with status_output:
                print(f"Test evaulation failed or was cancelled.")
        break
    
    # Wait before the next check
    if attempt < max_attempts - 1:
        time.sleep(poll_interval)

# Final status update
if progress >= 100:
    if test_status == 'COMPLETED' and test_run_result == 'PASSED':
        status_text.value = f"<b>Status:</b> Test passed!"
    elif test_status == 'FAILED' or test_run_result == 'FAILED':
        status_text.value = f"<b>Status:</b> Test failed!"
    else:
        status_text.value = f"<b>Status:</b> UNKNOWN!"
else:
    status_text.value = f"<b>Status:</b> Test is not yet complete, Test Status: {test_status}"

If the test has not either failed or completed, re-run the previous cell.

In [None]:
# Retrieve test case result using policy ARN, build workflow ID and test case ID
test_case_result_response = get_test_case_result(
    policy_arn=policy_arn,
    build_workflow_id=build_workflow_id,
    test_case_id=test_case_id,
)

# Convert test case result response to JSON format for display/processing
JSON(test_case_result_response)