# Latent Dirichlet Allocation (LDA)


In [1]:
from src.base.pipeline import Pipeline
from src.runner import Runner
from typing import Tuple, List
from arqmath_code.Entities.Post import Answer
from arqmath_code.topic_file_reader import Topic

## LDA pipeline

In [None]:
from src.pre_processors.nltk_tokenization_and_stopword_removal import NLTKTokenizationAndStopwordRemoval, \
    NLTKTokenizationAndStopwordRemovalForQueries
from src.pre_processors.remove_xml_tags import RemoveXMLTagsFromDocumentBody, RemoveXMLTagsFromQueries
from src.latent.latent_dirichlet_allocation import LatentDirichletAllocationModel
from src.post_processors.top_k_filter import TopKFilter
from arqmath_code.post_reader_record import DataReaderRecord


class LDAPipeline(Pipeline):

    def __init__(self, data_reader: DataReaderRecord):
        super().__init__(data_reader)
        self.lda_model = LatentDirichletAllocationModel(save_embeddings=True)
        self.top1000 = TopKFilter(k=1000)
        self.document_tag_remover = RemoveXMLTagsFromDocumentBody()
        self.document_tokenizer = NLTKTokenizationAndStopwordRemoval()
        self.query_tag_remover = RemoveXMLTagsFromQueries()
        self.query_tokenizer = NLTKTokenizationAndStopwordRemovalForQueries()

    def run(self, queries: List[Topic]) -> List[Tuple[Topic, Answer, float]]:
        all_answers = self.data_reader.get_all_answer_posts()

        print("Start document preprocessing")
        documents = self.document_tag_remover(queries, all_answers)
        documents = self.document_tokenizer(queries, documents)

        print("Start query preprocessing")
        queries = self.query_tag_remover(queries)
        queries = self.query_tokenizer(queries)

        print("Start ranking")
        ranking = self.lda_model(queries=queries, documents=documents)

        print("Start top 1000 filtering")
        ranking = self.top1000(queries=None, ranking=ranking)

        return ranking


## Run pipeline

In [4]:
from datetime import datetime
print(datetime.now())
runner = Runner(LDAPipeline, n=1)
ranking = runner.run("../results/model_results/lda_200_v2.tsv")
print(datetime.now())
ranking

2022-11-26 15:25:49.719414
reading users
reading comments
reading votes
reading post links
reading posts
Start document preprocessing
Start query preprocessing
Start ranking
Finished count vectorizer
Finished LDA embedding
Start top 1000 filtering
2022-11-26 16:38:43.381812


Unnamed: 0,Topic_Id,Post_Id,Score,Run_Number,Rank
0,A.301,1431851,0.818038,0,0
1,A.301,1016077,0.810230,0,1
2,A.301,2688957,0.795476,0,2
3,A.301,126181,0.788615,0,3
4,A.301,312260,0.786997,0,4
...,...,...,...,...,...
99995,A.400,552932,0.711746,0,995
99996,A.400,2559874,0.711726,0,996
99997,A.400,1620792,0.711696,0,997
99998,A.400,902004,0.711558,0,998


## Evaluation

In [1]:
from arqmath_code.evaluation.task1 import arqmath_to_prime_task1
from arqmath_code.evaluation.task1 import task1_get_results

In [2]:
qrel_dictionary = arqmath_to_prime_task1.read_qrel_to_dictionary("../arqmath_dataset/evaluation/Task 1/Qrel Files/qrel_task1_2022_official.tsv")
arqmath_to_prime_task1.convert_result_files_to_trec(submission_dir="../results/model_results/", qrel_result_dic=qrel_dictionary, prim_dir="../results/ARQmath_prim/", trec_dir="../results/ARQmath_trec/")

In [3]:
number_topics = 78
task1_get_results.get_result(trec_eval_tool="trec_eval", qre_file_path="../arqmath_dataset/evaluation/Task 1/Qrel Files/qrel_task1_2022_official.tsv", prim_result_dir="../results/ARQmath_prim/", evaluation_result_file="../results/complete_results_v1_.tsv", number_topics=number_topics)

-----------
['ndcg                  ', 'A.301', '0.1050']
-----------
['ndcg                  ', 'A.302', '0.0072']
-----------
['ndcg                  ', 'A.303', '0.0175']
-----------
['ndcg                  ', 'A.304', '0.2101']
-----------
['ndcg                  ', 'A.305', '0.0000']
-----------
['ndcg                  ', 'A.306', '0.0000']
-----------
['ndcg                  ', 'A.307', '0.0750']
-----------
['ndcg                  ', 'A.308', '0.4967']
-----------
['ndcg                  ', 'A.309', '0.0074']
-----------
['ndcg                  ', 'A.310', '0.0085']
-----------
['ndcg                  ', 'A.312', '0.0000']
-----------
['ndcg                  ', 'A.313', '0.3230']
-----------
['ndcg                  ', 'A.315', '0.1377']
-----------
['ndcg                  ', 'A.316', '0.0287']
-----------
['ndcg                  ', 'A.317', '0.0685']
-----------
['ndcg                  ', 'A.318', '0.0471']
-----------
['ndcg                  ', 'A.319', '0.0291']
-----------
['