In [1]:
from dotenv import load_dotenv
import os

# Load environment variables from .env file
load_dotenv()

# Create an API Client 
from anthropic import Anthropic
client = Anthropic()
model = "claude-sonnet-4-0"

In [2]:
# Helper functions
def add_user_message(messages, text):
    user_message = {"role": "user", "content": text}
    messages.append(user_message)


def add_assistant_message(messages, text):
    assistant_message = {"role": "assistant", "content": text}
    messages.append(assistant_message)


def chat(messages, system=None, temperature=1.0, stop_sequences=[]):
    params = {
        "model": model,
        "max_tokens": 1000,
        "messages": messages,
        "temperature": temperature,
        "stop_sequences": stop_sequences,
    }

    if system:
        params["system"] = system

    message = client.messages.create(**params)
    return message.content[0].text

In [3]:
# Function to grade a test case + output using a model
import json
def grade_by_model(test_case, output):
    eval_prompt = f"""
You are an expert AWS code reviewer. Your task is to evaluate the following AI-generated solution.

Original Task:
<task>
{test_case["task"]}
</task>

Solution to Evaluate:
<solution>
{output}
</solution>

Output Format
Provide your evaluation as a structured JSON object with the following fields, in this specific order:
- "strengths": An array of 1-3 key strengths
- "weaknesses": An array of 1-3 key areas for improvement
- "reasoning": A concise explanation of your overall assessment
- "score": A number between 1-10

Respond with JSON. Keep your response concise and direct.
Example response shape:
{{
    "strengths": string[],
    "weaknesses": string[],
    "reasoning": string,
    "score": number
}}
    """

    messages = []
    add_user_message(messages, eval_prompt)
    add_assistant_message(messages, "```json")
    eval_text = chat(messages, stop_sequences=["```"])
    return json.loads(eval_text)

In [4]:
# Passes a test case into Claude
def run_prompt(test_case):
    prompt = f"""
Please solve the following task:

{test_case["task"]}
"""

    messages = []
    add_user_message(messages, prompt)
    output = chat(messages)
    return output

In [5]:
# Function to execute a single test case and grade the output
def run_test_case(test_case):
    """Calls run_prompt, then grades the result"""
    output = run_prompt(test_case)

    model_grade = grade_by_model(test_case, output)
    score = model_grade["score"]
    reasoning = model_grade["reasoning"]

    return {
        "output": output,
        "test_case": test_case,
        "score": score,
        "reasoning": reasoning,
    }

In [6]:
from statistics import mean


def run_eval(dataset):
    """Loads the dataset and calls run_test_case with each case"""
    results = []

    for test_case in dataset:
        result = run_test_case(test_case)
        results.append(result)

    average_score = mean([result["score"] for result in results])
    print(f"Average score: {average_score}")

    return results

In [8]:
with open("test_data.json", "r") as f:
    dataset = json.load(f)

results = run_eval(dataset)

Average score: 7


In [9]:
results

[{'output': 'Here\'s a Python function to extract the AWS region from an ARN:\n\n```python\ndef extract_region_from_arn(arn):\n    """\n    Extract the AWS region from an Amazon Resource Name (ARN).\n    \n    Args:\n        arn (str): The Amazon Resource Name\n        \n    Returns:\n        str: The AWS region code, or None if not found or invalid ARN\n        \n    Raises:\n        ValueError: If the ARN format is invalid\n    """\n    if not arn or not isinstance(arn, str):\n        raise ValueError("ARN must be a non-empty string")\n    \n    # Split the ARN by colons\n    arn_parts = arn.split(\':\')\n    \n    # ARN format: arn:partition:service:region:account-id:resource\n    # Minimum 6 parts required\n    if len(arn_parts) < 6:\n        raise ValueError("Invalid ARN format. ARN must have at least 6 parts separated by colons")\n    \n    # Check if it starts with \'arn\'\n    if arn_parts[0] != \'arn\':\n        raise ValueError("Invalid ARN format. ARN must start with \'arn:\