In [1]:
import numpy as np
from tqdm import tqdm
from sklearn.metrics import f1_score, precision_score, recall_score

# === 1. Get predictions on the validation set ===
predictions = trainer.predict(val_dataset)
pred_labels = np.argmax(predictions.predictions, axis=-1)
true_labels = predictions.label_ids

# Convert validation dataset back to pandas for grouping
val_df = pd.DataFrame(val_dataset)

# === 2. Attach predicted labels to each example ===
val_df["pred_label"] = pred_labels
val_df["true_label"] = true_labels

# === 3. Compute per-acronym F1 based on set comparison ===
results = []
for acronym, group in val_df.groupby("acronym"):
    # Get sets of options judged true
    predicted_true = set(group.loc[group["pred_label"] == 1, "option_text"])
    actual_true = set(group.loc[group["true_label"] == 1, "option_text"])
    
    VP = len(predicted_true & actual_true)
    FP = len(predicted_true - actual_true)
    FN = len(actual_true - predicted_true)
    
    precision = VP / (VP + FP) if (VP + FP) > 0 else 0
    recall = VP / (VP + FN) if (VP + FN) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    
    results.append({
        "acronym": acronym,
        "VP": VP,
        "FP": FP,
        "FN": FN,
        "precision": precision,
        "recall": recall,
        "f1": f1
    })

# === 4. Display per-acronym and global F1 ===
results_df = pd.DataFrame(results)
display(results_df)

# Weighted/global averages
global_VP = results_df["VP"].sum()
global_FP = results_df["FP"].sum()
global_FN = results_df["FN"].sum()

global_precision = global_VP / (global_VP + global_FP) if (global_VP + global_FP) > 0 else 0
global_recall = global_VP / (global_VP + global_FN) if (global_VP + global_FN) > 0 else 0
global_f1 = 2 * global_precision * global_recall / (global_precision + global_recall) if (global_precision + global_recall) > 0 else 0

print("\n📊 Global Metrics (based on set-level comparison):")
print(f"Precision: {global_precision:.3f}")
print(f"Recall:    {global_recall:.3f}")
print(f"F1-score:  {global_f1:.3f}")


NameError: name 'trainer' is not defined