## Model

In [None]:
import polars as pl
import torch
from sentence_transformers import SentenceTransformer, CrossEncoder
from sentence_transformers.models import StaticEmbedding
from tqdm import tqdm

In [None]:
static_embedding = StaticEmbedding.from_distillation("BAAI/bge-m3")
model = SentenceTransformer(
    modules=[static_embedding],
    device="cuda",
    trust_remote_code=True,
    model_kwargs={
        "torch_dtype": torch.float16,
    },
)

# model = SentenceTransformer(
#     "BAAI/bge-m3",
#     device="cuda",
#     trust_remote_code=True,
#     model_kwargs={
#         "torch_dtype": torch.float16,
#     },
# )

reranker = CrossEncoder(
    "BAAI/bge-reranker-v2-m3",
    device="cuda",
    trust_remote_code=True,
    automodel_args={
        "torch_dtype": torch.float16,
    },
)

## Data

In [3]:
corpus = pl.read_csv("../data/corpus.csv")
public_test = pl.read_csv("../data/public_test.csv")
corpus_text = corpus["text"].to_list()
corpus_id = corpus["cid"].to_list()
questions = public_test["question"].to_list()
question_id = public_test["qid"].to_list()

In [None]:
corpus.head()

In [None]:
public_test.head()

## Embedding

### 1. Embedding Corpus and Question

In [None]:
corpus_embeddings = model.encode(
    corpus_text,
    batch_size=16,
    show_progress_bar=True,
)

In [None]:
public_question_embeddings = model.encode(
    questions,
    batch_size=16,
    show_progress_bar=True,
)

### 2.1 Retrieve top k Corpus (Baseline) For each Question

In [8]:
similarities = model.similarity(public_question_embeddings, corpus_embeddings)

In [9]:
top_k = torch.topk(similarities,8, dim=1)

In [10]:
bge_top_k = top_k.indices.tolist()

### 2.2 Retrieve top k Corpus (BM25) For each Question 

In [11]:
import bm25s

In [None]:
retriever = bm25s.BM25(corpus=corpus_text)
retriever.index(bm25s.tokenize(corpus_text))

In [None]:
bm25_top_k = []
for question in tqdm(questions, total=len(questions)):
    results, scores = retriever.retrieve(
        bm25s.tokenize(question),
        k=8,
    )
    top_id = list(map(lambda x: corpus_text.index(x), results.tolist()[0]))
    bm25_top_k.append(top_id)

In [14]:
example_id = list(set(bge_top_k[1] + bm25_top_k[1]))

In [15]:
merge_top_k = list(map(lambda a: list(set(a[0] + a[1])), zip(bge_top_k, bm25_top_k)))

### 3. Rerank to 10 corpus/question

In [16]:
def final_rerank(query, corpus, top_k_indices):
    rerank_score = reranker.rank(
        query=query,
        batch_size=4,
        documents=[corpus["text"][idx] for idx in top_k_indices],
    )
    top_k_id = list(map(lambda x: x["corpus_id"], rerank_score[:10]))
    final_top_k = [top_k_indices[i] for i in top_k_id]
    return final_top_k

In [None]:
rerank_index = []
for i in tqdm(range(len(public_test)), total=len(public_test)):
    rerank_index.append(
        final_rerank(public_test["question"][i], corpus, merge_top_k[i])
    )

In [18]:
rerank_corpus_id = [
    [corpus_id[idx] for idx in rerank_index[i]] for i in range(len(rerank_index))
]

In [None]:
len(rerank_corpus_id)

## Final Output

In [19]:
q_id = public_test["qid"]
list_output = []
for idx, index in zip(q_id, rerank_corpus_id):
    list_output.append(f"{idx} {' '.join(map(str, index))}")

# Save the output
with open("bm25/predict.txt", "w") as f:
    f.write("\n".join(list_output))

In [None]:
corpus.head()

In [None]:
index = 0
(
    rerank_corpus_id[index],
    [corpus_id[i] for i in bge_top_k[index]],
    [corpus_id[i] for i in bm25_top_k[index]],
)