### Analyze the experiment for different intervention strenghts (alpha) from -0.1 to 0.1 and evaluate the results

In [3]:
import os
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from datasets import load_from_disk

# Define the base directory
base_dir = "/home/jholshuijsen/reasoning-reciting-probing/results/chess/intervention/liref"

# Define the alphas to analyze
alphas = [-0.10, -0.05, 0.00, 0.05, 0.10, 0.15]

# Function to load data for a specific alpha
def load_data_for_alpha(alpha):
    # Format alpha with 2 decimal places to preserve trailing zeros
    alpha_str = f"{alpha:.2f}"
    alpha_dir = os.path.join(base_dir, alpha_str)
    
    # Check if directory exists
    if not os.path.exists(alpha_dir):
        print(f"Directory for alpha={alpha_str} not found: {alpha_dir}")
        return None

    # Load the first file found (assuming there's only one combined file)
    dataset = load_from_disk(alpha_dir)
    
    print(f"Loaded data for alpha={alpha_str}, found {len(dataset)} examples")
    return dataset

# Load data for all alphas
datasets = {}
for alpha in alphas:
    data = load_data_for_alpha(alpha)
    if data:
        datasets[alpha] = data


print(f"Successfully loaded data for {len(datasets)} alpha values")


Loaded data for alpha=-0.10, found 800 examples
Loaded data for alpha=-0.05, found 800 examples
Loaded data for alpha=0.00, found 800 examples
Loaded data for alpha=0.05, found 800 examples
Loaded data for alpha=0.10, found 800 examples
Loaded data for alpha=0.15, found 800 examples
Successfully loaded data for 6 alpha values


In [4]:
# load the jsonl file with the original prompt questions:
import json
# Print the path to help debug file not found errors
import os
print(f"Current working directory: {os.getcwd()}")
print(f"Does the file exist? {os.path.exists('inputs/chess/data/chess_data.jsonl')}")
print(f"Checking parent directories:")
for i in range(4):  # Check up to 3 levels up
    path = os.path.join(*(['..'] * i), 'inputs/chess/data/chess_data.jsonl')
    print(f"  {path}: {os.path.exists(path)}")


# Load the original prompt questions from the jsonl file
chess_data_path = "../../../../inputs/chess/data/chess_data.jsonl"
chess_questions = []

with open(chess_data_path, 'r') as f:
    for line in f:
        chess_questions.append(json.loads(line))

print(f"Loaded {len(chess_questions)} chess questions from {chess_data_path}")

Current working directory: /gpfs/home5/jholshuijsen/reasoning-reciting-probing/code/probes/notebooks/utils
Does the file exist? False
Checking parent directories:
  inputs/chess/data/chess_data.jsonl: False
  ../inputs/chess/data/chess_data.jsonl: False
  ../../inputs/chess/data/chess_data.jsonl: False
  ../../../inputs/chess/data/chess_data.jsonl: False
Loaded 800 chess questions from ../../../../inputs/chess/data/chess_data.jsonl


In [5]:
def evaluate_llm_response(dataset):    
    yes_count = 0
    no_count = 0
    invalid_count = 0
    invalid_answers = []
    
    results = []
    
    for i in range(len(dataset)):
        response = dataset[i]['intervention_response']
        
        # Extract the assistant's response part
        if "assistant" in response:
            answer_part = response.split("assistant", 1)[1].strip()
        else:
            print(f"Warning: Could not identify assistant part in response for row {dataset[i]['row']}")
            invalid_count += 1
            invalid_answers.append(response)
            results.append("invalid")
            continue
        
        # Look for boxed yes/no
        has_yes = r"\boxed{yes}" in answer_part
        has_no = r"\boxed{no}" in answer_part
        
        # Check if both yes and no are present
        answer_part = answer_part.replace(".", " ")
        answer_part = answer_part.strip()
        last_line = answer_part.split("\n")[-1]
        success = False
        
        if has_yes and has_no:
            #print(f"Warning: Both \\boxed{{yes}} and \\boxed{{no}} found in answer for row {dataset[i]['row']}")
            invalid_answers.append(answer_part)
            invalid_count += 1
            results.append('invalid')
            continue
            
        if has_yes:
            results.append('yes')
            success = True
        elif answer_part.endswith('is legal'):
            results.append("yes")
            success = True
        elif 'are legal' in last_line or 'is legal' in last_line:
            success = True
            results.append("yes")

        if has_no:
            results.append("no")
            success = True
        elif answer_part.endswith('illegal') or answer_part.endswith('not legal'):
            results.append("no")
            success = True
        elif 'not legal' in last_line or 'illegal' in last_line or 'not valid' in last_line or 'not a valid' in last_line:
            results.append("no")
            success = True

        if not success:
            #print(f"Warning: Could not identify an answer for row {i}")
            invalid_answers.append(answer_part)
            results.append("invalid")

    yes_count = results.count("yes")
    no_count = results.count("no")
    invalid_count = results.count("invalid")
    evaluated_answer_count = len(results)
    if evaluated_answer_count != len(dataset):
        print("Warning: some answers seem to be ambiguous as total evaluated answer length" 
              f"{evaluated_answer_count} does not match the dataset length {len(dataset)}"
              )
    
    print(f"Yes answers: {yes_count}, No answers: {no_count}, Invalid answers: {invalid_count}")
    
    return results, invalid_answers


In [6]:
def calculate_accuracy_metrics(dataset, results):
    rw_correct_yes = 0
    rw_correct_no = 0
    rw_incorrect_yes = 0
    rw_incorrect_no = 0
    cf_correct_yes = 0
    cf_correct_no = 0
    cf_incorrect_yes = 0
    cf_incorrect_no = 0

    for i in range(len(dataset)):
        if results[i] == 'invalid':
            continue
        if dataset[i]['mode'] == 'real_world':
            if results[i] == 'yes':
                if dataset[i]['real_world_answer']:
                    rw_correct_yes += 1
                else:
                    rw_incorrect_yes += 1
            elif results[i] == 'no':
                if not dataset[i]['real_world_answer']:
                    rw_correct_no += 1
                else:
                    rw_incorrect_no += 1
          
        if dataset[i]['mode'] == 'counter_factual':
            if results[i] == 'yes':
                if not dataset[i]['counter_factual_answer']:
                    cf_correct_yes += 1
                else:
                    cf_incorrect_yes += 1
            elif results[i] == 'no':
                if dataset[i]['counter_factual_answer']:
                    cf_correct_no += 1
                else:
                    cf_incorrect_no += 1
     
    rw_correct = rw_correct_yes + rw_correct_no
    cf_correct = cf_correct_yes + cf_correct_no
    rw_incorrect = rw_incorrect_yes + rw_incorrect_no   
    cf_incorrect = cf_incorrect_yes + cf_incorrect_no
    
    print(f"Real world correct: {rw_correct_yes} yes, {rw_correct_no} no")
    print(f"Real world incorrect: {rw_incorrect_yes} yes, {rw_incorrect_no} no")
    print(f"Counter factual correct: {cf_correct_yes} yes, {cf_correct_no} no")
    print(f"Counter factual incorrect: {cf_incorrect_yes} yes, {cf_incorrect_no} no")
    print(f"Accuracy real world: {rw_correct / (rw_correct + rw_incorrect) if (rw_correct + rw_incorrect) > 0 else 'N/A'}")
    print(f"Accuracy counter factual: {cf_correct / (cf_correct + cf_incorrect) if (cf_correct + cf_incorrect) > 0 else 'N/A'}")
    
    return {
        'rw_correct_yes': rw_correct_yes,
        'rw_correct_no': rw_correct_no,
        'rw_incorrect_yes': rw_incorrect_yes,
        'rw_incorrect_no': rw_incorrect_no,
        'cf_correct_yes': cf_correct_yes,
        'cf_correct_no': cf_correct_no,
        'cf_incorrect_yes': cf_incorrect_yes,
        'cf_incorrect_no': cf_incorrect_no,
        'rw_accuracy': rw_correct / (rw_correct + rw_incorrect) if (rw_correct + rw_incorrect) > 0 else None,
        'cf_accuracy': cf_correct / (cf_correct + cf_incorrect) if (cf_correct + cf_incorrect) > 0 else None
    }


In [7]:
for alpha, dataset in datasets.items():
    print(f"Evaluating dataset for alpha={alpha}")
    results, invalid_answers = evaluate_llm_response(dataset)
    calculate_accuracy_metrics(chess_questions, results)


Evaluating dataset for alpha=-0.1
Yes answers: 315, No answers: 469, Invalid answers: 24
Real world correct: 95 yes, 134 no
Real world incorrect: 66 yes, 104 no
Counter factual correct: 74 yes, 108 no
Counter factual incorrect: 80 yes, 115 no
Accuracy real world: 0.5739348370927319
Accuracy counter factual: 0.4827586206896552
Evaluating dataset for alpha=-0.05
Yes answers: 343, No answers: 434, Invalid answers: 25
Real world correct: 96 yes, 130 no
Real world incorrect: 70 yes, 104 no
Counter factual correct: 85 yes, 97 no
Counter factual incorrect: 91 yes, 102 no
Accuracy real world: 0.565
Accuracy counter factual: 0.48533333333333334
Evaluating dataset for alpha=0.0
Yes answers: 303, No answers: 485, Invalid answers: 14
Real world correct: 95 yes, 127 no
Real world incorrect: 72 yes, 102 no
Counter factual correct: 65 yes, 125 no
Counter factual incorrect: 71 yes, 129 no
Accuracy real world: 0.5606060606060606
Accuracy counter factual: 0.48717948717948717
Evaluating dataset for alpha

In [8]:
data = datasets[0.1]

for i in range(len(data)):
    print(data[i]['intervention_response'].split("assistant")[1].strip())

Dataset({
    features: ['prompt', 'intervention_response'],
    num_rows: 800
})