In [None]:
import xml.etree.ElementTree as ET
import html

def read_answers(xml_file_path):
    dic_answer_id_text = {}
    tree = ET.parse(xml_file_path)
    root = tree.getroot()
    for child in root.iter('DOC'):
        DOCNO = child[0].text
        TEXT = html.unescape(child[1].text)
        dic_answer_id_text[DOCNO] = TEXT
    return dic_answer_id_text

In [None]:
!pip install python-terrier
!pip install --upgrade git+https://github.com/terrier-org/pyterrier.git

In [None]:
import pyterrier as pt
if not pt.started():
  pt.init()

# index the collection
posts = []
dic_answer_id_text = read_answers("LawPosts.xml")
for answer_id in dic_answer_id_text:
  posts.append({'docno':str(answer_id), 'body': dic_answer_id_text[answer_id]})

iter_indexer = pt.IterDictIndexer("./index", meta={'docno': 20, 'body':20000}, overwrite=True)
RETRIEVAL_FIELDS = ['body']
indexref1 = iter_indexer.index(posts, fields=RETRIEVAL_FIELDS)

In [None]:
4# Load Topics
import pandas as pd

def read_topics(xml_file_path):
    lst_topics = []
    tree = ET.parse(xml_file_path)
    root = tree.getroot()
    for child in root.iter('Question'):
        ID = child[0].text
        TITLE = html.unescape(child[1].text)
        TITLE = "".join([x if x.isalnum() else " " for x in TITLE])
        lst_topics.append([str(ID), TITLE])
    return lst_topics

lst_topics = read_topics("TestTopics.xml")

In [None]:
queries = pd.DataFrame(lst_topics, columns=['qid','query'])
result = pt.BatchRetrieve(indexref1, wmodel="TF_IDF").transform(queries)
pt.io.write_results(result, "res_TF_IDF.txt", format='trec')
print(result)

In [None]:
queries = pd.DataFrame(lst_topics, columns=['qid','query'])
result = pt.BatchRetrieve(indexref1, wmodel="BM25").transform(queries)
pt.io.write_results(result, "res_BM25.txt", format='trec')
print(result)

In [None]:
#YAKE
!pip install git+https://github.com/LIAAD/yake

In [None]:
import yake
import pandas as pd
kw_extractor = yake.KeywordExtractor(n=1, top=10)


def read_topics_yake(xml_file_path):
    lst_topics = []
    tree = ET.parse(xml_file_path)
    root = tree.getroot()
    for child in root.iter('Question'):
        ID = child[0].text
        BODY = html.unescape(child[2].text)
        BODY = "".join([x if x.isalnum() else " " for x in BODY])

        TITLE = html.unescape(child[1].text)
        TITLE = "".join([x if x.isalnum() else " " for x in TITLE])

        keywords = kw_extractor.extract_keywords(BODY)[:5]
        # print(keywords)
        Query = ""
        for kw in keywords:
           Query+=kw[0]+" "
        Query = Query.strip() #+ " " + TITLE
        Query = Query.strip() + " " + TITLE
        # print(Query)
        lst_topics.append([str(ID), Query])
    return lst_topics

lst_topics_yake = read_topics_yake("TestTopics.xml")


In [None]:
import pyterrier as pt
if not pt.started():
  pt.init()
  
queries = pd.DataFrame(lst_topics_yake, columns=['qid','query'])
result = pt.BatchRetrieve(indexref1, wmodel="TF_IDF").transform(queries)
pt.io.write_results(result, "res_TF_IDF_YAKE_2.txt", format='trec')
print(result)

queries = pd.DataFrame(lst_topics_yake, columns=['qid','query'])
result = pt.BatchRetrieve(indexref1, wmodel="BM25").transform(queries)
pt.io.write_results(result, "res_BM25_YAKE_2.txt", format='trec')
print(result)

In [None]:
! pip install ranx

In [None]:
from ranx import Qrels, Run, compare, evaluate
qrels = Qrels.from_file("qrel_test.tsv", kind="trec")

# run_11 = Run.from_file("res_TF_IDF.txt", kind="trec")
# run_22 = Run.from_file("res_BM25.txt", kind="trec")
run_1 = Run.from_file("res_BM25_YAKE_2.txt", kind="trec")
run_2 = Run.from_file("res_TF_IDF_YAKE_2.txt", kind="trec")
# run_3 = Run.from_file("all-MiniLM-L12-v2_finetuned.tsv", kind="trec")
# run_4 = Run.from_file("all-distilroberta-v1_finetuned.tsv", kind="trec")

# run_3 = Run.from_file("res_TF_IDF_YAKE.txt", kind="trec")
# run_4 = Run.from_file("res_BM25_YAKE.txt", kind="trec")
# run_5 = Run.from_file("distilroberta.tsv", kind="trec")
# # run_6 = Run.from_file("msmarcodistilbert.tsv", kind="trec")
# run_7 = Run.from_file("all-MiniLM-L12-v2 .tsv", kind="trec")

# report = compare(
# qrels=qrels,
# runs=[run_11, run_22,run_1, run_2, run_3, run_4, run_5, run_7],
# metrics=["precision@1", "mrr@1000"],
# max_p=0.05, # P-value threshold
# stat_test="student"
# )
# report

# temp = evaluate(qrels, run, ["map@100", "mrr@10", "ndcg@10"]) # temp is a dictionary
# #per query results
temp = evaluate(qrels, run_1, ["precision@1"], return_mean=False)
temp

# temp.count(1)

In [None]:
# SentenceBERT
!pip install -U sentence-transformers

In [None]:
from sentence_transformers import SentenceTransformer, SentencesDataset, InputExample, losses, util, models, evaluation
from torch.utils.data import DataLoader

model = SentenceTransformer('all-distilroberta-v1')

In [None]:
from sentence_transformers import SentenceTransformer, SentencesDataset, InputExample, losses, util, models, evaluation
from torch.utils.data import DataLoader
import csv 
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

def read_queries(xml_file_path):
    dic_topics = {}
    tree = ET.parse(xml_file_path)
    root = tree.getroot()
    for child in root.iter('Question'):
        ID = child[0].text
        TITLE = html.unescape(child[1].text)
        dic_topics[ID] = TITLE
    return dic_topics

def read_corpus(xml_collection_file_path, xml_file_path):
  dic_collection = read_answers(xml_collection_file_path)
  dic_queries = read_queries(xml_file_path)
  return dic_queries, dic_collection

def retrieval():
    final_result = {}
    print("model loaded")
    # concept_map = read_concept_file("SQLite.csv")
    "This is an important part"
    queries, candidates = read_corpus("LawPosts.xml", "TestTopics.xml")
    print("corpus read")
    corpus_embeddings = model.encode(list(candidates.values()), convert_to_tensor=True)
    print("corpus encoded")
    for topic_id in queries:
        temp_dic = {}
        query = queries[topic_id]
        query_embedding = model.encode(query, convert_to_tensor=True)
        hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=1000)
        hits = hits[0]  # Get the hits for the first query
        for hit in hits:
            index = hit['corpus_id']
            answer_id = list(candidates.keys())[index]
            score = hit['score']
            temp_dic[answer_id] = score
        final_result[topic_id] = temp_dic
    return final_result

retrieval_results = retrieval()
with open("distilroberta.tsv", mode='w', newline='') as csv_file:
    csv_writer = csv.writer(csv_file, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL)
    for topic_id in retrieval_results:
        result_map = retrieval_results[topic_id]
        result_map = dict(sorted(result_map.items(), key=lambda item: item[1], reverse=True))
        rank = 1
        for post_id in result_map:
            score = result_map[post_id]
            csv_writer.writerow([topic_id, "Q0",  post_id, str(rank), str(score), "distilroberta"])
            rank += 1
            if rank > 1000:
                break