# Automated Reasoning Valid at N Playground

This notebook demonstrates how to run a "Valid at N" experiment to evaluate how many iterations of response rewriting are needed before the response complies with our policy rules.

It includes the following:

1. Setting up the Bedrock client with custom API models
2. Runs "Valid at N" experiment 

## Setup

First, let's import the necessary libraries and set up the Bedrock client with custom API models.

In [None]:

%pip install -r requirements.txt

In [None]:
import os
import json
import boto3
import uuid
import time
import pandas as pd
from IPython.display import display, HTML, JSON
import ipywidgets as widgets
from datetime import datetime
from policy_definition import get_policy_definition
from rewrite import summarize_results

In [None]:
# Create the Bedrock client
REGION_NAME="us-west-2" # Fill in the AWS Region
my_session = boto3.session.Session()
runtime_client = my_session.client('bedrock-runtime', region_name=REGION_NAME)
bedrock_client = my_session.client('bedrock', region_name=REGION_NAME)

In [None]:
# Policy arn for which policy definitions will be retrieved
AR_POLICY_ARN="<AR_POLICY_ARN>"

# Guardrail id which will be used for calling ApplyGuardrails API
GUARDRAIL_ID = "<GUARDRAIL_ID>"

# Guardrail version which will be used for calling ApplyGuardrails API
GUARDRAIL_VERSION = "<GUARDRAIL_VERSION>"

# Id of the model used by bedrock when generating LLM responses
MODEL_ID="<MODEL_ID>"

# Valid at N Experiment

## Purpose
This code implements a "Valid at N" experiment to evaluate how many iterations of response rewriting are needed before the response complies with our policy rules.

## Concept
- **Valid at N**: N represents the number of iterations required before a response is deemed "valid" by the automated reasoning guardrails.
- N=1 means the original response was valid without any rewriting
- Higher N values indicate more complex policy violations that required multiple rewrites

## Methodology
1. Start with an initial query and LLM response
2. Apply guardrails to check if response is valid
3. If valid or too_complex, record N and stop
4. If not valid, rewrite the response and repeat the process
5. Continue until a valid response is found or max iterations reached

## Analysis Value
- **Policy Compliance**: Understand how well initial LLM responses comply with policies
- **Rewriting Effectiveness**: Measure how efficiently the rewriting process resolves policy violations
- **Error Distribution**: Identify common error patterns and their resolution complexity
- **Optimization Opportunities**: Determine where to focus prompt engineering efforts

### How to Interpret Results

- **N=1**: Perfect! The original response is already valid.
- **N=2**: Good. Only one rewrite was needed to make the response valid.
- **N=3+**: Concerning. Multiple rewrites were needed, suggesting significant misalignment with policy.
- **N≥max_iterations**: Critical. The response couldn't be made valid within the iteration limit.

In [None]:
def valid_at_n_experiment(
    user_query, 
    initial_llm_response, 
    policy_definition, 
    guardrail_id, 
    guardrail_version, 
    bedrock_runtime_client,
    model_id=MODEL_ID, 
    max_iterations=5
):
    """
    Performs a 'Valid at N' experiment by repeatedly applying summarize_results
    until a valid response is produced.
    """
    # Initialize result tracking
    results = {
        "query": user_query,
        "original_response": initial_llm_response,
        "iterations": [],
        "n_value": None,
        "final_valid_response": None
    }
    
    # Current response to validate (starts with initial response)
    current_response = initial_llm_response
    
    # Iterate until we find a valid response or reach max iterations
    for iteration in range(1, max_iterations + 1):
        # Apply summarize_results to get findings and potentially rewrite
        iteration_result = summarize_results(
            user_query=user_query,
            llm_response=current_response,
            policy_definition=policy_definition,
            guardrail_id=guardrail_id,
            guardrail_version=guardrail_version,
            bedrock_runtime_client=bedrock_runtime_client,
            model_id=model_id
        )
        
        # Store the iteration result
        iteration_data = {
            "iteration": iteration,
            "response": current_response,
            "findings": iteration_result["findings"],
            "finding_types": []
        }
        
        # Extract finding types more precisely
        if "**finding type:** valid" in iteration_result["findings"].lower():
            iteration_data["finding_types"].append("valid")
        if "**finding type:** too_complex" in iteration_result["findings"].lower():
            iteration_data["finding_types"].append("too_complex")
        if "**finding type:** invalid" in iteration_result["findings"].lower():
            iteration_data["finding_types"].append("invalid")
        if "**finding type:** satisfiable" in iteration_result["findings"].lower() or "**finding type:** satisfied" in iteration_result["findings"].lower():
            iteration_data["finding_types"].append("satisfiable")
        if "**finding type:** impossible" in iteration_result["findings"].lower():
            iteration_data["finding_types"].append("impossible")
        if "**finding type:** translation_ambiguous" in iteration_result["findings"].lower():
            iteration_data["finding_types"].append("translation_ambiguous")
        if "**finding type:** no_translations" in iteration_result["findings"].lower():
            iteration_data["finding_types"].append("no_translations")
        
        # Add to results
        results["iterations"].append(iteration_data)
        
        # Check if valid
        if "valid" in iteration_data["finding_types"] or "too_complex" in iteration_data["finding_types"]:
            results["n_value"] = iteration
            results["final_valid_response"] = current_response
            break
        
        # If not valid and we have a rewritten response, use that for next iteration
        if iteration_result["rewritten_response"]:
            iteration_data["rewritten_to"] = iteration_result["rewritten_response"]
            current_response = iteration_result["rewritten_response"]
        else:
            # If no rewritten response but not valid, we can't continue
            break
    
    # If we reached max iterations without finding validity
    if results["n_value"] is None:
        results["n_value"] = f">= {max_iterations}"
    
    return results

def display_valid_at_n_results(results):
    """Display the results of a Valid at N experiment in a formatted way."""
    print(f"\n## Valid at N Experiment Results (N = {results['n_value']})")
    print(f"Query: {results['query']}")
    
    print("\n### Original Response")
    print(results['original_response'])
    
    for i, iteration in enumerate(results['iterations']):
        print(f"\n### Iteration {iteration['iteration']}")
        print(f"Finding Types: {', '.join(iteration['finding_types']) if iteration['finding_types'] else 'None'}")
        
        if i > 0:  # Skip printing first response again
            print("\nResponse:")
            print(iteration['response'])
        
        print("\nFindings:")
        print(iteration['findings'])
        
        if 'rewritten_to' in iteration:
            print("\nRewritten To:")
            print(iteration['rewritten_to'])
    
    if results['final_valid_response'] and results['n_value'] != 1:
        print("\n### Final Valid Response")
        print(results['final_valid_response'])

In [None]:
# Example usage
user_query = "<USER_QUERY>"
llm_response = "<LLM_RESPONSE>"

policy_definition = get_policy_definition(bedrock_policy_client, AR_POLICY_ARN)

# Run the Valid at N experiment
results = valid_at_n_experiment(
    user_query,
    llm_response,
    policy_definition,
    GUARDRAIL_ID,
    GUARDRAIL_VERSION,
    bedrock_runtime_client
)

# Display the results
display_valid_at_n_results(results)