# Sentence Transformer Evaluation

In [None]:
from sentence_transformers import SentenceTransformer, util
import os
from triplet_semantic_similarity_sbert import form_triplets, load_doc_col, split_doc_col, form_triplets
from enum import Enum
import torch


TOP_K=10
doc_col = load_doc_col()

queries = {(1,
            "Right To Information"): "Das Bestehen eines Rechts auf Auskunft seitens des Verantwortlichen über die "
                                     "betreffenden personenbezogenen Daten.",
           (2,
            "Right To Rectification Or Deletion"): "Die betroffene Person hat das Recht, von dem Verantwortlichen "
                                                   "unverzüglich die Berichtigung sie betreffender unrichtiger "
                                                   "personenbezogener Daten zu verlangen. Die betroffene Person hat "
                                                   "das Recht, von dem Verantwortlichen zu verlangen, dass sie "
                                                   "betreffende personenbezogene Daten unverzüglich gelöscht werden.",
           (3,
            "Right To Data Portability"): "Die betroffene Person hat das Recht, die sie betreffenden personenbezogenen "
                                          "Daten, die sie einem Verantwortlichen bereitgestellt hat, in einem "
                                          "strukturierten, gängigen und maschinenlesbaren Format zu erhalten.",
           (4, "Right To Withdraw Consent"): "Das Bestehen eines Rechts, die Einwilligung jederzeit zu widerrufen.",
           (5, "Right To Complain"): "Das Bestehen eines Beschwerderechts bei einer Aufsichtsbehörde."}

query_names = {
    1: "Right To Information",
    2: "Right To Rectification Or Deletion",
    3: "Right To Data Portability",
    4: "Right To Withdraw Consent",
    5: "Right To Complain"
}


class ModelPath(Enum):
    
    1 = ""

In [None]:
eval_data_per_query = {}
for id_tuple, query in queries.items():
    query_id = id_tuple[0]
    query_name = id_tuple[1]
    model_path = os.path.join(os.path.dirname(__file__), ModelPath(query_id))

    train_data, test_data, test_docs, train_docs = split_doc_col(doc_col, query_id)
    test_triplets = form_triplets(query_id, test_docs)
    model = SentenceTransformer.load()
    eval_data_per_query[query_id] = {"model": model, "test_triplets": test_triplets, "test_docs": test_docs}


In [None]:
query_results = {}
for query_ident, query in queries:
    query_id = query_ident[0]
    query_name = query_ident[1]
    
    eval_dict = eval_data_per_query[query_id]
    model = eval_dict["model"]
    test_docs = eval_dict["test_docs"]
    test_triplets = eval_dict["test_triplets"]
    encoded_query = model.encode(query, convert_to_tensor=True)
    
    
    found_labels = []
    found_blobs = []
    scores = []
    embedded_docs = []
    labels = []
    for test_doc, doc_labels in test_docs:
        embedded_corpus = model.encode(test_doc, convert_to_tensor=True)
        cos_scores = util.cos_sim(encoded_query, embedded_corpus)[0]
        top_results, indices = torch.topk(cos_scores, k=TOP_K)
        indices = indices.to_list()
        found_doc_blobs = [blob for idx, blob in enumerate(test_doc) if idx in indices]
        found_doc_labels = [label for idx, label in enumerate(doc_labels) if idx in indices]
        found_blobs.append(found_doc_blobs)
        found_labels.append(found_doc_labels)
        scores.append(top_results.to_list())
        embedded_docs.append(embedded_corpus)
        labels.extend(doc_labels)
    
    query_results[query_id] = {"blobs": found_blobs, "labels": found_labels, "scores": scores, "encoded_query": encoded_query, "embedded_docs": embedded_docs, "labels": labels}

## Evaluation

### Accuracies per Query

If the fitting sentence is under the top_k suggestion, we consider it as a hit.

In [None]:
hit_dict = {}
for query_ident, query in queries:
    query_id = query_ident[0]
    query_name = query_ident[1]
    
    query_result = query_results[query_id]
    found_labels = query_result["found_labels"]
    scores = query_results["scores"]
    for idx, found_doc_labels in enumerate(labels):
        
    hit_or_not = [True if query_id in found_doc_labels else False for found_doc_labels in found_labels]
    hit_dict[query_id] = sum(hit_or_not)/len(hit_or_not)
for query_id, search_result in hit_dict.items():
    print(f"For {query_names[query_id]} Accuracy is {hit_dict[query_id]}, with confidence score of: {}")