In [None]:
!pip install transformers
!pip install torch torchvision torchaudio
!pip install tqdm
!pip install sentence-transformers
!pip install nltk

import nltk
nltk.download('punkt')

## SBert QA

In [None]:
import nltk
import json
import torch
import os
from sentence_transformers import SentenceTransformer, util

import sys
sys.path.append(".")
sys.path.append("..") # Adds higher directory to python modules path.
from eval.eval import ClickbaitResolverEvaluator

In [None]:
ENTRY_SETS = ['train', 'dev']
DATA_PATH = "../data/"
RESULT_PATH = "../data/baseline_results/sbert_qa/"

In [None]:
embedder = SentenceTransformer('multi-qa-mpnet-base-cos-v1')

In [None]:
def compute_sbert_qa(entries):
    results = []
    for entry in entries:
        text = entry["text"]
        if text[0] == ".":
            text = text[1:].strip()
        sentences = nltk.tokenize.sent_tokenize(text, language='english')
        query_embedding = embedder.encode(entry["title"], convert_to_tensor=True)
        corpus_embeddings = embedder.encode(sentences, convert_to_tensor=True)

        top_k = 1

        # We use cosine-similarity and torch.topk to find the highest score
        cos_scores = util.dot_score(query_embedding, corpus_embeddings)[0]
        top_results = torch.topk(cos_scores, k=top_k)

        answer = sentences[top_results[1][0]].strip()
        print(f"{entry['title']} -> {answer}")
        results.append({"id": entry["id"], "answer": answer})
    return results

In [None]:
os.makedirs(RESULT_PATH, exist_ok=True)

for s in ENTRY_SETS:
    with open(f"{DATA_PATH}final_{s}.json", "r") as entry_file:
        results = compute_sbert_qa(json.load(entry_file))
        
    with open(f"{RESULT_PATH}{s}.json", "w") as result_file:
        json.dump(results, result_file, indent=2, ensure_ascii=False)

In [None]:
evaluator = ClickbaitResolverEvaluator()

for s in ENTRY_SETS:
    agg_results, results = evaluator.run_file(f"{RESULT_PATH}{s}.json", f"{DATA_PATH}final_{s}.json")
    evaluator.print_results(agg_results, results, False)