In [None]:

import torch
from transformers import BertTokenizer, BertModel
import numpy as np
from scipy.spatial.distance import cosine
from typing import List
import random

class MoverScore:
    def __init__(self, model_name='bert-base-uncased'):
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.model = BertModel.from_pretrained(model_name)
        self.model.eval()

    def get_bert_embedding(self, text):
        tokens = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
        with torch.no_grad():
            outputs = self.model(**tokens)
        embeddings = outputs.last_hidden_state.squeeze(0)
        return embeddings.mean(dim=0).cpu().numpy()  # Average pooling

    def compute_idf(self, corpus):
        idf_dict = {}
        num_docs = len(corpus)
        for doc in corpus:
            tokens = self.tokenizer.tokenize(doc)
            for token in set(tokens):
                idf_dict[token] = idf_dict.get(token, 0) + 1
        for token, count in idf_dict.items():
            idf_dict[token] = np.log(num_docs / count)
        return idf_dict

    def compute_moverscore(self, hypothesis, reference, idf_dict):
        hyp_embedding = self.get_bert_embedding(hypothesis)
        ref_embedding = self.get_bert_embedding(reference)
        similarity = 1 - cosine(hyp_embedding, ref_embedding)
        return similarity


def generate_numeral_aware_headline(article: str) -> str:
    """Generate a numeral-aware headline incorporating key numerical reasoning."""
    words = article.split()
    numerals = [word for word in words if word.isdigit()]
    if numerals:
        headline = f"{numerals[0]} impacted in {numerals[1]} related incidents." if len(numerals) > 1 else f"{numerals[0]} affected in an incident."
    else:
        headline = "No numeral-specific data found."
    return headline


def generate_numeral_aware_headlines_loop(articles: List[str]) -> List[str]:
    """Generate a list of numeral-aware headlines for a list of articles."""
    generated_headlines = []
    for article in articles:
        headline = generate_numeral_aware_headline(article)
        generated_headlines.append(headline)
    return generated_headlines


def compare_generated_to_reference(generated_headlines: List[str], reference_headlines: List[str], articles: List[str]) -> float:
    """Compare the generated headlines to the reference headlines using MoverScore."""
    mover_score = MoverScore()
    idf_dict = mover_score.compute_idf(articles)
    score = 0
    for gen, ref in zip(generated_headlines, reference_headlines):
        score += mover_score.compute_moverscore(gen, ref, idf_dict)
    return score / len(generated_headlines)  # Average MoverScore


# Example usage
articles = [
    "At least 30 gunmen burst into a drug rehabilitation center in a Mexican border state capital and opened fire, killing 19 men and wounding four people, police said.",
    "Gunmen also killed 16 people in another drug-plagued northern city. The killings in Chihuahua city and in Ciudad Madero marked one of the bloodiest weeks ever in Mexico."
]
reference_headlines = [
    "Gunmen attacked a drug rehabilitation center in Mexico, killing 19 and wounding 4.",
    "16 people were killed in another northern city, marking one of Mexico's bloodiest weeks."
]

# Step 1: Generate numeral-aware headlines
generated_headlines = generate_numeral_aware_headlines_loop(articles)

# Step 2: Compare generated headlines to reference headlines
numeral_aware_score = compare_generated_to_reference(generated_headlines, reference_headlines, articles)

# Outputs
print("Generated Headlines:")
print(articles)
print(generated_headlines)
print(f"Numeral-aware Headline Generation MoverScore: {numeral_aware_score}")


Generated Headlines:
['At least 30 gunmen burst into a drug rehabilitation center in a Mexican border state capital and opened fire, killing 19 men and wounding four people, police said.', 'Gunmen also killed 16 people in another drug-plagued northern city. The killings in Chihuahua city and in Ciudad Madero marked one of the bloodiest weeks ever in Mexico.']
['30 impacted in 19 related incidents.', '16 affected in an incident.']
Numeral-aware Headline Generation MoverScore: 0.5906293392181396
