In [108]:
import os
import yaml
import glob
import Levenshtein as lev
from sklearn.metrics import precision_score, recall_score, f1_score

In [109]:
def load_yaml(file_path):
    with open(file_path, 'r') as stream:
        try:
            return yaml.safe_load(stream)
        except yaml.YAMLError as exc:
            print(exc)

In [110]:
def compare_items(true_items, detected_items):
    TP = FP = FN = 0
    
    for item in detected_items:
        if item in true_items:
            TP += 1
        else:
            FP += 1
    
    for item in true_items:
        if item not in detected_items:
            FN += 1
            
    return TP, FP, FN

In [111]:
def compare_constructs(true_constructs, detected_constructs, max_distance=3):
    true_set = set(true_constructs.values())
    detected_set = set(detected_constructs.values())
    TP = sum(1 for det in detected_set if any(is_similar(det, tru, max_distance) for tru in true_set))
    FP = len(detected_set) - TP
    FN = len(true_set) - TP
    return TP, FP, FN

# maybe remove special characters first

In [112]:
def compare_hypotheses(true_constructs, detected_constructs, true_hypotheses, detected_hypotheses):
    # Translate hypothesis keys to construct names for true data
    true_hypotheses_translated = {(true_constructs[h['cause']], true_constructs[h['effect']]) for h in true_hypotheses.values()}
    # Translate hypothesis keys to construct names for detected data
    detected_hypotheses_translated = {(detected_constructs[h['cause']], detected_constructs[h['effect']]) for h in detected_hypotheses.values()}

    TP = len(true_hypotheses_translated.intersection(detected_hypotheses_translated))
    FP = len(detected_hypotheses_translated - true_hypotheses_translated)
    FN = len(true_hypotheses_translated - detected_hypotheses_translated)

    return TP, FP, FN

In [113]:
"""
def compare_texts(true_texts, detected_texts, max_distance=3):
    TP = sum(1 for det_text in detected_texts if any(is_similar(det_text, true_text, max_distance) for true_text in true_texts))
    FP = len(detected_texts) - TP
    FN = len(true_texts) - TP
    return TP, FP, FN
"""

'\ndef compare_texts(true_texts, detected_texts, max_distance=3):\n    TP = sum(1 for det_text in detected_texts if any(is_similar(det_text, true_text, max_distance) for true_text in true_texts))\n    FP = len(detected_texts) - TP\n    FN = len(true_texts) - TP\n    return TP, FP, FN\n'

In [114]:
def is_similar(str1, str2, max_distance=3):
    return lev.distance(str1, str2) <= max_distance

In [115]:
def calculate_metrics(TP, FP, FN):
    precision = TP / (TP + FP) if TP + FP > 0 else 0
    recall = TP / (TP + FN) if TP + FN > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
    return precision, recall, f1

In [116]:
# Load the YAML files
ground_truth_files = sorted(glob.glob('chatGPT_results/*.yaml'))
extracted_files = sorted(glob.glob('true_results/*.yaml'))

In [117]:
ground_truth_files = [element.replace("\\", "/") for element in ground_truth_files]
extracted_files = [element.replace("\\", "/") for element in extracted_files]

extracted_files

['true_results/diagram1.yaml',
 'true_results/diagram2.yaml',
 'true_results/diagram28.yaml',
 'true_results/diagram3.yaml',
 'true_results/diagram30.yaml',
 'true_results/diagram4.yaml',
 'true_results/diagram5.yaml',
 'true_results/diagram50.yaml',
 'true_results/diagram6.yaml',
 'true_results/diagram7.yaml',
 'true_results/diagram8.yaml',
 'true_results/diagram87.yaml',
 'true_results/diagram99.yaml']

In [118]:
ground_truth = []
extracted_data = []

# Load files
for gt_file, ex_file in zip(ground_truth_files, extracted_files):
    ground_truth.append(load_yaml(gt_file))
    extracted_data.append(load_yaml(ex_file))

In [119]:

# Initialize counters for constructs and hypotheses
constructs_TP = constructs_FP = constructs_FN = 0
hypotheses_TP = hypotheses_FP = hypotheses_FN = 0

# Process each file
for gt_file, ex_file in zip(ground_truth_files, extracted_files):
    ground_truth = load_yaml(gt_file)
    extracted_data = load_yaml(ex_file)

    # If either file failed to load properly, skip this pair
    if ground_truth is None or extracted_data is None:
        print(f"Error loading files: {gt_file}, {ex_file}")
        continue
    if 'constructs' not in ground_truth or 'constructs' not in extracted_data:
        print(f"Missing 'constructs' in files: {gt_file}, {ex_file}")
        continue
    
    # Compare constructs
    true_constructs = ground_truth.get('constructs', {})
    detected_constructs = extracted_data.get('constructs', {})
    TP, FP, FN = compare_constructs(true_constructs, detected_constructs)
    constructs_TP += TP
    constructs_FP += FP
    constructs_FN += FN

    # Compare hypotheses using the updated function
    true_hypotheses = ground_truth.get('hypotheses', {})
    detected_hypotheses = extracted_data.get('hypotheses', {})
    TP, FP, FN = compare_hypotheses(true_constructs, detected_constructs, true_hypotheses, detected_hypotheses)
    hypotheses_TP += TP
    hypotheses_FP += FP
    hypotheses_FN += FN

    # Compare texts
    #TP, FP, FN = compare_texts(ground_truth['texts'], extracted_data['texts'])
    #all_TP += TP
    #all_FP += FP
    #all_FN += FN


In [120]:
# Calculate and print metrics for constructs
constructs_precision, constructs_recall, constructs_f1 = calculate_metrics(constructs_TP, constructs_FP, constructs_FN)
print(f"Constructs - Precision: {constructs_precision:.2f}, Recall: {constructs_recall:.2f}, F1 Score: {constructs_f1:.2f}")

# Calculate and print metrics for hypotheses
hypotheses_precision, hypotheses_recall, hypotheses_f1 = calculate_metrics(hypotheses_TP, hypotheses_FP, hypotheses_FN)
print(f"Hypotheses - Precision: {hypotheses_precision:.2f}, Recall: {hypotheses_recall:.2f}, F1 Score: {hypotheses_f1:.2f}")

Constructs - Precision: 0.85, Recall: 0.90, F1 Score: 0.87
Hypotheses - Precision: 0.47, Recall: 0.51, F1 Score: 0.49
