In [None]:
import numpy as np
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.utils.data import DataLoader

# 1. Function to compute accuracy manually
def compute_accuracy(y_true, y_pred):

    correct = sum(1 for t, p in zip(y_true, y_pred) if t == p)
    return correct / len(y_true) if len(y_true) > 0 else 0.0

# 2. Function to compute macro F1 score excluding 'Other' category
def compute_macro_f1_excl_other(y_true, y_pred, other_label=18):

    # Filter out samples where true label is 'Other'
    filtered_pairs = [(t, p) for t, p in zip(y_true, y_pred) if t != other_label]
    if len(filtered_pairs) == 0:
        return 0.0
    
    # Initialize arrays for true positives (tp), false positives (fp), and false negatives (fn)
    K = other_label + 1  # Total number of classes (19 in this case)
    tp = np.zeros(K, dtype=np.float32)
    fp = np.zeros(K, dtype=np.float32)
    fn = np.zeros(K, dtype=np.float32)

    # Compute tp, fp, fn for filtered samples
    for t, p in filtered_pairs:
        if t == p:
            tp[t] += 1
        else:
            fp[p] += 1
            fn[t] += 1

    # Calculate precision, recall, and F1 for each class (0-17)
    f1_sum = 0.0
    for i in range(other_label):  # Only for classes 0-17
        prec_i = tp[i] / (tp[i] + fp[i] + 1e-9)  # Add small epsilon to avoid division by zero
        rec_i = tp[i] / (tp[i] + fn[i] + 1e-9)
        if prec_i + rec_i > 0:
            f1_i = 2 * prec_i * rec_i / (prec_i + rec_i)
        else:
            f1_i = 0.0
        f1_sum += f1_i

    # Compute macro F1 by averaging over 18 relation classes
    macro_f1 = f1_sum / other_label
    return float(macro_f1)

# -------------------------- Evaluation Process --------------------------

# Load test set data (replace with your actual dataset and preprocessing)
test_sentences = dataset['test']['sentence']  # Example: your test sentences
test_labels = dataset['test']['relation']    # Example: your test labels (0-17 or 18 for 'Other')

# Create test dataset and DataLoader (replace SemEvalDataset with your dataset class)
test_dataset = SemEvalDataset(test_sentences, test_labels, word2idx)  # word2idx is your word-to-index mapping
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# Set device (ensure this is defined before running the code)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load your trained model (replace with your model loading logic if needed)
# model = YourModelClass(...)  # Initialize your model
# model.load_state_dict(torch.load('your_model_path.pt'))  # Load trained weights
model.to(device)
model.eval()

# Lists to store true and predicted labels
y_true_test, y_pred_test = [], []

# Perform inference on the test set
with torch.no_grad():
    for X_batch, y_batch in test_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        outputs = model(X_batch)  # Model output: [batch_size, num_classes] (e.g., [B, 19])
        _, predicted = torch.max(outputs, dim=1)  # Get predicted class indices
        y_true_test.extend(y_batch.cpu().tolist())
        y_pred_test.extend(predicted.cpu().tolist())

# (A) Compute overall accuracy
test_acc = compute_accuracy(y_true_test, y_pred_test)

# (B) Compute macro F1 score excluding 'Other'
test_macro_f1_excl_other = compute_macro_f1_excl_other(y_true_test, y_pred_test, other_label=18)

# Print evaluation results
print(f"test dataset acc: {test_acc:.4f}")
print(f"test dataset Macro F1 (without Other): {test_macro_f1_excl_other:.4f}")

# (C) Plot confusion matrix including all 19 classes
class_names = [
    "Cause-Effect(e1,e2)", "Cause-Effect(e2,e1)", "Component-Whole(e1,e2)", "Component-Whole(e2,e1)",
    "Content-Container(e1,e2)", "Content-Container(e2,e1)", "Entity-Destination(e1,e2)", "Entity-Destination(e2,e1)",
    "Entity-Origin(e1,e2)", "Entity-Origin(e2,e1)", "Instrument-Agency(e1,e2)", "Instrument-Agency(e2,e1)",
    "Member-Collection(e1,e2)", "Member-Collection(e2,e1)", "Message-Topic(e1,e2)", "Message-Topic(e2,e1)",
    "Product-Producer(e1,e2)", "Product-Producer(e2,e1)", "Other"
]

# Compute confusion matrix
cm = confusion_matrix(y_true_test, y_pred_test, labels=range(len(class_names)))

# Plot confusion matrix using seaborn
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.title("Test Confusion Matrix")
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=45)
plt.tight_layout()
plt.show()