In [11]:
import torch
import torch.nn.functional as f
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util
from tqdm import tqdm

In [67]:
dataset = load_dataset("multi_news")
test_dataset = dataset['test']

In [77]:
base_path = 'sentence-transformers/'
model_names=['paraphrase-MiniLM-L6-v2', 'all-mpnet-base-v2', 'distiluse-base-multilingual-cased-v2', 'all-MiniLM-L12-v2'] # add more model names to evaluate
device = 'cuda'

In [81]:
def compute_summary_score_for_batch(model, document_batch, summary_batch):
    input_encoded = model.encode(document_batch)
    summary_encoded = model.encode(summary_batch)
    score = util.cos_sim(input_encoded, summary_encoded)
    score = torch.diag(score)
    return score

In [79]:
def evaluate(model_name='all-mpnet-base-v2', batch_size=32):
    total_score = 0
    model = SentenceTransformer(base_path + model_name)
    test_batches = len(test_dataset['document']) // batch_size
    loop = tqdm(total=test_batches, position=0, leave=False)
    for i in range(0, len(test_dataset['document']), batch_size):
        document_batch = test_dataset['document'][i:i+batch_size]
        summary_batch = test_dataset['summary'][i:i+batch_size]

        score = compute_summary_score_for_batch(model, document_batch, summary_batch)

        total_score += score.mean().item()
        # TODO: compute the average_reward so far for the display of the progress bar.
        average_loss = total_score / ((i // batch_size) + 1)

        loop.set_description(f"Average Score: {average_loss:.4f}")
        loop.update(1)
    final_average_loss = total_score / test_batches
    print(f"Final Average Score: {final_average_loss:.4f}")

In [83]:
for name in model_names:
    evaluate(name)

                                                                        

Final Average Score: 0.8133


