# Sentence Transformer Evaluation

In [30]:
from sentence_transformers import SentenceTransformer, util
import os
from triplet_semantic_similarity_sbert import form_triplets, load_doc_col, split_doc_col, form_triplets
from enum import Enum
import torch
from tiltify.config import Path
from tqdm import tqdm


TOP_K=10
doc_col = load_doc_col()

queries = {(1,
            "Right To Information"): "Das Bestehen eines Rechts auf Auskunft seitens des Verantwortlichen über die "
                                     "betreffenden personenbezogenen Daten.",
           (2,
            "Right To Rectification Or Deletion"): "Die betroffene Person hat das Recht, von dem Verantwortlichen "
                                                   "unverzüglich die Berichtigung sie betreffender unrichtiger "
                                                   "personenbezogener Daten zu verlangen. Die betroffene Person hat "
                                                   "das Recht, von dem Verantwortlichen zu verlangen, dass sie "
                                                   "betreffende personenbezogene Daten unverzüglich gelöscht werden.",
           (3,
            "Right To Data Portability"): "Die betroffene Person hat das Recht, die sie betreffenden personenbezogenen "
                                          "Daten, die sie einem Verantwortlichen bereitgestellt hat, in einem "
                                          "strukturierten, gängigen und maschinenlesbaren Format zu erhalten.",
           (4, "Right To Withdraw Consent"): "Das Bestehen eines Rechts, die Einwilligung jederzeit zu widerrufen.",
           (5, "Right To Complain"): "Das Bestehen eines Beschwerderechts bei einer Aufsichtsbehörde."}

query_names = {
    1: "Right To Information",
    2: "Right To Rectification Or Deletion",
    3: "Right To Data Portability",
    4: "Right To Withdraw Consent",
    5: "Right To Complain"
}

model_paths = {
    1: "scripts/triplet_training/real_runs/triplet_semantic_search_results_2022-09-30_19:18:01/models/triplet_semantic_search_right_to_information",
    2: "scripts/triplet_training/real_runs/triplet_semantic_search_results_2022-09-30_19:18:01/models/triplet_semantic_search_right_to_rectification_or_deletion",
    3: "scripts/triplet_training/real_runs/triplet_semantic_search_results_2022-09-30_19:18:01/models/triplet_semantic_search_right_to_data_portability",
    4: "scripts/triplet_training/real_runs/triplet_semantic_search_results_2022-09-30_19:18:01/models/triplet_semantic_search_right_to_withdraw_consent",
    5: "scripts/triplet_training/real_runs/triplet_semantic_search_results_2022-09-30_19:18:01/models/triplet_semantic_search_right_to_complain"
}

In [31]:
eval_data_per_query = {}
for id_tuple, query in queries.items():
    query_id = id_tuple[0]
    query_name = id_tuple[1]
    model_path = os.path.join(os.path.dirname(os.path.join(Path.root_path, "scripts")), model_paths[query_id])

    train_data, test_data, train_docs, test_docs = split_doc_col(doc_col, query_id, with_docs=True)
    # test_triplets = form_triplets(query_id, test_docs)
    model = SentenceTransformer.load(model_path)
    eval_data_per_query[query_id] = {"model": model, "test_docs": test_docs}


In [32]:
query_results = {}
for query_ident, query in queries.items():
    query_id = query_ident[0]
    query_name = query_ident[1]
    
    eval_dict = eval_data_per_query[query_id]
    model = eval_dict["model"]
    test_docs = eval_dict["test_docs"]
    # test_triplets = eval_dict["test_triplets"]
    encoded_query = model.encode(query, convert_to_tensor=True)
    
    
    found_labels = []
    found_blobs = []
    scores = []
    embedded_docs = []
    labels = []
    for test_doc, doc_labels in tqdm(test_docs):
        embedded_corpus = model.encode(test_doc, convert_to_tensor=True)
        cos_scores = util.cos_sim(encoded_query, embedded_corpus)[0]
        top_results, indices = torch.topk(cos_scores, k=TOP_K, largest=True, sorted=True)
        indices = indices.tolist()
        found_doc_blobs = [test_doc[idx] for idx in indices]
        found_doc_labels = [doc_labels[idx] for idx in indices]
        found_blobs.append(found_doc_blobs)
        found_labels.append(found_doc_labels)
        scores.append(top_results.tolist())
        embedded_docs.append(embedded_corpus)
        labels.extend(doc_labels)
    
    query_results[query_id] = {"blobs": found_blobs, "found_labels": found_labels, "scores": scores, "encoded_query": encoded_query, "embedded_docs": embedded_docs, "labels": labels}

100%|██████████| 21/21 [00:21<00:00,  1.01s/it]
100%|██████████| 22/22 [00:19<00:00,  1.15it/s]
100%|██████████| 18/18 [00:24<00:00,  1.33s/it]
100%|██████████| 20/20 [00:18<00:00,  1.08it/s]
100%|██████████| 20/20 [00:27<00:00,  1.39s/it]


## Evaluation

### Accuracies per Query

If the fitting sentence is under the top_k suggestion, we consider it as a hit.

In [33]:
hit_dict = {}
for query_ident, query in queries.items():
    query_id = query_ident[0]
    query_name = query_ident[1]
    
    query_result = query_results[query_id]
    found_labels = [sum(found_doc_labels, []) for found_doc_labels in query_result["found_labels"]]
    #scores = query_results["scores"]  
    hit_or_not = [True if query_id in found_doc_labels else False for found_doc_labels in found_labels]
    hit_dict[query_id] = sum(hit_or_not)/len(hit_or_not)
for query_id, search_result in hit_dict.items():
    print(f"For {query_names[query_id]} Accuracy is {hit_dict[query_id]}")

For Right To Information Accuracy is 1.0
For Right To Rectification Or Deletion Accuracy is 1.0
For Right To Data Portability Accuracy is 1.0
For Right To Withdraw Consent Accuracy is 1.0
For Right To Complain Accuracy is 1.0
