In [1]:
import torch
import torch.nn.functional as f
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util
from tqdm import tqdm

In [2]:
dataset = load_dataset("multi_news")
test_dataset = dataset['test']

In [3]:
base_path = 'sentence-transformers/'
model_names=['paraphrase-MiniLM-L6-v2', 'all-mpnet-base-v2', 'distiluse-base-multilingual-cased-v2', 'all-MiniLM-L12-v2'] # add more model names to evaluate
device = 'cuda'

In [4]:
def compute_summary_score_for_batch(model, document_batch, summary_batch):
    input_encoded = model.encode(document_batch)
    summary_encoded = model.encode(summary_batch)
    score = util.cos_sim(input_encoded, summary_encoded)
    score = torch.diag(score)
    return score

In [7]:
def evaluate(model_name='all-mpnet-base-v2', batch_size=32):
    total_score = 0
    model = SentenceTransformer(base_path + model_name)
    test_batches = len(test_dataset['document']) // batch_size
    loop = tqdm(total=test_batches, position=0, leave=False)
    for i in range(0, len(test_dataset['document']), batch_size):
        document_batch = test_dataset['document'][i:i+batch_size]
        summary_batch = test_dataset['summary'][i:i+batch_size]
        score = compute_summary_score_for_batch(model, document_batch, summary_batch)
        total_score += score.mean().item()
        average_loss = total_score / ((i // batch_size) + 1)

        loop.set_description(f"Average Score: {average_loss:.4f}")
        loop.update(1)
    final_average_loss = total_score / test_batches
    print(f"Final Average Score for {model_name}: {final_average_loss:.4f}")

In [8]:
for name in model_names:
    torch.cuda.empty_cache()
    evaluate(name)

                                                                        

Final Average Score for paraphrase-MiniLM-L6-v2: 0.6666


                                                                        

Final Average Score for all-mpnet-base-v2: 0.8133


                                                                        

Final Average Score for distiluse-base-multilingual-cased-v2: 0.6530


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/573 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/134M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/352 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

                                                                        

Final Average Score for all-MiniLM-L12-v2: 0.6997


