In [1]:
from datasets import load_dataset
from keybert import KeyBERT
import re
from typing import List, Tuple
import numpy as np
from tqdm.notebook import tqdm
import time
import pickle
from datetime import datetime


In [2]:
# Load SemEval dataset
sem_eval_ds = load_dataset("SemEvalWorkshop/sem_eval_2010_task_8")

# Load Inspec dataset
inspec_ds = load_dataset("midas/inspec", "generation")


In [3]:
def clean_sem_eval_sentence(text: str) -> str:
    """Remove entity XML tags from SemEval sentences."""
    return re.sub(r'</?e[12]>', '', text).strip()

def prepare_inspec_text(tokens: List[str]) -> str:
    """Convert list of tokens to clean string for Inspec."""
    return ' '.join([t for t in tokens if not (t.startswith('-') and t.endswith('-'))])

def extract_keyphrases(text: str, model: KeyBERT, top_n: int = 8) -> List[str]:
    """Extract keyphrases using KeyBERT with maxsum enabled."""
    keyphrases = model.extract_keywords(
        text,
        keyphrase_ngram_range=(1, 3),
        stop_words='english',
        top_n=top_n,
        use_maxsum=True     # <-- Extension 1: improves diversity
    )
    return [kp[0] for kp in keyphrases]

# Initialize KeyBERT with default embedding model
print("Loading KeyBERT model (default)...")
model = KeyBERT()


Loading KeyBERT model (default)...


In [4]:
def evaluate_matches(true_phrases: List[str], extracted_phrases: List[str], partial_match: bool = True) -> Tuple[float, float, float]:
    """Calculate precision, recall, and F1 score with optional partial matching."""
    if partial_match:
        matched_true = set()
        matched_extracted = set()
        for i, ext in enumerate(extracted_phrases):
            for j, true in enumerate(true_phrases):
                if ext.lower() in true.lower() or true.lower() in ext.lower():
                    matched_extracted.add(i)
                    matched_true.add(j)
        matches = len(matched_true)
    else:
        matches = sum(1 for ext in extracted_phrases if any(ext.lower() == true.lower() for true in true_phrases))

    precision = matches / len(extracted_phrases) if extracted_phrases else 0
    recall = matches / len(true_phrases) if true_phrases else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    return precision, recall, f1


In [5]:
def process_sem_eval(dataset, model, top_n=2):
    results = []
    metrics = []
    for sample in tqdm(dataset):
        text = clean_sem_eval_sentence(sample['sentence'])
        true_entities = re.findall(r'<e[12]>(.*?)</e[12]>', sample['sentence'])
        extracted = extract_keyphrases(text, model, top_n=top_n)

        precision, recall, f1 = evaluate_matches(true_entities, extracted)
        metrics.append((precision, recall, f1))

        results.append({
            'sentence': sample['sentence'],
            'true_entities': true_entities,
            'extracted_phrases': extracted,
            'metrics': {'precision': precision, 'recall': recall, 'f1': f1}
        })
    avg_metrics = {
        'precision': np.mean([m[0] for m in metrics]),
        'recall': np.mean([m[1] for m in metrics]),
        'f1': np.mean([m[2] for m in metrics]),
    }
    return results, avg_metrics


In [6]:
def process_inspec(dataset, model, top_n=8):
    results = []
    metrics = []
    for sample in tqdm(dataset):
        text = prepare_inspec_text(sample['document'])
        true_keyphrases = sample['extractive_keyphrases']
        extracted = extract_keyphrases(text, model, top_n=top_n)

        precision, recall, f1 = evaluate_matches(true_keyphrases, extracted)
        metrics.append((precision, recall, f1))

        results.append({
            'document': sample['document'],
            'true_keyphrases': true_keyphrases,
            'extracted_phrases': extracted,
            'metrics': {'precision': precision, 'recall': recall, 'f1': f1}
        })
    avg_metrics = {
        'precision': np.mean([m[0] for m in metrics]),
        'recall': np.mean([m[1] for m in metrics]),
        'f1': np.mean([m[2] for m in metrics]),
    }
    return results, avg_metrics


In [7]:
print("Processing SemEval-2010 train dataset...")
sem_eval_train_results, sem_eval_train_metrics = process_sem_eval(sem_eval_ds['train'], model, top_n=2)

print("Processing SemEval-2010 test dataset...")
sem_eval_test_results, sem_eval_test_metrics = process_sem_eval(sem_eval_ds['test'], model, top_n=2)

print("\nSemEval-2010 Metrics:")
print(f"Train - Precision: {sem_eval_train_metrics['precision']:.3f}, Recall: {sem_eval_train_metrics['recall']:.3f}, F1: {sem_eval_train_metrics['f1']:.3f}")
print(f"Test  - Precision: {sem_eval_test_metrics['precision']:.3f}, Recall: {sem_eval_test_metrics['recall']:.3f}, F1: {sem_eval_test_metrics['f1']:.3f}")


Processing SemEval-2010 train dataset...


  0%|          | 0/8000 [00:00<?, ?it/s]

Processing SemEval-2010 test dataset...


  0%|          | 0/2717 [00:00<?, ?it/s]


SemEval-2010 Metrics:
Train - Precision: 0.548, Recall: 0.548, F1: 0.548
Test  - Precision: 0.544, Recall: 0.544, F1: 0.544


In [8]:
print("Processing Inspec train dataset...")
inspec_train_results, inspec_train_metrics = process_inspec(inspec_ds['train'], model, top_n=8)

print("Processing Inspec test dataset...")
inspec_test_results, inspec_test_metrics = process_inspec(inspec_ds['test'], model, top_n=8)

print("\nInspec Metrics:")
print(f"Train - Precision: {inspec_train_metrics['precision']:.3f}, Recall: {inspec_train_metrics['recall']:.3f}, F1: {inspec_train_metrics['f1']:.3f}")
print(f"Test  - Precision: {inspec_test_metrics['precision']:.3f}, Recall: {inspec_test_metrics['recall']:.3f}, F1: {inspec_test_metrics['f1']:.3f}")


Processing Inspec train dataset...


  0%|          | 0/1000 [00:00<?, ?it/s]

Processing Inspec test dataset...


  0%|          | 0/500 [00:00<?, ?it/s]


Inspec Metrics:
Train - Precision: 0.271, Recall: 0.403, F1: 0.300
Test  - Precision: 0.286, Recall: 0.411, F1: 0.311
