# Bi-Encoder ranking model

## Data loading

In [1]:
from datasets import load_from_disk
from src.utils.config_management import CONFIG

In [2]:
hf_dataset = load_from_disk(CONFIG['paths']['data']['dalip_hf_dataset'])

In [3]:
hf_dataset

DatasetDict({
    train: Dataset({
        features: ['answer_id', 'question_id', 'answer_creation_date', 'answer_score', 'answer_normalized_score', 'answer_log_normalized_score', 'answer_body', 'answer_last_edit_date', 'answer_last_activity_date', 'answer_comment_count', 'answer_community_owned_date', 'question_creation_date', 'question_score', 'question_view_count', 'question_body', 'question_last_edit_date', 'question_last_activity_date', 'question_title', 'question_tags', 'question_answer_count', 'question_comment_count', 'question_favorite_count', 'question_closed_date', 'question_community_owned_date', 'answer_accepted'],
        num_rows: 42700
    })
    test: Dataset({
        features: ['answer_id', 'question_id', 'answer_creation_date', 'answer_score', 'answer_normalized_score', 'answer_log_normalized_score', 'answer_body', 'answer_last_edit_date', 'answer_last_activity_date', 'answer_comment_count', 'answer_community_owned_date', 'question_creation_date', 'question_score', 

## Data preprocessing

In [4]:
from src.utils.text_preprocessing import Preprocessor

In [5]:
preprocessor = Preprocessor(preserve_html_tags=['code'])

In [6]:
hf_dataset = hf_dataset.map(preprocessor, batched=True)

In [7]:
hf_dataset

DatasetDict({
    train: Dataset({
        features: ['answer_id', 'question_id', 'answer_creation_date', 'answer_score', 'answer_normalized_score', 'answer_log_normalized_score', 'answer_body', 'answer_last_edit_date', 'answer_last_activity_date', 'answer_comment_count', 'answer_community_owned_date', 'question_creation_date', 'question_score', 'question_view_count', 'question_body', 'question_last_edit_date', 'question_last_activity_date', 'question_title', 'question_tags', 'question_answer_count', 'question_comment_count', 'question_favorite_count', 'question_closed_date', 'question_community_owned_date', 'answer_accepted', 'question_text', 'answer_text'],
        num_rows: 42700
    })
    test: Dataset({
        features: ['answer_id', 'question_id', 'answer_creation_date', 'answer_score', 'answer_normalized_score', 'answer_log_normalized_score', 'answer_body', 'answer_last_edit_date', 'answer_last_activity_date', 'answer_comment_count', 'answer_community_owned_date', 'question_cr

## Fine-tuning

### Create pairs dataset

In [8]:
import torch
import os
import pandas as pd
from typing import Literal, Union
from datasets import Dataset
from itertools import combinations
import random
from tqdm import tqdm
import math

In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

TARGET_COL = 'answer_normalized_score'
PAIRS_SAMPLING_STRATEGY = 'mean'
N_SAMPLES = 10
MODEL_PATH = 'mmukh/SOBertBase'
MODEL_NAME = MODEL_PATH.split('/')[-1]
MAX_LENGTH = 1024
BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = math.ceil(4 / BATCH_SIZE)
EMBEDDINGS_POOLING = 'mean'
LOSS = 'margin_ranking_loss'
MODEL_OUTPUT_PATH = os.path.join(CONFIG['paths']['models']['dalip_bi-encoder_ranking'],
                                 f'bi-encoder_ranking_{MODEL_NAME}_{EMBEDDINGS_POOLING}')

In [10]:
def create_pairs_dataset_df(dataset_df,
                         pairs_sampling_strategy: Union[Literal['mean'], Literal['topk']] = 'mean',
                         n: Union[int, Literal['all']] = 'all',
                         ) -> pd.DataFrame:
    def create_pairs(group, pairs_sampling_strategy, n):
        group = group.sort_values(TARGET_COL, ascending=False)

        all_pairs_idxs = list(combinations(group.index, 2))
        pairs_idxs = []

        if n == 'all':
            for pair_idx in all_pairs_idxs:
                if group.loc[pair_idx[0]][TARGET_COL] != group.loc[pair_idx[1]][TARGET_COL]:
                    pairs_idxs.append(pair_idx)

        else:
            if pairs_sampling_strategy == 'mean':
                random.shuffle(all_pairs_idxs)
                for pair_idx in all_pairs_idxs:
                    if group.loc[pair_idx[0]][TARGET_COL] != group.loc[pair_idx[1]][TARGET_COL]:
                        pairs_idxs.append(pair_idx)
                    if len(pairs_idxs) == n:
                        break

            elif pairs_sampling_strategy == 'topk':
                for curr_k in range(min(n, len(group))):
                    anchor_idx = group.index[curr_k]
                    for idx in group.index[curr_k:]:
                        if group.loc[anchor_idx][TARGET_COL] != group.loc[idx][TARGET_COL]:
                            pairs_idxs.append((anchor_idx, idx))

        pairs = []
        for pair_idx in pairs_idxs:
            pair = {
                'question_id': group.loc[pair_idx[0]]['question_id'],
                'answer_1_id': group.loc[pair_idx[0]]['answer_id'],
                'answer_2_id': group.loc[pair_idx[1]]['answer_id'],
                'question_text': group.loc[pair_idx[0]]['question_text'],
                'answer_1_text': group.loc[pair_idx[0]]['answer_text'],
                'answer_2_text': group.loc[pair_idx[1]]['answer_text'],
                f'answer_1_{TARGET_COL}': group.loc[pair_idx[0]][TARGET_COL],
                f'answer_2_{TARGET_COL}': group.loc[pair_idx[1]][TARGET_COL],
            }
            if group.loc[pair_idx[0]][TARGET_COL] > group.loc[pair_idx[1]][TARGET_COL]:
                pair['label'] = 1
            else:
                pair['label'] = -1

            pairs.append(pair)

        return pairs

    groups = dataset_df.groupby('question_id')

    pairs_dataset_df = []
    for name, group in tqdm(groups):
        group_pairs = create_pairs(group, pairs_sampling_strategy, n)
        pairs_dataset_df.extend(group_pairs)
    pairs_dataset_df = pd.DataFrame(pairs_dataset_df)

    return pairs_dataset_df

In [11]:
train_dataset_df = pd.DataFrame(hf_dataset['train'])
test_dataset_df = pd.DataFrame(hf_dataset['test'])

In [12]:
train_pairs_dataset_df = create_pairs_dataset_df(train_dataset_df, pairs_sampling_strategy=PAIRS_SAMPLING_STRATEGY, n=N_SAMPLES)
test_pairs_dataset_df = create_pairs_dataset_df(test_dataset_df, pairs_sampling_strategy='mean', n='all')

hf_dataset['test'] = hf_dataset['test'].rename_column('answer_text', 'answer_1_text')

100%|██████████| 7776/7776 [00:41<00:00, 186.85it/s]
100%|██████████| 1945/1945 [00:19<00:00, 98.65it/s] 


In [13]:
train_pairs_dataset = Dataset.from_pandas(train_pairs_dataset_df)

### Define model and data collator

In [14]:
from dataclasses import dataclass
import torch.nn as nn
from transformers import PreTrainedTokenizerBase, MegatronBertModel, PreTrainedTokenizerFast

2025-05-16 12:17:07.133952: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-16 12:17:07.779017: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-16 12:17:07.779086: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-16 12:17:07.884323: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-16 12:17:08.104983: I tensorflow/core/platform/cpu_feature_guar

In [15]:
encoder_model = MegatronBertModel.from_pretrained(MODEL_PATH)

Some weights of MegatronBertModel were not initialized from the model checkpoint at mmukh/SOBertBase and are newly initialized: ['embeddings.token_type_embeddings.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
tokenizer = PreTrainedTokenizerFast.from_pretrained(MODEL_PATH)
if not tokenizer.pad_token:
    pad_token_id = encoder_model.embeddings.word_embeddings.padding_idx
    print(f'Setting pad token id to {pad_token_id}...')
    tokenizer.pad_token_id = pad_token_id
    print(f'Pad token set to {tokenizer.pad_token}')

Setting pad token id to 0...
Pad token set to <unk>


In [17]:
class BiEncoderRanker(nn.Module):
    def __init__(self, encoder_model, embeddings_pooling):
        super().__init__()
        self.encoder = encoder_model
        self.embeddings_pooling = embeddings_pooling
        self.hidden_size = self.encoder.embeddings.word_embeddings.embedding_dim

        self.scorer = nn.Sequential(
            nn.Linear(2 * self.hidden_size, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )

    def get_sentence_embeddings(self, tokenized_inputs):
        outputs = self.encoder(**tokenized_inputs)

        attention_mask = tokenized_inputs['attention_mask']
        last_hidden_state = outputs.last_hidden_state

        if self.embeddings_pooling == 'mean':
            attention_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size())
            sum_embeddings = (last_hidden_state * attention_mask_expanded).sum(dim=1)
            sum_mask = attention_mask_expanded.sum(dim=1)
            pooled = sum_embeddings / sum_mask

        return pooled

    def gradient_checkpointing_enable(self, **gradient_checkpointing_kwargs):
        if hasattr(self.encoder, 'gradient_checkpointing_enable'):
            self.encoder.gradient_checkpointing_enable()
        else:
            raise NotImplementedError('Encoder model does not support gradient checkpointing.')

    def gradient_checkpointing_disable(self, **gradient_checkpointing_kwargs):
        if hasattr(self.encoder, 'gradient_checkpointing_disable'):
            self.encoder.gradient_checkpointing_disable()
        else:
            raise NotImplementedError('Encoder model does not support gradient checkpointing.')

    def forward(self, questions_tokenized, answers_1_tokenized, answers_2_tokenized=None, labels=None):
        question_embeddings = self.get_sentence_embeddings(questions_tokenized)
        answer_1_embeddings = self.get_sentence_embeddings(answers_1_tokenized)
        combined_1 = torch.cat([question_embeddings, answer_1_embeddings], dim=1)
        answer_1_scores = self.scorer(combined_1).squeeze(-1)

        outputs = {'answer_1_scores': answer_1_scores}

        if answers_2_tokenized is not None:
            answer_2_embeddings = self.get_sentence_embeddings(answers_2_tokenized)
            combined_2 = torch.cat([question_embeddings, answer_2_embeddings], dim=1)
            answer_2_scores = self.scorer(combined_2).squeeze(-1)

            outputs['answer_2_scores'] = answer_2_scores

        return outputs

In [18]:
model = BiEncoderRanker(encoder_model, embeddings_pooling=EMBEDDINGS_POOLING).to(device)

In [19]:
@dataclass
class BiEncoderPairwiseDataCollator:
    tokenizer: PreTrainedTokenizerBase
    padding: bool = True

    def __call__(self, batch):
        question_texts = []
        answer_1_texts = []
        answer_2_texts = []
        labels = []

        for sample in batch:
            question_texts.append(sample['question_text'])
            answer_1_texts.append(sample['answer_1_text'])
            if 'answer_2_text' in sample:
                answer_2_texts.append(sample['answer_2_text'])
            if 'label' in sample:  # if training
                labels.append(sample['label'])
            else:  # if evaluation
                labels.append(sample[TARGET_COL])

        questions_tokenized = self.tokenizer(question_texts, padding=self.padding, truncation=True, max_length=MAX_LENGTH,
                                             return_tensors='pt')
        answers_1_tokenized = self.tokenizer(answer_1_texts, padding=self.padding, truncation=True, max_length=MAX_LENGTH,
                                           return_tensors='pt')

        labels = torch.tensor(labels).float()

        collated_batch = {
            'questions_tokenized': questions_tokenized,
            'answers_1_tokenized': answers_1_tokenized,
            'labels': labels
        }

        if answer_2_texts:
            answers_2_tokenized = self.tokenizer(answer_2_texts, padding=self.padding, truncation=True, max_length=MAX_LENGTH,
                                           return_tensors='pt')

            collated_batch['answers_2_tokenized'] = answers_2_tokenized

        return collated_batch

In [20]:
data_collator = BiEncoderPairwiseDataCollator(tokenizer=tokenizer)

### Train model

In [21]:
from transformers import Trainer, TrainingArguments
from src.evaluation import RankingEvaluator
import wandb
import pandas as pd
import math

In [22]:
if LOSS == 'margin_ranking_loss':
    loss_fn = nn.MarginRankingLoss(margin=1.0)

def trainer_loss_fn(outputs, labels, num_items_in_batch=None):
    answer_1_scores = outputs['answer_1_scores']

    if 'answer_2_scores' in outputs: # if training
        answer_2_scores = outputs['answer_2_scores']
        loss = loss_fn(answer_1_scores, answer_2_scores, labels)

    else: # if evaluation
        loss = torch.tensor(0.0)

    return loss

In [23]:
test_question_ids = hf_dataset['test']['question_id']

evaluator = RankingEvaluator(ndcg_k=list(range(1, 11)),
                             ndcg_gain_func='exponential', ndcg_discount_func='logarithmic')

def compute_metrics(eval_pred):
    predictions = eval_pred.predictions
    labels = eval_pred.label_ids

    predictions_df = pd.DataFrame()
    predictions_df['answer_id'] = hf_dataset['test']['answer_id']
    predictions_df[TARGET_COL] = hf_dataset['test'][TARGET_COL]
    predictions_df = predictions_df[:len(predictions)]
    predictions_df['predicted_score'] = predictions

    pairs_predictions_df = test_pairs_dataset_df.merge(predictions_df, left_on='answer_1_id', right_on='answer_id')
    pairs_predictions_df = pairs_predictions_df.rename(columns={'predicted_score': 'answer_1_predicted_score'})
    pairs_predictions_df = pairs_predictions_df.merge(predictions_df, left_on='answer_2_id', right_on='answer_id')
    pairs_predictions_df = pairs_predictions_df.rename(columns={'predicted_score': 'answer_2_predicted_score'})

    loss = loss_fn(torch.tensor(pairs_predictions_df['answer_1_predicted_score']),
                   torch.tensor(pairs_predictions_df['answer_2_predicted_score']),
                   torch.tensor(pairs_predictions_df['label']))

    metrics = {LOSS: loss}
    metrics.update(evaluator(labels, predictions, test_question_ids))
    metrics.pop('mae')

    wandb.log({'predictions_table': wandb.Table(dataframe=predictions_df)})

    return metrics

In [24]:
training_args = TrainingArguments(
    output_dir=MODEL_OUTPUT_PATH,
    logging_steps=1,
    eval_steps=int(len(train_pairs_dataset) / BATCH_SIZE / GRADIENT_ACCUMULATION_STEPS / 10),
    eval_strategy = "steps",
    save_strategy = "epoch",
    save_total_limit=1,
    learning_rate=2e-5,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE * 2,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    num_train_epochs=5,
    weight_decay=0.01,
    report_to='wandb',
    remove_unused_columns=False,
    # gradient_checkpointing=True,
    # optim="adamw_8bit"
)

In [25]:
run = wandb.init(
    project='dalip-stackoverflow-answer-ranking',
    tags=['bi-encoder', 'ranking']
)

wandb.config.update({
    'preprocessing': preprocessor.__dict__,
    'dataset': {
        'pairs_sampling_strategy': PAIRS_SAMPLING_STRATEGY,
        'n': N_SAMPLES
    },
    'model_name': MODEL_NAME,
    'vectorizer': {
        'vectorization_type': 'embeddings',
        'embeddings_pooling': EMBEDDINGS_POOLING,
        'max_length': MAX_LENGTH
    },
    'loss_fn': LOSS
})

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_pairs_dataset,
    eval_dataset=hf_dataset['test'],
    compute_loss_func=trainer_loss_fn,
    compute_metrics=compute_metrics
)

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbunnynobugs[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
trainer.train()
run.finish()



Step,Training Loss,Validation Loss,Margin Ranking Loss,Ndcg@1 G.exponential D.logarithmic,Ndcg@2 G.exponential D.logarithmic,Ndcg@3 G.exponential D.logarithmic,Ndcg@4 G.exponential D.logarithmic,Ndcg@5 G.exponential D.logarithmic,Ndcg@6 G.exponential D.logarithmic,Ndcg@7 G.exponential D.logarithmic,Ndcg@8 G.exponential D.logarithmic,Ndcg@9 G.exponential D.logarithmic,Ndcg@10 G.exponential D.logarithmic,Hit Rate@1,Runtime,Samples Per Second,Steps Per Second
1316,4.1893,0.0,0.844493,0.520546,0.622459,0.700874,0.755762,0.777599,0.786951,0.791173,0.793832,0.794831,0.795371,0.401542,341.4328,30.937,15.47
