In [1]:
import torch
import torch.nn.functional as f
from datasets import load_dataset
from tqdm import tqdm
from SentenceSimilarity import CrossSimilarity
from sentence_transformers import util

In [2]:
dataset = load_dataset("multi_news")
test_dataset = dataset['test']
device = 'cuda'

In [3]:
def eval_batch(cs: CrossSimilarity, batched_text, batched_summarys, batch_size=32):
    res = []
    kr = 0
    for i in range(len(batched_text)):
        proposed, curr = cs.compute(batched_text[i], batched_summarys[i])
        kr += curr
        res.append(cs.model.encode(proposed))
    score = util.cos_sim(res, cs.model.encode(batched_summarys))
    score = torch.diag(score)
    return score, kr / batch_size

In [4]:
def evaluate(model_name='all-mpnet-base-v2', batch_size=32):
    cs = CrossSimilarity(model_name=model_name, rt=0.4)
    total_score = 0
    test_batches = len(test_dataset['document']) // batch_size
    loop = tqdm(total=test_batches, position=0, leave=False)
    for i in range(0, len(test_dataset['document']), batch_size):
        document_batch = test_dataset['document'][i:i+batch_size]
        summary_batch = test_dataset['summary'][i:i+batch_size]
        score, kr = eval_batch(cs, document_batch, summary_batch, batch_size)
        total_score += score.mean().item()
        average_loss = total_score / ((i // batch_size) + 1)

        loop.set_description(f"Average Score: {average_loss:.4f}, with final keep rate {kr:.4f}")
        loop.update(1)
    final_average_loss = total_score / test_batches
    print(f"Final Average Score for {model_name}: {final_average_loss:.4f}")

In [None]:
torch.cuda.empty_cache()
evaluate()