### 1. ADE dataset

In [None]:
# 1111 update micro G-BERTScore
# 1110 updated

import json
import time
import logging
import os
from collections import Counter
from typing import List, Tuple, Dict, Set
from kg_gen import KGGen
import bert_score
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import torch
import re
import numpy as np
from scipy.optimize import linear_sum_assignment  # For Hungarian matching
import datetime

# --- Configuration ---
API_KEY = ""
TEST_DATASET_PATH = '../datasets/ade_split_1_test_converted.json'
MODEL_NAME = "openai/gpt-4o"
SENTENCE_MODEL_NAME = 'all-mpnet-base-v2'  # Updated for semantic
NUM_EXAMPLES_TO_TEST = None # whole test set

# --- Paths ---
current_date = datetime.datetime.now().strftime('%m%d')
LOG_FILE_PATH = f'{current_date}_benchmark_kggen_fewshot_test_ade_1.log'
FAILED_ENTRIES_PATH = f'{current_date}_ade_failed_entries.json'

# --- Set up logging ---
logging.basicConfig(level=logging.INFO, filename=LOG_FILE_PATH,
                    filemode='a', format='%(message)s')

# --- Main Script ---
if not API_KEY or "YOUR_API_KEY" in API_KEY:
    logging.warning("API key not set. Please use an environment variable.")
    print("WARNING: API key not found. Please set the API_KEY variable.")

kg = KGGen(
    model=MODEL_NAME,
    temperature=0.1,
    api_key=API_KEY
)

# --- Optimized Few-Shot Prompt (Further Refined: Added Emphasis on Flipping, Reduced Redundancy) ---
EXTRACT_CONTEXT = """
You are a strict data conversion bot. Extract adverse drug events (ADEs) from scientific sentences into triples: ("Specific Adverse Effect", "adverse_effect", "Drug/Chemical").

### CRITICAL RULES
1. **RELATION:** Always use literal "adverse_effect" (lowercase). Only extract causal ADEs (drug causes effect); ignore others.
2. **ORDER:** Always (Effect, "adverse_effect", Drug). Identify effect (symptom/condition) and drug (chemical). Flip order if extracted reversely.
3. **OUTPUT:** ONLY unique triples list. Return [] if none. Avoid generics (e.g., "patient", ages), non-ADEs, negated/uncertain.
- Use precise phrasing; trim unnecessary modifiers.
- Focus on causation indicators (e.g., "induced", "caused", "associated with").

Examples:

Sentence: "Cyclophosphamide is a human teratogen."
Triples: [("human teratogen", "adverse_effect", "cyclophosphamide")]

Sentence: "Lethal anuria complicating high dose ifosfamide chemotherapy."
Triples: [("lethal anuria", "adverse_effect", "ifosfamide")]

Sentence: "Gemcitabine-induced pulmonary toxicity."
Triples: [("pulmonary toxicity", "adverse_effect", "gemcitabine")]

Sentence: "The patient took aspirin for headache without issues."
Triples: []

Sentence: "Senna caused subacute cholestatic hepatitis."
Triples: [("subacute cholestatic hepatitis", "adverse_effect", "senna")]

Sentence: "Fulminant hepatic failure associated with didanosine."
Triples: [("fulminant hepatic failure", "adverse_effect", "didanosine")]

Sentence: "A 34-year-old lady developed a constellation of dermatitis, fever, lymphadenopathy and hepatitis, beginning on the 17th day of a course of oral sulphasalazine for sero-negative rheumatoid arthritis."
Triples: [("constellation of dermatitis", "adverse_effect", "sulphasalazine"), ("fever", "adverse_effect", "sulphasalazine"), ("lymphadenopathy", "adverse_effect", "sulphasalazine"), ("hepatitis", "adverse_effect", "sulphasalazine")]

Apply rules to the text. Extract triples only in ("Effect", "adverse_effect", "Drug") format.
"""

# --- Normalization Logic (Updated to match SciERC) ---
def normalize_entity(entity_text: str) -> str:
    if not isinstance(entity_text, str):
        return ""
    
    text = entity_text.strip().lower()
    text = re.sub(r'^(a|an|the)\s+', '', text, flags=re.IGNORECASE)
    text = re.sub(r'\s*\.$', '', text)
    text = text.strip('\'"')
    
    return text if text else ""

# --- Load dataset ---
try:
    with open(TEST_DATASET_PATH, 'r') as f:
        data = json.load(f)
except FileNotFoundError:
    logging.error(f"FATAL: Test dataset not found at path: {TEST_DATASET_PATH}")
    print(f"FATAL: Test dataset not found at path: {TEST_DATASET_PATH}")
    exit()

if NUM_EXAMPLES_TO_TEST:
    data = data[:NUM_EXAMPLES_TO_TEST]

print(f"Loading sentence transformer model: {SENTENCE_MODEL_NAME}...")
sentence_model = SentenceTransformer(SENTENCE_MODEL_NAME)
print("Model loaded.")

# --- Benchmark Loop ---
all_gold_triples, all_pred_triples, all_pred_strings, all_gold_strings = [], [], [], []
item_details = []
success_count = 0
fail_count = 0
failed_entries = []

for i, item in enumerate(data):
    text = item['text']
    gold_triples_raw = item.get('triple_list', [])
    
    gold_triples = []
    for subj, rel, obj in gold_triples_raw:
        subj_norm = normalize_entity(subj)
        obj_norm = normalize_entity(obj)
        rel_norm = rel.strip().lower().replace(" ", "-")
        if subj_norm and obj_norm:
            gold_triples.append((subj_norm, rel_norm, obj_norm))
    gold_triples = sorted(list(set(gold_triples)))

    logging.info(f"\n----- Processing Item {i+1}/{len(data)} -----")
    logging.info(f"Text: {text}")
    print(f"Processing Item {i+1}/{len(data)}...")
    
    pred_triples = []
    try:
        # Single-Step Extraction with Few-Shot Prompt
        response = kg.generate(input_data=text, context=EXTRACT_CONTEXT)
        pred_triples_raw = response.relations if response and hasattr(response, 'relations') else []
        logging.info(f"KGGen Extraction: {pred_triples_raw}")

        corrected_triples_set = set()
        for triple in pred_triples_raw:
            if len(triple) != 3: continue
            subj, rel, obj = triple
            subj_norm = normalize_entity(subj)
            obj_norm = normalize_entity(obj)
            rel_norm = rel.strip().lower().replace(" ", "-")
            if subj_norm and obj_norm:
                corrected_triples_set.add((subj_norm, rel_norm, obj_norm))

        pred_triples = sorted(list(corrected_triples_set))
        success_count += 1

    except Exception as e:
        logging.error(f"An API or parsing error occurred for item {i+1}: {e}", exc_info=True)
        pred_triples = []
        fail_count += 1
        failed_entries.append({
            'index': i+1,
            'text': text,
            'gold_triples_raw': gold_triples_raw,
            'error': str(e)
        })

    logging.info(f"Gold (Normalized): {gold_triples}")
    logging.info(f"Pred (Normalized): {pred_triples}")

    all_gold_triples.extend(gold_triples)
    all_pred_triples.extend(pred_triples)

    pred_str = " | ".join(sorted([f"{h} {r} {t}" for h, r, t in set(pred_triples)]))
    gold_str = " | ".join(sorted([f"{h} {r} {t}" for h, r, t in set(gold_triples)]))
    all_pred_strings.append(pred_str)
    all_gold_strings.append(gold_str)

    item_details.append({
        'index': i+1,
        'gold_list': set(gold_triples), 'pred_list': set(pred_triples),
        'gold_str': gold_str, 'pred_str': pred_str,
        'gold_triples': gold_triples, 'pred_triples': pred_triples
    })

# Write failed entries to JSON
if failed_entries:
    with open(FAILED_ENTRIES_PATH, 'w') as f:
        json.dump(failed_entries, f, indent=4)
    logging.info(f"Failed entries written to {FAILED_ENTRIES_PATH}")
else:
    logging.info("No failed entries.")

# Log processing counts
logging.info(f"\n----- Processing Summary -----")
logging.info(f"Successful items: {success_count}")
logging.info(f"Failed items: {fail_count}")
print(f"Successful items: {success_count}")
print(f"Failed items: {fail_count}")

# --- Metrics Calculation ---
logging.info("\n----- FINAL METRICS -----")

# Strict Metrics (Unchanged)
gold_count = Counter(all_gold_triples)
pred_count = Counter(all_pred_triples)
tp = sum(min(gold_count[t], pred_count[t]) for t in gold_count)
pred_total = len(all_pred_triples)
gold_total = len(all_gold_triples)
micro_p_strict = tp / pred_total if pred_total > 0 else 0.0
micro_r_strict = tp / gold_total if gold_total > 0 else 0.0
micro_f1_strict = 2 * micro_p_strict * micro_r_strict / \
    (micro_p_strict + micro_r_strict) if micro_p_strict + \
    micro_r_strict > 0 else 0.0

item_ps_strict, item_rs_strict, item_f1s_strict = [], [], []
for detail in item_details:
    tp_item = len(detail['pred_list'] & detail['gold_list'])
    p_item = tp_item / len(detail['pred_triples']) if detail['pred_triples'] else (1.0 if not detail['gold_triples'] else 0.0)
    r_item = tp_item / len(detail['gold_triples']) if detail['gold_triples'] else (1.0 if not detail['pred_triples'] else 0.0)

    if p_item + r_item == 0:
        f1_item = 0.0
    else:
        f1_item = 2 * p_item * r_item / (p_item + r_item)

    if not detail['pred_triples'] and not detail['gold_triples']:
        f1_item = 1.0

    item_ps_strict.append(p_item)
    item_rs_strict.append(r_item)
    item_f1s_strict.append(f1_item)
    logging.info(
        f"Item {detail['index']} Strict: P={p_item:.4f}, R={r_item:.4f}, F1={f1_item:.4f}")

macro_p_strict = np.nanmean(item_ps_strict) if item_ps_strict else 0.0
macro_r_strict = np.nanmean(item_rs_strict) if item_rs_strict else 0.0
macro_f1_strict = np.nanmean(item_f1s_strict) if item_f1s_strict else 0.0

# Base BERTScore (Unchanged, with IDF added for robustness)
from bert_score import score as bert_score_compute

if any(all_pred_strings) and any(all_gold_strings):
    # Base BERTScore (P, R, F1) - Uses greedy max as per formulas
    P_macro, R_macro, F1_macro = bert_score_compute(
        all_pred_strings, all_gold_strings, lang="en", verbose=False, model_type="roberta-large", idf=True
    )
    macro_p_bert = P_macro.mean().item()
    macro_r_bert = R_macro.mean().item()
    macro_f1_bert = F1_macro.mean().item()

    all_pred_concat = ' | '.join(filter(None, all_pred_strings))
    all_gold_concat = ' | '.join(filter(None, all_gold_strings))
    P_micro, R_micro, F1_micro = bert_score_compute(
        [all_pred_concat], [all_gold_concat], lang="en", verbose=False, model_type="roberta-large", idf=False
    )
    micro_p_bert = P_micro.item()
    micro_r_bert = R_micro.item() if not np.isnan(R_micro.item()) else 0.0
    micro_f1_bert = F1_micro.item() if not np.isnan(F1_micro.item()) else 0.0
    
    # Greedy BERTScore (Renamed from G-BS)
    item_gbs_p, item_gbs_r, item_gbs_f1 = [], [], []

    for detail in item_details:
        pred_strings = [f"{h} {r} {t}" for h, r, t in detail['pred_triples']]
        gold_strings = [f"{h} {r} {t}" for h, r, t in detail['gold_triples']]

        if not pred_strings and not gold_strings:
            item_gbs_p.append(1.0)
            item_gbs_r.append(1.0)
            item_gbs_f1.append(1.0)
            continue

        if not pred_strings:  # Has gold, but no preds (all missed)
            item_gbs_p.append(1.0)  # Vacuously true: all 0 preds are "correct"
            item_gbs_r.append(0.0)
            item_gbs_f1.append(0.0)
            continue

        if not gold_strings:  # Has preds, but no gold (all false pos)
            item_gbs_p.append(0.0)
            item_gbs_r.append(1.0)  # Vacuously true: all 0 golds are "found"
            item_gbs_f1.append(0.0)
            continue
        
        # G-BS-P (Precision): Avg. max similarity for each *predicted* triple
        P_gbs, _, _ = bert_score_compute(
            pred_strings,
            [gold_strings] * len(pred_strings),  # Compare each pred to ALL golds
            lang="en", verbose=False,
            model_type="roberta-large", idf=True
        )
        gbs_p = P_gbs.mean().item()

        # G-BS-R (Recall): Avg. max similarity for each *gold* triple
        P_recall, _, _ = bert_score_compute(  # Capture the first value (Precision)
            gold_strings,
            [pred_strings] * len(gold_strings),  # Compare each gold to ALL preds
            lang="en", verbose=False,
            model_type="roberta-large", idf=True
        )
        gbs_r = P_recall.mean().item()  # Use the Precision from the flipped call
        
        gbs_f1 = 0.0
        if (gbs_p + gbs_r) > 0:
            gbs_f1 = 2 * gbs_p * gbs_r / (gbs_p + gbs_r)

        item_gbs_p.append(gbs_p)
        item_gbs_r.append(gbs_r)
        item_gbs_f1.append(gbs_f1)

    # Macro averages (mean over items)
    macro_gbs_p = np.nanmean(item_gbs_p) if item_gbs_p else 0.0
    macro_gbs_r = np.nanmean(item_gbs_r) if item_gbs_r else 0.0
    macro_gbs_f1 = np.nanmean(item_gbs_f1) if item_gbs_f1 else 0.0

    # Micro for Greedy BS: Flatten all triples and compute global greedy
    all_pred_triple_strs = [f"{h} {r} {t}" for triples in item_details for h, r, t in triples['pred_triples']]
    all_gold_triple_strs = [f"{h} {r} {t}" for triples in item_details for h, r, t in triples['gold_triples']]
    
    if all_pred_triple_strs and all_gold_triple_strs:
        # Micro P: Avg max sim for each pred vs all golds
        P_micro_g, _, _ = bert_score_compute(
            all_pred_triple_strs,
            [all_gold_triple_strs] * len(all_pred_triple_strs),
            lang="en", verbose=False, model_type="roberta-large", idf=True
        )
        micro_gbs_p = np.nanmean(P_micro_g.cpu().numpy())

        # Micro R: Avg max sim for each gold vs all preds
        P_micro_r, _, _ = bert_score_compute(  # Capture the first value (Precision)
            all_gold_triple_strs,
            [all_pred_triple_strs] * len(all_gold_triple_strs),
            lang="en", verbose=False, model_type="roberta-large", idf=True
        )
        micro_gbs_r = np.nanmean(P_micro_r.cpu().numpy())
        
        micro_gbs_f1 = 0.0
        if (micro_gbs_p + micro_gbs_r) > 0:
            micro_gbs_f1 = 2 * micro_gbs_p * micro_gbs_r / (micro_gbs_p + micro_gbs_r)
    else:
        micro_gbs_p, micro_gbs_r, micro_gbs_f1 = 0.0, 0.0, 0.0

    # G-BERTScore (Hungarian Matching) from EXPLAGRAPHS
    def get_g_bert_score(all_gold_edges, all_pred_edges, idf=False):
        references = []
        candidates = []
        ref_cand_index = {}
        for graph_idx, (gold_edges, pred_edges) in enumerate(zip(all_gold_edges, all_pred_edges)):
            for gold_edge in gold_edges:
                for pred_edge in pred_edges:
                    references.append(gold_edge)
                    candidates.append(pred_edge)
                    ref_cand_index[(graph_idx, gold_edge, pred_edge)] = len(references) - 1

        if not references:
            return np.zeros(len(all_gold_edges)), np.zeros(len(all_gold_edges)), np.zeros(len(all_gold_edges))

        _, _, bs_F1 = bert_score_compute(candidates, references, lang='en', verbose=False, model_type="roberta-large", idf=idf)
        bs_F1 = bs_F1.cpu().numpy()

        precisions, recalls, f1s = [], [], []
        for graph_idx, (gold_edges, pred_edges) in enumerate(zip(all_gold_edges, all_pred_edges)):
            num_gold = len(gold_edges)
            num_pred = len(pred_edges)
            if num_gold == 0 and num_pred == 0:
                precisions.append(1.0)
                recalls.append(1.0)
                f1s.append(1.0)
                continue
            if num_pred == 0:
                precisions.append(1.0)
                recalls.append(0.0)
                f1s.append(0.0)
                continue
            if num_gold == 0:
                precisions.append(0.0)
                recalls.append(1.0)
                f1s.append(0.0)
                continue

            score_matrix = np.zeros((num_gold, num_pred))
            for i, gold_edge in enumerate(gold_edges):
                for j, pred_edge in enumerate(pred_edges):
                    idx = ref_cand_index.get((graph_idx, gold_edge, pred_edge))
                    if idx is not None:
                        score_matrix[i, j] = bs_F1[idx]

            row_ind, col_ind = linear_sum_assignment(score_matrix, maximize=True)
            matched_sim = score_matrix[row_ind, col_ind]

            sample_precision = matched_sim.sum() / num_pred
            sample_recall = matched_sim.sum() / num_gold
            sample_f1 = 2 * sample_precision * sample_recall / (sample_precision + sample_recall) if sample_precision + sample_recall > 0 else 0.0

            precisions.append(sample_precision)
            recalls.append(sample_recall)
            f1s.append(sample_f1)

        return np.array(precisions), np.array(recalls), np.array(f1s)

    all_gold_edges = [[f"{h} {r} {t}" for h, r, t in detail['gold_triples']] for detail in item_details]
    all_pred_edges = [[f"{h} {r} {t}" for h, r, t in detail['pred_triples']] for detail in item_details]

    precisions_h, recalls_h, f1s_h = get_g_bert_score(all_gold_edges, all_pred_edges, idf=False)

    macro_gbs_p_h = np.nanmean(precisions_h) if len(precisions_h) > 0 else 0.0
    macro_gbs_r_h = np.nanmean(recalls_h) if len(recalls_h) > 0 else 0.0
    macro_gbs_f1_h = np.nanmean(f1s_h) if len(f1s_h) > 0 else 0.0

    # Micro for G-BERTScore: Sum matched similarities from per-sample computations
    total_matched_sum = 0.0
    for i in range(len(precisions_h)):
        total_matched_sum += precisions_h[i] * len(all_pred_edges[i])
    micro_gbs_p_h = total_matched_sum / pred_total if pred_total > 0 else 0.0
    micro_gbs_r_h = total_matched_sum / gold_total if gold_total > 0 else 0.0
    micro_gbs_f1_h = 2 * micro_gbs_p_h * micro_gbs_r_h / (micro_gbs_p_h + micro_gbs_r_h) if micro_gbs_p_h + micro_gbs_r_h > 0 else 0.0
    
else:
    # All zeros for empties
    macro_p_bert, macro_r_bert, macro_f1_bert = 0.0, 0.0, 0.0
    micro_p_bert, micro_r_bert, micro_f1_bert = 0.0, 0.0, 0.0
    macro_gbs_p, macro_gbs_r, macro_gbs_f1 = 0.0, 0.0, 0.0
    micro_gbs_p, micro_gbs_r, micro_gbs_f1 = 0.0, 0.0, 0.0
    macro_gbs_p_h, macro_gbs_r_h, macro_gbs_f1_h = 0.0, 0.0, 0.0
    micro_gbs_p_h, micro_gbs_r_h, micro_gbs_f1_h = 0.0, 0.0, 0.0

# Soft Semantic Score (Unchanged)
threshold = 0.8  # Keep as is, or tune per dataset (literature often uses 0.75-0.9)
global_matched_preds, global_matched_golds = 0, 0
item_ps_soft, item_rs_soft, item_f1s_soft = [], [], []

for detail in item_details:
    pred_triples, gold_triples = detail['pred_triples'], detail['gold_triples']
    if not pred_triples and not gold_triples:
        item_ps_soft.append(1.0)
        item_rs_soft.append(1.0)
        item_f1s_soft.append(1.0)
        continue

    pred_strings = [f"{h} {r} {t}" for h, r, t in pred_triples]
    gold_strings = [f"{h} {r} {t}" for h, r, t in gold_triples]

    if not pred_strings or not gold_strings:
        matched_preds, matched_golds = 0, 0
    else:
        pred_embs = sentence_model.encode(pred_strings, convert_to_tensor=True).cpu().numpy()
        gold_embs = sentence_model.encode(gold_strings, convert_to_tensor=True).cpu().numpy()
        
        # Cosine similarity matrix (normalize to [0,1] for matching)
        sim_matrix = np.dot(pred_embs, gold_embs.T) / (np.linalg.norm(pred_embs, axis=1)[:, np.newaxis] * np.linalg.norm(gold_embs, axis=1))
        sim_matrix = np.maximum(sim_matrix, 0)  # Clip negatives
        
        # Hungarian for optimal assignment (maximize similarity)
        row_ind, col_ind = linear_sum_assignment(sim_matrix, maximize=True)
        matched_sim = sim_matrix[row_ind, col_ind]
        
        # Count matches above threshold
        matched_preds = sum(1 for s in matched_sim if s > threshold)
        matched_golds = matched_preds  # Symmetric in bipartite

    p_item = matched_preds / len(pred_triples) if pred_triples else (1.0 if not gold_triples else 0.0)
    r_item = matched_golds / len(gold_triples) if gold_triples else (1.0 if not pred_triples else 0.0)
    f1_item = 2 * p_item * r_item / (p_item + r_item) if p_item + r_item > 0 else 0.0

    item_ps_soft.append(p_item)
    item_rs_soft.append(r_item)
    item_f1s_soft.append(f1_item)
    global_matched_preds += matched_preds
    global_matched_golds += matched_golds
    logging.info(f"Item {detail['index']} Semantic: P={p_item:.4f}, R={r_item:.4f}, F1={f1_item:.4f}")

micro_p_soft = global_matched_preds / pred_total if pred_total > 0 else 0.0
micro_r_soft = global_matched_golds / gold_total if gold_total > 0 else 0.0
micro_f1_soft = 2 * micro_p_soft * micro_r_soft / (micro_p_soft + micro_r_soft) if micro_p_soft + micro_r_soft > 0 else 0.0

macro_p_soft = np.nanmean(item_ps_soft) if item_ps_soft else 0.0
macro_r_soft = np.nanmean(item_rs_soft) if item_rs_soft else 0.0
macro_f1_soft = np.nanmean(item_f1s_soft) if item_f1s_soft else 0.0

# --- Final Output ---
output_str = f"""
Benchmark Results:
---------------------------------
Number of examples: {len(data)}
Strict Averages:
  - Macro: P={macro_p_strict:.4f}, R={macro_r_strict:.4f}, F1={macro_f1_strict:.4f}
  - Micro: P={micro_p_strict:.4f}, R={micro_r_strict:.4f}, F1={micro_f1_strict:.4f}

BERTScore Averages:
  - Macro: P={macro_p_bert:.4f}, R={macro_r_bert:.4f}, F1={macro_f1_bert:.4f}
  - Micro: P={micro_p_bert:.4f}, R={micro_r_bert:.4f}, F1={micro_f1_bert:.4f}

Greedy BS Averages:
  - Macro: P={macro_gbs_p:.4f}, R={macro_gbs_r:.4f}, F1={macro_gbs_f1:.4f}
  - Micro: P={micro_gbs_p:.4f}, R={micro_gbs_r:.4f}, F1={micro_gbs_f1:.4f}

G-BERTScore Averages:
  - Macro: P={macro_gbs_p_h:.4f}, R={macro_gbs_r_h:.4f}, F1={macro_gbs_f1_h:.4f}
  - Micro: P={micro_gbs_p_h:.4f}, R={micro_gbs_r_h:.4f}, F1={micro_gbs_f1_h:.4f}

Semantic Score Averages (Threshold: {threshold}):
  - Macro: P={macro_p_soft:.4f}, R={macro_r_soft:.4f}, F1={macro_f1_soft:.4f}
  - Micro: P={micro_p_soft:.4f}, R={micro_r_soft:.4f}, F1={micro_f1_soft:.4f}
"""
print(output_str)
logging.info(output_str)

### 2. CONLL2004 dataset

In [None]:
import json
import time
import logging
import os
from collections import Counter
from typing import List, Tuple, Dict, Set
from kg_gen import KGGen
import bert_score
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import torch
import re
import numpy as np
from scipy.optimize import linear_sum_assignment  # For Hungarian matching
import datetime

# --- Configuration ---
# It is a critical security risk to hardcode your API key. Please use an environment variable.
# In your terminal, run: export OPENAI_API_KEY='your_real_api_key'
API_KEY = ""
# Using the path for the uploaded file
DATASET_PATH = '../datasets/conll04_test_triples.json'
current_date = datetime.datetime.now().strftime('%m%d')
LOG_FILE_PATH = f'{current_date}_benchmark_conll04_test_all.log'
MODEL_NAME = "openai/gpt-4o"
SENTENCE_MODEL_NAME = 'all-mpnet-base-v2'
# Set to None to run on the full dataset
NUM_EXAMPLES_TO_TEST = None
FAILED_ENTRIES_PATH = f'{current_date}_conll04_failed_entries.json'

# --- Set up logging ---
logging.basicConfig(level=logging.INFO, filename=LOG_FILE_PATH,
                    filemode='a', format='%(message)s')

# --- Main Script ---
if not API_KEY.startswith("sk-proj-"):
    logging.warning(
        "OpenAI API key not set. The script will fail. Please set the OPENAI_API_KEY environment variable.")
    print("WARNING: OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")

kg = KGGen(
    model=MODEL_NAME,
    temperature=0.1,
    api_key=API_KEY,
)

# --- Stricter Prompt ---
CONTEXT = """You MUST extract relationships ONLY using these exact relation names: kill, work_for, organization_based_in, live_in, located_in. You MUST NOT use any other relation names—any deviation is invalid. Subject and object MUST be complete noun phrases from the text. Output ONLY unique triples as (Subject, relation, Object) or [] if none. Infer relations where implied (e.g., 'born in' implies 'live_in'; 'director of' implies 'work_for').

Invalid Example (Do NOT do this):
Sentence: "Recognition of proper nouns in Japanese text has been studied as a part of the more general problem of morphological analysis in Japanese text processing."
Wrong Triples: [("Recognition", "has-been-studied-as", "part"), ("Recognition", "of", "proper nouns")]  # Invalid relations

Valid Examples:
Sentence: "Newspaper ` Explains ' U.S. Interests Section Events FL1402001894 Havana Radio Reloj Network in Spanish 2100 GMT 13 Feb 94"
Triples: [("radio reloj network", "organization_based_in", "havana")]

Sentence: "Annie Oakley , also known as Little Miss Sure Shot , was born Phoebe Ann Moses in Willowdell , Darke County , in 1860 ."
Triples: [("annie oakley", "live_in", "willowdell , darke county"), ("little miss sure shot", "live_in", "willowdell , darke county"), ("phoebe ann moses", "live_in", "willowdell , darke county")]

Apply to the text. Remember: ONLY use relations from the list."""

# --- Normalization Logic ---
def normalize_entity(entity_text: str) -> str:
    if not isinstance(entity_text, str):
        return ""
    
    text = entity_text.strip().lower()
    text = re.sub(r'^(a|an|the)\s+', '', text, flags=re.IGNORECASE)
    text = re.sub(r'\s*\.$', '', text)
    text = text.strip('\'"')
    
    return text if text else ""

# Load dataset
try:
    with open(DATASET_PATH, 'r') as f:
        data = json.load(f)
except FileNotFoundError:
    logging.error(f"FATAL: Dataset not found at path: {DATASET_PATH}")
    print(f"FATAL: Dataset not found at path: {DATASET_PATH}")
    exit()

if NUM_EXAMPLES_TO_TEST:
    data = data[:NUM_EXAMPLES_TO_TEST]

print(f"Loading sentence transformer model: {SENTENCE_MODEL_NAME}...")
sentence_model = SentenceTransformer(SENTENCE_MODEL_NAME)
print("Model loaded.")

# --- Benchmark Loop ---
all_gold_triples, all_pred_triples, all_pred_strings, all_gold_strings = [], [], [], []
item_details = []
success_count = 0
fail_count = 0
failed_entries = []

for i, item in enumerate(data):
    text = item['text']
    gold_triples_raw = item.get('triple_list', [])
    
    gold_triples = []
    for subj, pred, obj in gold_triples_raw:
        subj_norm = normalize_entity(subj)
        obj_norm = normalize_entity(obj)
        pred_norm = pred.strip().lower().replace(" ", "_")
        if subj_norm and obj_norm:
            gold_triples.append((subj_norm, pred_norm, obj_norm))
    gold_triples = sorted(list(set(gold_triples)))

    logging.info(f"\n----- Processing Item {i+1}/{len(data)} -----")
    logging.info(f"Text: {text}")
    print(f"Processing Item {i+1}/{len(data)}...")
    
    pred_triples = []
    try:
        response = kg.generate(
            input_data=text, 
            context=CONTEXT
        )
        pred_triples_raw = response.relations if response and hasattr(response, 'relations') else []
        logging.info(f"KGGen Extraction: {pred_triples_raw}")

        corrected_triples_set = set()
        for triple in pred_triples_raw:
            if len(triple) != 3: continue
            subj, pred, obj = triple
            subj_norm = normalize_entity(subj)
            obj_norm = normalize_entity(obj)
            pred_norm = pred.strip().lower().replace(" ", "_")
            if subj_norm and obj_norm:
                corrected_triples_set.add((subj_norm, pred_norm, obj_norm))

        pred_triples = sorted(list(corrected_triples_set))
        success_count += 1

    except Exception as e:
        logging.error(f"An API or parsing error occurred for item {i+1}: {e}", exc_info=True)
        pred_triples = []
        fail_count += 1
        failed_entries.append({
            'index': i+1,
            'text': text,
            'gold_triples_raw': gold_triples_raw,
            'error': str(e)
        })

    logging.info(f"Gold (Normalized): {gold_triples}")
    logging.info(f"Pred (Normalized): {pred_triples}")

    all_gold_triples.extend(gold_triples)
    all_pred_triples.extend(pred_triples)

    pred_str = " | ".join(sorted([f"{h} {r} {t}" for h, r, t in set(pred_triples)]))
    gold_str = " | ".join(sorted([f"{h} {r} {t}" for h, r, t in set(gold_triples)]))
    all_pred_strings.append(pred_str)
    all_gold_strings.append(gold_str)

    item_details.append({
        'index': i+1,
        'gold_list': set(gold_triples), 'pred_list': set(pred_triples),
        'gold_str': gold_str, 'pred_str': pred_str,
        'gold_triples': gold_triples, 'pred_triples': pred_triples
    })

# Write failed entries to JSON
if failed_entries:
    with open(FAILED_ENTRIES_PATH, 'w') as f:
        json.dump(failed_entries, f, indent=4)
    logging.info(f"Failed entries written to {FAILED_ENTRIES_PATH}")
else:
    logging.info("No failed entries.")

# Log processing counts
logging.info(f"\n----- Processing Summary -----")
logging.info(f"Successful items: {success_count}")
logging.info(f"Failed items: {fail_count}")
print(f"Successful items: {success_count}")
print(f"Failed items: {fail_count}")

# --- Metrics Calculation ---
logging.info("\n----- FINAL METRICS -----")

# Strict Metrics
gold_count = Counter(all_gold_triples)
pred_count = Counter(all_pred_triples)
tp = sum(min(gold_count[t], pred_count[t]) for t in gold_count)
pred_total = len(all_pred_triples)
gold_total = len(all_gold_triples)
micro_p_strict = tp / pred_total if pred_total > 0 else 0.0
micro_r_strict = tp / gold_total if gold_total > 0 else 0.0
micro_f1_strict = 2 * micro_p_strict * micro_r_strict / \
    (micro_p_strict + micro_r_strict) if micro_p_strict + \
    micro_r_strict > 0 else 0.0

item_ps_strict, item_rs_strict, item_f1s_strict = [], [], []
for detail in item_details:
    tp_item = len(detail['pred_list'] & detail['gold_list'])
    p_item = tp_item / len(detail['pred_triples']) if detail['pred_triples'] else (1.0 if not detail['gold_triples'] else 0.0)
    r_item = tp_item / len(detail['gold_triples']) if detail['gold_triples'] else (1.0 if not detail['pred_triples'] else 0.0)
    if p_item + r_item == 0:
        f1_item = 0.0
    else:
        f1_item = 2 * p_item * r_item / (p_item + r_item)
    if not detail['pred_triples'] and not detail['gold_triples']:
        f1_item = 1.0
    item_ps_strict.append(p_item)
    item_rs_strict.append(r_item)
    item_f1s_strict.append(f1_item)
    logging.info(
        f"Item {detail['index']} Strict: P={p_item:.4f}, R={r_item:.4f}, F1={f1_item:.4f}")

macro_p_strict = np.nanmean(item_ps_strict) if item_ps_strict else 0.0
macro_r_strict = np.nanmean(item_rs_strict) if item_rs_strict else 0.0
macro_f1_strict = np.nanmean(item_f1s_strict) if item_f1s_strict else 0.0

# Base BERTScore (with IDF added for robustness)
from bert_score import score as bert_score_compute

if any(all_pred_strings) and any(all_gold_strings):
    # Base BERTScore (P, R, F1) - Uses greedy max as per formulas
    P_macro, R_macro, F1_macro = bert_score_compute(
        all_pred_strings, all_gold_strings, lang="en", verbose=False, model_type="roberta-large", idf=True
    )
    macro_p_bert = P_macro.mean().item()
    macro_r_bert = R_macro.mean().item()
    macro_f1_bert = F1_macro.mean().item()

    all_pred_concat = ' | '.join(filter(None, all_pred_strings))
    all_gold_concat = ' | '.join(filter(None, all_gold_strings))
    P_micro, R_micro, F1_micro = bert_score_compute(
        [all_pred_concat], [all_gold_concat], lang="en", verbose=False, model_type="roberta-large", idf=False
    )
    micro_p_bert = P_micro.item()
    micro_r_bert = R_micro.item() if not np.isnan(R_micro.item()) else 0.0
    micro_f1_bert = F1_micro.item() if not np.isnan(F1_micro.item()) else 0.0
    
    # Greedy BERTScore
    item_gbs_p, item_gbs_r, item_gbs_f1 = [], [], []

    for detail in item_details:
        pred_strings = [f"{h} {r} {t}" for h, r, t in detail['pred_triples']]
        gold_strings = [f"{h} {r} {t}" for h, r, t in detail['gold_triples']]

        if not pred_strings and not gold_strings:
            item_gbs_p.append(1.0)
            item_gbs_r.append(1.0)
            item_gbs_f1.append(1.0)
            continue

        if not pred_strings:  # Has gold, but no preds (all missed)
            item_gbs_p.append(1.0)  # Vacuously true: all 0 preds are "correct"
            item_gbs_r.append(0.0)
            item_gbs_f1.append(0.0)
            continue

        if not gold_strings:  # Has preds, but no gold (all false pos)
            item_gbs_p.append(0.0)
            item_gbs_r.append(1.0)  # Vacuously true: all 0 golds are "found"
            item_gbs_f1.append(0.0)
            continue
        
        # G-BS-P (Precision): Avg. max similarity for each *predicted* triple
        P_gbs, _, _ = bert_score_compute(
            pred_strings,
            [gold_strings] * len(pred_strings),  # Compare each pred to ALL golds
            lang="en", verbose=False,
            model_type="roberta-large", idf=True
        )
        gbs_p = P_gbs.mean().item()

        # G-BS-R (Recall): Avg. max similarity for each *gold* triple
        P_recall, _, _ = bert_score_compute(  # Capture the first value (Precision)
            gold_strings,
            [pred_strings] * len(gold_strings),  # Compare each gold to ALL preds
            lang="en", verbose=False,
            model_type="roberta-large", idf=True
        )
        gbs_r = P_recall.mean().item()  # Use the Precision from the flipped call
        
        gbs_f1 = 0.0
        if (gbs_p + gbs_r) > 0:
            gbs_f1 = 2 * gbs_p * gbs_r / (gbs_p + gbs_r)

        item_gbs_p.append(gbs_p)
        item_gbs_r.append(gbs_r)
        item_gbs_f1.append(gbs_f1)

    # Macro averages (mean over items)
    macro_gbs_p = np.nanmean(item_gbs_p) if item_gbs_p else 0.0
    macro_gbs_r = np.nanmean(item_gbs_r) if item_gbs_r else 0.0
    macro_gbs_f1 = np.nanmean(item_gbs_f1) if item_gbs_f1 else 0.0

    # Micro for G-BS: Flatten all triples and compute global greedy
    all_pred_triple_strs = [f"{h} {r} {t}" for triples in item_details for h, r, t in triples['pred_triples']]
    all_gold_triple_strs = [f"{h} {r} {t}" for triples in item_details for h, r, t in triples['gold_triples']]
    
    if all_pred_triple_strs and all_gold_triple_strs:
        # Micro P: Avg max sim for each pred vs all golds
        P_micro_g, _, _ = bert_score_compute(
            all_pred_triple_strs,
            [all_gold_triple_strs] * len(all_pred_triple_strs),
            lang="en", verbose=False, model_type="roberta-large", idf=True
        )
        micro_gbs_p = np.nanmean(P_micro_g.cpu().numpy())

        # Micro R: Avg max sim for each gold vs all preds
        P_micro_r, _, _ = bert_score_compute(  # Capture the first value (Precision)
            all_gold_triple_strs,
            [all_pred_triple_strs] * len(all_gold_triple_strs),
            lang="en", verbose=False, model_type="roberta-large", idf=True
        )
        micro_gbs_r = np.nanmean(P_micro_r.cpu().numpy())
        
        micro_gbs_f1 = 0.0
        if (micro_gbs_p + micro_gbs_r) > 0:
            micro_gbs_f1 = 2 * micro_gbs_p * micro_gbs_r / (micro_gbs_p + micro_gbs_r)
    else:
        micro_gbs_p, micro_gbs_r, micro_gbs_f1 = 0.0, 0.0, 0.0
    
else:
    # All zeros for empties
    macro_p_bert, macro_r_bert, macro_f1_bert = 0.0, 0.0, 0.0
    micro_p_bert, micro_r_bert, micro_f1_bert = 0.0, 0.0, 0.0
    macro_gbs_p, macro_gbs_r, macro_gbs_f1 = 0.0, 0.0, 0.0
    micro_gbs_p, micro_gbs_r, micro_gbs_f1 = 0.0, 0.0, 0.0

# G-BERTScore (Hungarian Matching) from EXPLAGRAPHS
def get_g_bert_score(all_gold_edges, all_pred_edges, idf=False):
    references = []
    candidates = []
    ref_cand_index = {}
    for graph_idx, (gold_edges, pred_edges) in enumerate(zip(all_gold_edges, all_pred_edges)):
        for gold_edge in gold_edges:
            for pred_edge in pred_edges:
                references.append(gold_edge)
                candidates.append(pred_edge)
                ref_cand_index[(graph_idx, gold_edge, pred_edge)] = len(references) - 1

    if not references:
        return np.zeros(len(all_gold_edges)), np.zeros(len(all_gold_edges)), np.zeros(len(all_gold_edges))

    _, _, bs_F1 = bert_score_compute(candidates, references, lang='en', verbose=False, model_type="roberta-large", idf=idf)
    bs_F1 = bs_F1.cpu().numpy()

    precisions, recalls, f1s = [], [], []
    for graph_idx, (gold_edges, pred_edges) in enumerate(zip(all_gold_edges, all_pred_edges)):
        num_gold = len(gold_edges)
        num_pred = len(pred_edges)
        if num_gold == 0 and num_pred == 0:
            precisions.append(1.0)
            recalls.append(1.0)
            f1s.append(1.0)
            continue
        if num_pred == 0:
            precisions.append(1.0)
            recalls.append(0.0)
            f1s.append(0.0)
            continue
        if num_gold == 0:
            precisions.append(0.0)
            recalls.append(1.0)
            f1s.append(0.0)
            continue

        score_matrix = np.zeros((num_gold, num_pred))
        for i, gold_edge in enumerate(gold_edges):
            for j, pred_edge in enumerate(pred_edges):
                idx = ref_cand_index.get((graph_idx, gold_edge, pred_edge))
                if idx is not None:
                    score_matrix[i, j] = bs_F1[idx]

        row_ind, col_ind = linear_sum_assignment(score_matrix, maximize=True)
        matched_sim = score_matrix[row_ind, col_ind]

        sample_precision = matched_sim.sum() / num_pred
        sample_recall = matched_sim.sum() / num_gold
        sample_f1 = 2 * sample_precision * sample_recall / (sample_precision + sample_recall) if sample_precision + sample_recall > 0 else 0.0

        precisions.append(sample_precision)
        recalls.append(sample_recall)
        f1s.append(sample_f1)

    return np.array(precisions), np.array(recalls), np.array(f1s)

all_gold_edges = [[f"{h} {r} {t}" for h, r, t in detail['gold_triples']] for detail in item_details]
all_pred_edges = [[f"{h} {r} {t}" for h, r, t in detail['pred_triples']] for detail in item_details]

precisions_h, recalls_h, f1s_h = get_g_bert_score(all_gold_edges, all_pred_edges, idf=False)

macro_gbs_p_h = np.nanmean(precisions_h) if len(precisions_h) > 0 else 0.0
macro_gbs_r_h = np.nanmean(recalls_h) if len(recalls_h) > 0 else 0.0
macro_gbs_f1_h = np.nanmean(f1s_h) if len(f1s_h) > 0 else 0.0

# Micro for G-BERTScore: Sum matched similarities from per-sample computations
total_matched_sum = 0.0
for i in range(len(precisions_h)):
    total_matched_sum += precisions_h[i] * len(all_pred_edges[i])
micro_gbs_p_h = total_matched_sum / pred_total if pred_total > 0 else 0.0
micro_gbs_r_h = total_matched_sum / gold_total if gold_total > 0 else 0.0
micro_gbs_f1_h = 2 * micro_gbs_p_h * micro_gbs_r_h / (micro_gbs_p_h + micro_gbs_r_h) if micro_gbs_p_h + micro_gbs_r_h > 0 else 0.0

# Soft Semantic Score
threshold = 0.8  # Keep as is, or tune per dataset (literature often uses 0.75-0.9)
global_matched_preds, global_matched_golds = 0, 0
item_ps_soft, item_rs_soft, item_f1s_soft = [], [], []

for detail in item_details:
    pred_triples, gold_triples = detail['pred_triples'], detail['gold_triples']
    if not pred_triples and not gold_triples:
        item_ps_soft.append(1.0)
        item_rs_soft.append(1.0)
        item_f1s_soft.append(1.0)
        continue

    pred_strings = [f"{h} {r} {t}" for h, r, t in pred_triples]
    gold_strings = [f"{h} {r} {t}" for h, r, t in gold_triples]

    if not pred_strings or not gold_strings:
        matched_preds, matched_golds = 0, 0
    else:
        pred_embs = sentence_model.encode(pred_strings, convert_to_tensor=True).cpu().numpy()
        gold_embs = sentence_model.encode(gold_strings, convert_to_tensor=True).cpu().numpy()
        
        # Cosine similarity matrix (normalize to [0,1] for matching)
        sim_matrix = np.dot(pred_embs, gold_embs.T) / (np.linalg.norm(pred_embs, axis=1)[:, np.newaxis] * np.linalg.norm(gold_embs, axis=1))
        sim_matrix = np.maximum(sim_matrix, 0)  # Clip negatives
        
        # Hungarian for optimal assignment (maximize similarity)
        row_ind, col_ind = linear_sum_assignment(sim_matrix, maximize=True)
        matched_sim = sim_matrix[row_ind, col_ind]
        
        # Count matches above threshold
        matched_preds = sum(1 for s in matched_sim if s > threshold)
        matched_golds = matched_preds  # Symmetric in bipartite

    p_item = matched_preds / len(pred_triples) if pred_triples else (1.0 if not gold_triples else 0.0)
    r_item = matched_golds / len(gold_triples) if gold_triples else (1.0 if not pred_triples else 0.0)
    f1_item = 2 * p_item * r_item / (p_item + r_item) if p_item + r_item > 0 else 0.0

    item_ps_soft.append(p_item)
    item_rs_soft.append(r_item)
    item_f1s_soft.append(f1_item)
    global_matched_preds += matched_preds
    global_matched_golds += matched_golds
    logging.info(f"Item {detail['index']} Semantic: P={p_item:.4f}, R={r_item:.4f}, F1={f1_item:.4f}")

micro_p_soft = global_matched_preds / pred_total if pred_total > 0 else 0.0
micro_r_soft = global_matched_golds / gold_total if gold_total > 0 else 0.0
micro_f1_soft = 2 * micro_p_soft * micro_r_soft / (micro_p_soft + micro_r_soft) if micro_p_soft + micro_r_soft > 0 else 0.0

macro_p_soft = np.nanmean(item_ps_soft) if item_ps_soft else 0.0
macro_r_soft = np.nanmean(item_rs_soft) if item_rs_soft else 0.0
macro_f1_soft = np.nanmean(item_f1s_soft) if item_f1s_soft else 0.0

# --- Final Output ---
output_str = f"""
Benchmark Results:
---------------------------------
Number of examples: {len(data)}
Strict Averages:
  - Macro: P={macro_p_strict:.4f}, R={macro_r_strict:.4f}, F1={macro_f1_strict:.4f}
  - Micro: P={micro_p_strict:.4f}, R={micro_r_strict:.4f}, F1={micro_f1_strict:.4f}

BERTScore Averages:
  - Macro: P={macro_p_bert:.4f}, R={macro_r_bert:.4f}, F1={macro_f1_bert:.4f}
  - Micro: P={micro_p_bert:.4f}, R={micro_r_bert:.4f}, F1={micro_f1_bert:.4f}

Greedy BS Averages:
  - Macro: P={macro_gbs_p:.4f}, R={macro_gbs_r:.4f}, F1={macro_gbs_f1:.4f}
  - Micro: P={micro_gbs_p:.4f}, R={micro_gbs_r:.4f}, F1={micro_gbs_f1:.4f}

G-BERTScore Averages:
  - Macro: P={macro_gbs_p_h:.4f}, R={macro_gbs_r_h:.4f}, F1={macro_gbs_f1_h:.4f}
  - Micro: P={micro_gbs_p_h:.4f}, R={micro_gbs_r_h:.4f}, F1={micro_gbs_f1_h:.4f}

Semantic Score Averages (Threshold: {threshold}):
  - Macro: P={macro_p_soft:.4f}, R={macro_r_soft:.4f}, F1={macro_f1_soft:.4f}
  - Micro: P={micro_p_soft:.4f}, R={micro_r_soft:.4f}, F1={micro_f1_soft:.4f}
"""
print(output_str)
logging.info(output_str)

### 3. SciERC datatset

In [None]:
import json
import time
import logging
import os
from collections import Counter
from typing import List, Tuple, Dict, Set
from kg_gen import KGGen
import bert_score
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import torch
import re
import numpy as np
from scipy.optimize import linear_sum_assignment  # For Hungarian matching
import datetime

# --- Configuration ---
API_KEY = ""

# --- Paths ---
current_date = datetime.datetime.now().strftime('%m%d')
LOG_FILE_PATH = f'{current_date}_benchmark_scierc_test_all.log'
DATASET_PATH = '../datasets/scierc_test_converted.json'
MODEL_NAME = "openai/gpt-4o"
SENTENCE_MODEL_NAME = 'all-mpnet-base-v2'
NUM_EXAMPLES_TO_TEST = None
FAILED_ENTRIES_PATH = f'{current_date}_scierc_failed_entries.json'

# --- Set up logging ---
logging.basicConfig(level=logging.INFO, filename=LOG_FILE_PATH,
                    filemode='a', format='%(message)s')

# --- Main Script ---
if not API_KEY.startswith("sk-proj-"):
    logging.warning("OpenAI API key not set.")
    print("WARNING: OpenAI API key not set.")

kg = KGGen(
    model=MODEL_NAME,
    temperature=0.1,
    api_key=API_KEY,
)

# --- Stricter Prompt ---
CONTEXT = """You MUST extract relationships ONLY using these exact relation names: used-for, feature-of, hyponym-of, part-of, compare, evaluate-for, conjunction. You MUST NOT use any other relation names—any deviation is invalid. Subject and object MUST be complete noun phrases from the text. Output ONLY unique triples as (Subject, relation, Object) or [] if none. Do NOT use relations like 'has-been-studied-as', 'of', 'in'—those are invalid; stick ONLY to the list.

Invalid Example (Do NOT do this):
Sentence: "Recognition of proper nouns in Japanese text has been studied as a part of the more general problem of morphological analysis in Japanese text processing."
Wrong Triples: [("Recognition", "has-been-studied-as", "part"), ("Recognition", "of", "proper nouns")]  # Invalid relations

Valid Examples:
Sentence: "The agreement in question involves number in nouns and reflexive pronouns and is syntactic rather than semantic in nature because grammatical number in English , like grammatical gender in languages such as French , is partly arbitrary ."
Triples: [("nouns", "conjunction", "reflexive pronouns"), ("grammatical gender", "feature-of", "languages"), ("french", "hyponym-of", "languages")]

Sentence: "In this paper , a novel method to learn the intrinsic object structure for robust visual tracking is proposed ."
Triples: [("novel method", "used-for", "learn the intrinsic object structure"), ("intrinsic object structure", "used-for", "robust visual tracking")]

Apply to the text. Remember: ONLY use relations from the list."""

# --- Normalization Logic ---
def normalize_entity(entity_text: str) -> str:
    if not isinstance(entity_text, str):
        return ""
    
    text = entity_text.strip().lower()
    text = re.sub(r'^(a|an|the)\s+', '', text, flags=re.IGNORECASE)
    text = re.sub(r'\s*\.$', '', text)
    text = text.strip('\'"')
    
    return text if text else ""

# Load dataset
try:
    with open(DATASET_PATH, 'r') as f:
        data = json.load(f)
except FileNotFoundError:
    logging.error(f"FATAL: Dataset not found at path: {DATASET_PATH}")
    print(f"FATAL: Dataset not found at path: {DATASET_PATH}")
    exit()

if NUM_EXAMPLES_TO_TEST:
    data = data[:NUM_EXAMPLES_TO_TEST]

print(f"Loading sentence transformer model: {SENTENCE_MODEL_NAME}...")
sentence_model = SentenceTransformer(SENTENCE_MODEL_NAME)
print("Model loaded.")

# --- Benchmark Loop ---
all_gold_triples, all_pred_triples, all_pred_strings, all_gold_strings = [], [], [], []
item_details = []
success_count = 0
fail_count = 0
failed_entries = []

for i, item in enumerate(data):
    text = item['text']
    gold_triples_raw = item.get('triple_list', [])
    
    gold_triples = []
    for subj, pred, obj in gold_triples_raw:
        subj_norm = normalize_entity(subj)
        obj_norm = normalize_entity(obj)
        pred_norm = pred.strip().lower().replace(" ", "-")
        if subj_norm and obj_norm:
            gold_triples.append((subj_norm, pred_norm, obj_norm))
    gold_triples = sorted(list(set(gold_triples)))

    logging.info(f"\n----- Processing Item {i+1}/{len(data)} -----")
    logging.info(f"Text: {text}")
    print(f"Processing Item {i+1}/{len(data)}...")
    
    pred_triples = []
    try:
        response = kg.generate(
            input_data=text, 
            context=CONTEXT
        )
        pred_triples_raw = response.relations if response and hasattr(response, 'relations') else []
        logging.info(f"KGGen Extraction: {pred_triples_raw}")

        corrected_triples_set = set()
        for triple in pred_triples_raw:
            if len(triple) != 3: continue
            subj, pred, obj = triple
            subj_norm = normalize_entity(subj)
            obj_norm = normalize_entity(obj)
            pred_norm = pred.strip().lower().replace(" ", "-")
            if subj_norm and obj_norm:
                corrected_triples_set.add((subj_norm, pred_norm, obj_norm))

        pred_triples = sorted(list(corrected_triples_set))
        success_count += 1

    except Exception as e:
        logging.error(f"An API or parsing error occurred for item {i+1}: {e}", exc_info=True)
        pred_triples = []
        fail_count += 1
        failed_entries.append({
            'index': i+1,
            'text': text,
            'gold_triples_raw': gold_triples_raw,
            'error': str(e)
        })

    logging.info(f"Gold (Normalized): {gold_triples}")
    logging.info(f"Pred (Normalized): {pred_triples}")

    all_gold_triples.extend(gold_triples)
    all_pred_triples.extend(pred_triples)

    pred_str = " | ".join(sorted([f"{h} {r} {t}" for h, r, t in set(pred_triples)]))
    gold_str = " | ".join(sorted([f"{h} {r} {t}" for h, r, t in set(gold_triples)]))
    all_pred_strings.append(pred_str)
    all_gold_strings.append(gold_str)

    item_details.append({
        'index': i+1,
        'gold_list': set(gold_triples), 'pred_list': set(pred_triples),
        'gold_str': gold_str, 'pred_str': pred_str,
        'gold_triples': gold_triples, 'pred_triples': pred_triples
    })

# Write failed entries to JSON
if failed_entries:
    with open(FAILED_ENTRIES_PATH, 'w') as f:
        json.dump(failed_entries, f, indent=4)
    logging.info(f"Failed entries written to {FAILED_ENTRIES_PATH}")
else:
    logging.info("No failed entries.")

# Log processing counts
logging.info(f"\n----- Processing Summary -----")
logging.info(f"Successful items: {success_count}")
logging.info(f"Failed items: {fail_count}")
print(f"Successful items: {success_count}")
print(f"Failed items: {fail_count}")

# --- Metrics Calculation ---
logging.info("\n----- FINAL METRICS -----")

# Strict Metrics
gold_count = Counter(all_gold_triples)
pred_count = Counter(all_pred_triples)
tp = sum(min(gold_count[t], pred_count[t]) for t in gold_count)
pred_total = len(all_pred_triples)
gold_total = len(all_gold_triples)
micro_p_strict = tp / pred_total if pred_total > 0 else 0.0
micro_r_strict = tp / gold_total if gold_total > 0 else 0.0
micro_f1_strict = 2 * micro_p_strict * micro_r_strict / \
    (micro_p_strict + micro_r_strict) if micro_p_strict + \
    micro_r_strict > 0 else 0.0

item_ps_strict, item_rs_strict, item_f1s_strict = [], [], []
for detail in item_details:
    tp_item = len(detail['pred_list'] & detail['gold_list'])
    p_item = tp_item / len(detail['pred_triples']) if detail['pred_triples'] else (1.0 if not detail['gold_triples'] else 0.0)
    r_item = tp_item / len(detail['gold_triples']) if detail['gold_triples'] else (1.0 if not detail['pred_triples'] else 0.0)
    if p_item + r_item == 0:
        f1_item = 0.0
    else:
        f1_item = 2 * p_item * r_item / (p_item + r_item)
    if not detail['pred_triples'] and not detail['gold_triples']:
        f1_item = 1.0
    item_ps_strict.append(p_item)
    item_rs_strict.append(r_item)
    item_f1s_strict.append(f1_item)
    logging.info(
        f"Item {detail['index']} Strict: P={p_item:.4f}, R={r_item:.4f}, F1={f1_item:.4f}")

macro_p_strict = np.nanmean(item_ps_strict) if item_ps_strict else 0.0
macro_r_strict = np.nanmean(item_rs_strict) if item_rs_strict else 0.0
macro_f1_strict = np.nanmean(item_f1s_strict) if item_f1s_strict else 0.0

# Base BERTScore (with IDF added for robustness)
from bert_score import score as bert_score_compute

if any(all_pred_strings) and any(all_gold_strings):
    # Base BERTScore (P, R, F1) - Uses greedy max as per formulas
    P_macro, R_macro, F1_macro = bert_score_compute(
        all_pred_strings, all_gold_strings, lang="en", verbose=False, model_type="roberta-large", idf=True
    )
    macro_p_bert = P_macro.mean().item()
    macro_r_bert = R_macro.mean().item()
    macro_f1_bert = F1_macro.mean().item()

    all_pred_concat = ' | '.join(filter(None, all_pred_strings))
    all_gold_concat = ' | '.join(filter(None, all_gold_strings))
    P_micro, R_micro, F1_micro = bert_score_compute(
        [all_pred_concat], [all_gold_concat], lang="en", verbose=False, model_type="roberta-large", idf=False
    )
    micro_p_bert = P_micro.item()
    micro_r_bert = R_micro.item() if not np.isnan(R_micro.item()) else 0.0
    micro_f1_bert = F1_micro.item() if not np.isnan(F1_micro.item()) else 0.0
    
    # Greedy BERTScore
    item_gbs_p, item_gbs_r, item_gbs_f1 = [], [], []

    for detail in item_details:
        pred_strings = [f"{h} {r} {t}" for h, r, t in detail['pred_triples']]
        gold_strings = [f"{h} {r} {t}" for h, r, t in detail['gold_triples']]

        if not pred_strings and not gold_strings:
            item_gbs_p.append(1.0)
            item_gbs_r.append(1.0)
            item_gbs_f1.append(1.0)
            continue

        if not pred_strings:  # Has gold, but no preds (all missed)
            item_gbs_p.append(1.0)  # Vacuously true: all 0 preds are "correct"
            item_gbs_r.append(0.0)
            item_gbs_f1.append(0.0)
            continue

        if not gold_strings:  # Has preds, but no gold (all false pos)
            item_gbs_p.append(0.0)
            item_gbs_r.append(1.0)  # Vacuously true: all 0 golds are "found"
            item_gbs_f1.append(0.0)
            continue
        
        # G-BS-P (Precision): Avg. max similarity for each *predicted* triple
        P_gbs, _, _ = bert_score_compute(
            pred_strings,
            [gold_strings] * len(pred_strings),  # Compare each pred to ALL golds
            lang="en", verbose=False,
            model_type="roberta-large", idf=True
        )
        gbs_p = P_gbs.mean().item()

        # G-BS-R (Recall): Avg. max similarity for each *gold* triple
        P_recall, _, _ = bert_score_compute(  # Capture the first value (Precision)
            gold_strings,
            [pred_strings] * len(gold_strings),  # Compare each gold to ALL preds
            lang="en", verbose=False,
            model_type="roberta-large", idf=True
        )
        gbs_r = P_recall.mean().item()  # Use the Precision from the flipped call
        
        gbs_f1 = 0.0
        if (gbs_p + gbs_r) > 0:
            gbs_f1 = 2 * gbs_p * gbs_r / (gbs_p + gbs_r)

        item_gbs_p.append(gbs_p)
        item_gbs_r.append(gbs_r)
        item_gbs_f1.append(gbs_f1)

    # Macro averages (mean over items)
    macro_gbs_p = np.nanmean(item_gbs_p) if item_gbs_p else 0.0
    macro_gbs_r = np.nanmean(item_gbs_r) if item_gbs_r else 0.0
    macro_gbs_f1 = np.nanmean(item_gbs_f1) if item_gbs_f1 else 0.0

    # Micro for G-BS: Flatten all triples and compute global greedy
    all_pred_triple_strs = [f"{h} {r} {t}" for triples in item_details for h, r, t in triples['pred_triples']]
    all_gold_triple_strs = [f"{h} {r} {t}" for triples in item_details for h, r, t in triples['gold_triples']]
    
    if all_pred_triple_strs and all_gold_triple_strs:
        # Micro P: Avg max sim for each pred vs all golds
        P_micro_g, _, _ = bert_score_compute(
            all_pred_triple_strs,
            [all_gold_triple_strs] * len(all_pred_triple_strs),
            lang="en", verbose=False, model_type="roberta-large", idf=True
        )
        micro_gbs_p = np.nanmean(P_micro_g.cpu().numpy())

        # Micro R: Avg max sim for each gold vs all preds
        P_micro_r, _, _ = bert_score_compute(  # Capture the first value (Precision)
            all_gold_triple_strs,
            [all_pred_triple_strs] * len(all_gold_triple_strs),
            lang="en", verbose=False, model_type="roberta-large", idf=True
        )
        micro_gbs_r = np.nanmean(P_micro_r.cpu().numpy())
        
        micro_gbs_f1 = 0.0
        if (micro_gbs_p + micro_gbs_r) > 0:
            micro_gbs_f1 = 2 * micro_gbs_p * micro_gbs_r / (micro_gbs_p + micro_gbs_r)
    else:
        micro_gbs_p, micro_gbs_r, micro_gbs_f1 = 0.0, 0.0, 0.0
    
else:
    # All zeros for empties
    macro_p_bert, macro_r_bert, macro_f1_bert = 0.0, 0.0, 0.0
    micro_p_bert, micro_r_bert, micro_f1_bert = 0.0, 0.0, 0.0
    macro_gbs_p, macro_gbs_r, macro_gbs_f1 = 0.0, 0.0, 0.0
    micro_gbs_p, micro_gbs_r, micro_gbs_f1 = 0.0, 0.0, 0.0

# G-BERTScore (Hungarian Matching) from EXPLAGRAPHS
def get_g_bert_score(all_gold_edges, all_pred_edges, idf=False):
    references = []
    candidates = []
    ref_cand_index = {}
    for graph_idx, (gold_edges, pred_edges) in enumerate(zip(all_gold_edges, all_pred_edges)):
        for gold_edge in gold_edges:
            for pred_edge in pred_edges:
                references.append(gold_edge)
                candidates.append(pred_edge)
                ref_cand_index[(graph_idx, gold_edge, pred_edge)] = len(references) - 1

    if not references:
        return np.zeros(len(all_gold_edges)), np.zeros(len(all_gold_edges)), np.zeros(len(all_gold_edges))

    _, _, bs_F1 = bert_score_compute(candidates, references, lang='en', verbose=False, model_type="roberta-large", idf=idf)
    bs_F1 = bs_F1.cpu().numpy()

    precisions, recalls, f1s = [], [], []
    for graph_idx, (gold_edges, pred_edges) in enumerate(zip(all_gold_edges, all_pred_edges)):
        num_gold = len(gold_edges)
        num_pred = len(pred_edges)
        if num_gold == 0 and num_pred == 0:
            precisions.append(1.0)
            recalls.append(1.0)
            f1s.append(1.0)
            continue
        if num_pred == 0:
            precisions.append(1.0)
            recalls.append(0.0)
            f1s.append(0.0)
            continue
        if num_gold == 0:
            precisions.append(0.0)
            recalls.append(1.0)
            f1s.append(0.0)
            continue

        score_matrix = np.zeros((num_gold, num_pred))
        for i, gold_edge in enumerate(gold_edges):
            for j, pred_edge in enumerate(pred_edges):
                idx = ref_cand_index.get((graph_idx, gold_edge, pred_edge))
                if idx is not None:
                    score_matrix[i, j] = bs_F1[idx]

        row_ind, col_ind = linear_sum_assignment(score_matrix, maximize=True)
        matched_sim = score_matrix[row_ind, col_ind]

        sample_precision = matched_sim.sum() / num_pred
        sample_recall = matched_sim.sum() / num_gold
        sample_f1 = 2 * sample_precision * sample_recall / (sample_precision + sample_recall) if sample_precision + sample_recall > 0 else 0.0

        precisions.append(sample_precision)
        recalls.append(sample_recall)
        f1s.append(sample_f1)

    return np.array(precisions), np.array(recalls), np.array(f1s)

all_gold_edges = [[f"{h} {r} {t}" for h, r, t in detail['gold_triples']] for detail in item_details]
all_pred_edges = [[f"{h} {r} {t}" for h, r, t in detail['pred_triples']] for detail in item_details]

precisions_h, recalls_h, f1s_h = get_g_bert_score(all_gold_edges, all_pred_edges, idf=False)

macro_gbs_p_h = np.nanmean(precisions_h) if len(precisions_h) > 0 else 0.0
macro_gbs_r_h = np.nanmean(recalls_h) if len(recalls_h) > 0 else 0.0
macro_gbs_f1_h = np.nanmean(f1s_h) if len(f1s_h) > 0 else 0.0

# Micro for G-BERTScore: Sum matched similarities from per-sample computations
total_matched_sum = 0.0
for i in range(len(precisions_h)):
    total_matched_sum += precisions_h[i] * len(all_pred_edges[i])
micro_gbs_p_h = total_matched_sum / pred_total if pred_total > 0 else 0.0
micro_gbs_r_h = total_matched_sum / gold_total if gold_total > 0 else 0.0
micro_gbs_f1_h = 2 * micro_gbs_p_h * micro_gbs_r_h / (micro_gbs_p_h + micro_gbs_r_h) if micro_gbs_p_h + micro_gbs_r_h > 0 else 0.0

# Soft Semantic Score
threshold = 0.8
global_matched_preds, global_matched_golds = 0, 0
item_ps_soft, item_rs_soft, item_f1s_soft = [], [], []

for detail in item_details:
    pred_triples, gold_triples = detail['pred_triples'], detail['gold_triples']
    if not pred_triples and not gold_triples:
        item_ps_soft.append(1.0)
        item_rs_soft.append(1.0)
        item_f1s_soft.append(1.0)
        continue

    pred_strings = [f"{h} {r} {t}" for h, r, t in pred_triples]
    gold_strings = [f"{h} {r} {t}" for h, r, t in gold_triples]

    if not pred_strings or not gold_strings:
        matched_preds, matched_golds = 0, 0
    else:
        pred_embs = sentence_model.encode(pred_strings, convert_to_tensor=True).cpu().numpy()
        gold_embs = sentence_model.encode(gold_strings, convert_to_tensor=True).cpu().numpy()
        
        # Cosine similarity matrix (normalize to [0,1] for matching)
        sim_matrix = np.dot(pred_embs, gold_embs.T) / (np.linalg.norm(pred_embs, axis=1)[:, np.newaxis] * np.linalg.norm(gold_embs, axis=1))
        sim_matrix = np.maximum(sim_matrix, 0)  # Clip negatives
        
        # Hungarian for optimal assignment (maximize similarity)
        row_ind, col_ind = linear_sum_assignment(sim_matrix, maximize=True)
        matched_sim = sim_matrix[row_ind, col_ind]
        
        # Count matches above threshold
        matched_preds = sum(1 for s in matched_sim if s > threshold)
        matched_golds = matched_preds  # Symmetric in bipartite

    p_item = matched_preds / len(pred_triples) if pred_triples else (1.0 if not gold_triples else 0.0)
    r_item = matched_golds / len(gold_triples) if gold_triples else (1.0 if not pred_triples else 0.0)
    f1_item = 2 * p_item * r_item / (p_item + r_item) if p_item + r_item > 0 else 0.0

    item_ps_soft.append(p_item)
    item_rs_soft.append(r_item)
    item_f1s_soft.append(f1_item)
    global_matched_preds += matched_preds
    global_matched_golds += matched_golds
    logging.info(f"Item {detail['index']} Semantic: P={p_item:.4f}, R={r_item:.4f}, F1={f1_item:.4f}")

micro_p_soft = global_matched_preds / pred_total if pred_total > 0 else 0.0
micro_r_soft = global_matched_golds / gold_total if gold_total > 0 else 0.0
micro_f1_soft = 2 * micro_p_soft * micro_r_soft / (micro_p_soft + micro_r_soft) if micro_p_soft + micro_r_soft > 0 else 0.0

macro_p_soft = np.nanmean(item_ps_soft) if item_ps_soft else 0.0
macro_r_soft = np.nanmean(item_rs_soft) if item_rs_soft else 0.0
macro_f1_soft = np.nanmean(item_f1s_soft) if item_f1s_soft else 0.0

# --- Final Output ---
output_str = f"""
Benchmark Results:
---------------------------------
Number of examples: {len(data)}
Strict Averages:
  - Macro: P={macro_p_strict:.4f}, R={macro_r_strict:.4f}, F1={macro_f1_strict:.4f}
  - Micro: P={micro_p_strict:.4f}, R={micro_r_strict:.4f}, F1={micro_f1_strict:.4f}

BERTScore Averages:
  - Macro: P={macro_p_bert:.4f}, R={macro_r_bert:.4f}, F1={macro_f1_bert:.4f}
  - Micro: P={micro_p_bert:.4f}, R={micro_r_bert:.4f}, F1={micro_f1_bert:.4f}

Greedy BS Averages:
  - Macro: P={macro_gbs_p:.4f}, R={macro_gbs_r:.4f}, F1={macro_gbs_f1:.4f}
  - Micro: P={micro_gbs_p:.4f}, R={micro_gbs_r:.4f}, F1={micro_gbs_f1:.4f}

G-BERTScore Averages:
  - Macro: P={macro_gbs_p_h:.4f}, R={macro_gbs_r_h:.4f}, F1={macro_gbs_f1_h:.4f}
  - Micro: P={micro_gbs_p_h:.4f}, R={micro_gbs_r_h:.4f}, F1={micro_gbs_f1_h:.4f}

Semantic Score Averages (Threshold: {threshold}):
  - Macro: P={macro_p_soft:.4f}, R={macro_r_soft:.4f}, F1={macro_f1_soft:.4f}
  - Micro: P={micro_p_soft:.4f}, R={micro_r_soft:.4f}, F1={micro_f1_soft:.4f}
"""
print(output_str)
logging.info(output_str)