In [16]:
pip install streamlit sentence-transformers scikit-learn matplotlib

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [17]:
import numpy as np
from sentence_transformers import SentenceTransformer, util
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, mean_squared_error, precision_score, recall_score, roc_curve, precision_recall_curve
import matplotlib.pyplot as plt

In [18]:
# Load pretrained SBERT
model = SentenceTransformer("all-MiniLM-L6-v2")

In [3]:
# Function to compute similarity between two documents
def compute_similarity(doc1: str, doc2: str):
    embeddings = model.encode([doc1, doc2], convert_to_tensor=True)
    similarity = util.cos_sim(embeddings[0], embeddings[1]).item()
    return similarity

In [19]:
# Evaluation Function 
def evaluate_similarity(doc1, doc2, true_label, threshold=0.7, plot=True):
    # Compute cosine similarity
    sim_score = compute_similarity(doc1, doc2)

    # Predicted class based on threshold
    pred_label = 1 if sim_score >= threshold else 0
    true_labels = [true_label]
     # Convert labels to arrays for sklearn metrics
    y_true = np.array([true_label])  # 1=similar, 0=not similar
    y_pred = np.array([pred_label])
    y_score = np.array([sim_score])

    # Metrics
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred)
    rec = recall_score(y_true, y_pred)
    rmse = np.sqrt(mean_squared_error(y_true, y_score))
    roc_auc = roc_auc_score(y_true, y_score)

    print("---- Evaluation Results ----")
    print(f"Similarity Score: {sim_score:.4f}")
    print(f"Prediction (Threshold={threshold}): {'Similar' if pred_label else 'Not Similar'}")
    print(f"Ground Truth: {'Similar' if true_label else 'Not Similar'}")
    print(f"Accuracy: {acc:.4f}")
    print(f"F1-score: {f1:.4f}")
    print(f"Precision: {prec:.4f}")
    print(f"Recall: {rec:.4f}")
    print(f"RMSE: {rmse:.4f}")
    print(f"ROC-AUC: {roc_auc:.4f}")
    if plot:
        # ROC Curve
        fpr, tpr, _ = roc_curve(y_true, y_score)
        plt.figure(figsize=(12, 5))
        plt.subplot(1, 2, 1)
        plt.plot(fpr, tpr, label=f'ROC curve (area = {roc_auc:.2f})')
        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic')
        plt.legend(loc='lower right')

        # Precision-Recall Curve
        precision, recall, _ = precision_recall_curve(y_true, y_score)
        plt.subplot(1, 2, 2)
        plt.plot(recall, precision, label='Precision-Recall curve')
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.title('Precision-Recall Curve')
        plt.legend(loc='lower left')

        plt.tight_layout()
        plt.show()