In [1]:
import json

In [95]:
with open("Qwen_generation_1dot5_B_Base.json") as f:
    base_1_5b_output = json.load(f)

with open("Qwen_generation_7B_Base.json") as f:
    base_7b_output = json.load(f)
    
with open("Qwen_generation_GRPO_NL_7B.json") as f:
    grpo_nl_7b_output = json.load(f)
    
with open("Qwen_generation_SFT_GRPO_1dot5_B.json") as f:
    grpo_sft_1_5b_output = json.load(f)
    
with open("../output/gpt-4o-zero-CoT-combined.json", 'r') as f:
    gpt_4o_zero_cot_output = json.load(f)

In [8]:
len(base_1_5b_output), len(base_7b_output), len(grpo_nl_7b_output), len(grpo_sft_1_5b_output)

(450, 450, 450, 450)

In [38]:
with open("test_3_characters.jsonl", "r") as f:
    test_3_characters = [json.loads(line) for line in f]

with open("test_4_characters.jsonl", "r") as f:
    test_4_characters = [json.loads(line) for line in f]
    
with open("test_5_characters.jsonl", "r") as f:
    test_5_characters = [json.loads(line) for line in f]

In [39]:
len(test_3_characters), len(test_4_characters), len(test_5_characters)

(150, 150, 150)

In [None]:
import re

def extract_answers(text: str) -> dict:
    pattern = r"([A-E]):\s*(.*?)(?=(?:[A-E]:|$))"
    new_text = text.split("<answer>")[-1].split("</answer>")[0]
    text = new_text
    
    # Use DOTALL to ensure the dot matches newlines as well.
    matches = re.findall(pattern, text, re.DOTALL)
    
    answers = {}
    for letter, answer in matches:
        # First strip whitespace, then remove trailing unwanted characters
        cleaned = answer.strip()
        # Remove any trailing commas, curly braces, square brackets, newlines, or backslashes
        cleaned = re.sub(r'[\}\]\n\\,]+$', '', cleaned)
        cleaned = cleaned.strip()  # final clean-up
        answers[letter] = ('truth' in cleaned)
    return answers

In [107]:
test_3_range = range(0, 150)
test_4_range = range(150, 300)
test_5_range = range(300, 450)

In [108]:
extracted_output = [extract_answers(base_1_5b_output[key]) for key in base_1_5b_output]

In [109]:
import json

def calculate_metrics(output):
    # Load test files
    with open("test_3_characters.jsonl", "r") as f:
        test_3_characters = [json.loads(line) for line in f]

    with open("test_4_characters.jsonl", "r") as f:
        test_4_characters = [json.loads(line) for line in f]
        
    with open("test_5_characters.jsonl", "r") as f:
        test_5_characters = [json.loads(line) for line in f]
    
    # Extract predicted answers using your extraction function
    extracted_output = [extract_answers(output[key]) for key in output]
    character_3 = extracted_output[:150]
    character_4 = extracted_output[150:300]
    character_5 = extracted_output[300:]
    
    # Helper function to compute precision, recall, and F1 for one instance
    def compute_metrics(predicted, solution):
        tp = sum(predicted[k] and solution[k] for k in predicted)  # True Positives
        fp = sum(predicted[k] and not solution[k] for k in predicted)  # False Positives
        fn = sum(not predicted[k] and solution[k] for k in predicted)  # False Negatives

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0

        return precision, recall, f1_score

    # Initialize counts and metric sums per case
    count_3 = 0
    count_4 = 0
    count_5 = 0
    
    precision_sum_3 = 0.0
    recall_sum_3 = 0.0
    f1_sum_3 = 0.0
    
    precision_sum_4 = 0.0
    recall_sum_4 = 0.0
    f1_sum_4 = 0.0
    
    precision_sum_5 = 0.0
    recall_sum_5 = 0.0
    f1_sum_5 = 0.0
    
    # Process 3-character instances
    for idx, each in enumerate(character_3):
        solution = test_3_characters[idx]['solutions'][0]
        if set(each.keys()) == set(solution.keys()) and all(each[k] == solution[k] for k in each.keys()):
            count_3 += 1
        precision, recall, f1 = compute_metrics(each, solution)
        precision_sum_3 += precision
        recall_sum_3 += recall
        f1_sum_3 += f1

    # Process 4-character instances
    for idx, each in enumerate(character_4):
        solution = test_4_characters[idx]['solutions'][0]
        if set(each.keys()) == set(solution.keys()) and all(each[k] == solution[k] for k in each.keys()):
            count_4 += 1
        precision, recall, f1 = compute_metrics(each, solution)
        precision_sum_4 += precision
        recall_sum_4 += recall
        f1_sum_4 += f1

    # Process 5-character instances
    for idx, each in enumerate(character_5):
        solution = test_5_characters[idx]['solutions'][0]
        if set(each.keys()) == set(solution.keys()) and all(each[k] == solution[k] for k in each.keys()):
            count_5 += 1
        precision, recall, f1 = compute_metrics(each, solution)
        precision_sum_5 += precision
        recall_sum_5 += recall
        f1_sum_5 += f1

    # Compute accuracies (percentages)
    acc_3 = count_3 / 150 * 100
    acc_4 = count_4 / 150 * 100
    acc_5 = count_5 / 150 * 100
    total_acc = (count_3 + count_4 + count_5) / 450 * 100

    # Average metrics per case (converted to percentages)
    avg_precision_3 = (precision_sum_3 / 150) * 100
    avg_recall_3 = (recall_sum_3 / 150) * 100
    avg_f1_3 = (f1_sum_3 / 150) * 100

    avg_precision_4 = (precision_sum_4 / 150) * 100
    avg_recall_4 = (recall_sum_4 / 150) * 100
    avg_f1_4 = (f1_sum_4 / 150) * 100

    avg_precision_5 = (precision_sum_5 / 150) * 100
    avg_recall_5 = (recall_sum_5 / 150) * 100
    avg_f1_5 = (f1_sum_5 / 150) * 100

    # Calculate overall averages by summing over all 450 cases
    total_precision = (precision_sum_3 + precision_sum_4 + precision_sum_5) / 450 * 100
    total_recall = (recall_sum_3 + recall_sum_4 + recall_sum_5) / 450 * 100
    total_f1 = (f1_sum_3 + f1_sum_4 + f1_sum_5) / 450 * 100

    # Return all metrics rounded to 2 decimal places
    return {
        "accuracy": {
            "character_3": round(acc_3, 2),
            "character_4": round(acc_4, 2),
            "character_5": round(acc_5, 2),
            "overall": round(total_acc, 2)
        },
        "precision": {
            "character_3": round(avg_precision_3, 2),
            "character_4": round(avg_precision_4, 2),
            "character_5": round(avg_precision_5, 2),
            "overall": round(total_precision, 2)
        },
        "recall": {
            "character_3": round(avg_recall_3, 2),
            "character_4": round(avg_recall_4, 2),
            "character_5": round(avg_recall_5, 2),
            "overall": round(total_recall, 2)
        },
        "f1": {
            "character_3": round(avg_f1_3, 2),
            "character_4": round(avg_f1_4, 2),
            "character_5": round(avg_f1_5, 2),
            "overall": round(total_f1, 2)
        }
    }

In [110]:
outputs = {
    "base_1_5b_output": base_1_5b_output,
    "base_7b_output": base_7b_output,
    "grpo_nl_7b_output": grpo_nl_7b_output,
    "grpo_sft_1_5b_output": grpo_sft_1_5b_output,
    "gpt-4o-zero-CoT": gpt_4o_zero_cot_output
}
for output in outputs:
    metrics = calculate_metrics(outputs[output])
    print(f"Output for {output}:\n")

    print("------------Evaluation Metrics:------------\n")

    for metric_type, values in metrics.items():
        print(f"{metric_type.upper()}:")
        for case, value in values.items():
            print(f"  {case.capitalize()}: {value}%")
        print()
    print("="*80)


Output for base_1_5b_output:

------------Evaluation Metrics:------------

ACCURACY:
  Character_3: 29.33%
  Character_4: 16.0%
  Character_5: 9.33%
  Overall: 18.22%

PRECISION:
  Character_3: 46.78%
  Character_4: 40.83%
  Character_5: 44.49%
  Overall: 44.03%

RECALL:
  Character_3: 50.89%
  Character_4: 45.17%
  Character_5: 49.39%
  Overall: 48.48%

F1:
  Character_3: 46.91%
  Character_4: 40.63%
  Character_5: 44.15%
  Overall: 43.9%

Output for base_7b_output:

------------Evaluation Metrics:------------

ACCURACY:
  Character_3: 24.0%
  Character_4: 18.0%
  Character_5: 7.33%
  Overall: 16.44%

PRECISION:
  Character_3: 42.44%
  Character_4: 42.44%
  Character_5: 45.29%
  Overall: 43.39%

RECALL:
  Character_3: 40.78%
  Character_4: 45.78%
  Character_5: 48.5%
  Overall: 45.02%

F1:
  Character_3: 38.89%
  Character_4: 41.13%
  Character_5: 43.2%
  Overall: 41.07%

Output for grpo_nl_7b_output:

------------Evaluation Metrics:------------

ACCURACY:
  Character_3: 25.33%
  Char