# Cross-Encoders

In [1]:
from typing import List, Union, Tuple
from arqmath_code.Entities.Post import Question, Answer
from arqmath_code.topic_file_reader import Topic
from src import init_data
topic_reader, data_reader = init_data(task=1)

reading users
reading comments
reading votes
reading post links
reading posts


In [8]:
from arqmath_code.post_reader_record import DataReaderRecord
from src.post_processors.top_k_filter import TopKFilter
from src.post_processors.answer_score_retriever_for_questions import AnswerScoreRetrieverForQuestions
from src.sbert.question_s_bert import QuestionSBERT
from src.base.pipeline import Pipeline
from src.sbert.cross_encoder import SBertCrossEncoder


class SBertPipelineWithCrossEncoder(Pipeline):

    def __init__(self, data_reader: DataReaderRecord):
        super().__init__(data_reader)
        self.sbert = QuestionSBERT(model_id='all-MiniLM-L6-v2')
        self.answer_score_retriever = AnswerScoreRetrieverForQuestions()
        self.top_k_filter = TopKFilter(k=5000)
        self.final_top_k_filter = TopKFilter()
        self.cross_encoder = SBertCrossEncoder('cross-encoder/ms-marco-TinyBERT-L-2-v2')

    def run(self, queries: List[Topic]) -> List[Tuple[Topic, Answer, float]]:
        questions: List[Question] = self.data_reader.get_questions()[:1000]
        ranking: List[
        Tuple[Topic, Union[Question, Answer], float]] = self.sbert(queries=queries, documents=questions)
        ranking = self.answer_score_retriever(queries=queries, ranking=ranking)
        ranking = self.top_k_filter(queries=queries, ranking=ranking)
        ranking = self.cross_encoder(queries=queries, ranking=ranking)
        ranking = self.final_top_k_filter(queries=queries, ranking=ranking)
        return ranking

In [9]:
from src.runner import Runner
from datetime import datetime

print(datetime.now())
runner = Runner(SBertPipelineWithCrossEncoder, n=1, data_reader=data_reader, topic_reader=topic_reader)
ranking = runner.run("../results/model_results/SBert-Cross-encoder-test.tsv")
print(datetime.now())
ranking

2022-11-15 19:04:23.014767
read from cached embeddings at  ../arqmath_dataset/model_embeddings/document_embeddings_all-MiniLM-L6-v2.npy


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/3097 [00:00<?, ?it/s]

[-9.276285   4.2935276 -8.854645  ... -0.7223669 -0.5448551 -2.5718584]
2022-11-15 19:12:22.019144


Unnamed: 0,Topic_Id,Post_Id,Score,Run_Number,Rank
0,A.301,2122,6.308325,0,0
1,A.301,4209,6.281318,0,1
2,A.301,3207,6.202952,0,2
3,A.301,3084,6.183568,0,3
4,A.301,813,6.106048,0,4
...,...,...,...,...,...
99095,A.400,3949,-10.475514,0,986
99096,A.400,2895847,-10.512490,0,987
99097,A.400,1895,-10.586302,0,988
99098,A.400,2516,-10.660255,0,989
