In [1]:
import transformers
transformers.__version__

'4.51.3'

In [2]:
# !pip install transformers sentence-transformers
# !pip install -U bitsandbytes peft typing_extensions


In [3]:
import os
import torch
import random
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer
import datasets

from datasets import load_dataset
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer, AutoConfig
from transformers import TrainingArguments, Trainer
from transformers import DataCollatorWithPadding

In [None]:


class RerankerTrainDataset(Dataset):
    """
    Dataset for BGE reranker training.

    Args:
        data_dir (str): Directory containing JSON/JSONL files.
        tokenizer (PreTrainedTokenizer): Tokenizer to use.
        max_query_length (int): Maximum length for query tokens.
        max_passage_length (int): Maximum length for passage tokens.
        train_group_size (int): Number of passages per query (1 pos + N-1 neg).
    """
    def __init__(self, data_dir, tokenizer: PreTrainedTokenizer, max_query_length=32, max_passage_length=256, train_group_size=8):
        self.tokenizer = tokenizer
        self.max_query_length = max_query_length
        self.max_passage_length = max_passage_length
        self.train_group_size = train_group_size

        # Load all JSON/JSONL files from directory
        dataset_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(('.json', '.jsonl'))]
        self.dataset = datasets.load_dataset("json", data_files=dataset_files, split="train")
        self.dataset = self.dataset.map(lambda x: {"q_len": len(tokenizer(x["query"])["input_ids"])})
        self.dataset = self.dataset.map(lambda x: {"p_len": len(tokenizer(x["pos"][0])["input_ids"])})
        self.dataset = self.dataset.filter(lambda x: x["q_len"]<=max_query_length)
        self.dataset = self.dataset.filter(lambda x: x["p_len"]<=max_passage_length)
        
        

    def __len__(self):
        return len(self.dataset)

    def _prepare_pair(self, query, passage):
        """Encodes query-passage pair and prepares it for model input."""
        qry_inputs = self.tokenizer.encode(
            query,
            truncation=True,
            max_length=self.max_query_length,
            add_special_tokens=False
        )
        doc_inputs = self.tokenizer.encode(
            passage,
            truncation=True,
            max_length=self.max_passage_length,
            add_special_tokens=False
        )

        return self.tokenizer.prepare_for_model(
            qry_inputs,
            doc_inputs,
            truncation="only_second",
            max_length=self.max_query_length + self.max_passage_length,
            padding=False
        )

    def __getitem__(self, idx):
        """
        Returns tokenized query-passage pairs.
        The first passage is positive, and the remaining are negatives.
        """
        try:
            query_prompt = "Represent this sentence for searching relevant passages:" 
            data = self.dataset[idx]
            query = query_prompt + data["query"] 
            
            # Select one positive passage and 7 negatives
            positive_passage = random.choice(data["pos"])
            
            # Handle cases where the number of negatives is less than required
            neg_count = self.train_group_size - 1  # Expected number of negatives
            available_negatives = data["neg"]
    
            if len(available_negatives) >= neg_count:
                negative_passages = random.sample(available_negatives, neg_count)
            else:
                negative_passages = available_negatives * (neg_count // len(available_negatives)) + \
                                    random.sample(available_negatives, neg_count % len(available_negatives))
    
    
            # Format as [positive, negative1, ..., negative7]
            passages = [positive_passage] + negative_passages
    
            # Prepare all query-passage pairs
            prepared_pairs = [self._prepare_pair(query, passage) for passage in passages]
    
            return prepared_pairs  
            
        except ZeroDivisionError:
            print(idx, data["neg"])


In [None]:
from transformers import DataCollatorWithPadding
from typing import List
from transformers.tokenization_utils_base import BatchEncoding

class RerankerCollator(DataCollatorWithPadding):
    """
    Collator for BGE reranker.
    - Pads tokenized query-passage pairs.
    - Ensures correct batch formatting for Trainer.
    """

    def __init__(self, tokenizer, query_max_len=32, passage_max_len=128, **kwargs):
        super().__init__(tokenizer, **kwargs)
        self.query_max_len = query_max_len
        self.passage_max_len = passage_max_len

    def __call__(self, features) -> List[BatchEncoding]:
        """
        Prepares batch for model training.
        - Extracts teacher scores (if available).
        - Flattens nested query-passage pairs.
        - Pads and returns in Trainer-compatible format.
        """

        teacher_scores = [f[1] for f in features]
        if teacher_scores[0] is None:
            teacher_scores = None
            
        elif isinstance(teacher_scores[0], list):
            teacher_scores = sum(teacher_scores, [])  # Flatten list of lists

        # Flatten tokenized query-passage pairs
        if isinstance(features[0], list):
            features = sum(features, [])  # Flatten list of lists

        
        # Pad the sequences
        collated = self.tokenizer.pad(
            features,
            padding=self.padding,
            max_length=self.query_max_len + self.passage_max_len,  # Uses overridden values
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        return {
            "pair": collated.to("cuda"), #,  # Tokenized input batch
            #"teacher_scores": teacher_scores,  # Optional distillation scores
        }


In [7]:
from transformers import AutoModelForSequenceClassification, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch
import torch.nn as nn

def load_model(model_id: str):
    # 4-bit quantization settings
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        llm_int8_skip_modules=["classifier", "pre_classifier"]
    )
    
    # Initialize base model
    model = AutoModelForSequenceClassification.from_pretrained(
        model_id,
        quantization_config=bnb_config,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )

    # Prepare model for LoRA
    model = prepare_model_for_kbit_training(model)

    # Apply LoRA
    lora_config = LoraConfig(
        r=64,
        lora_alpha=16,
        lora_dropout=0.1,
        target_modules=["query", "value"],  
        bias="none",
        task_type="SEQ_CLS",
    )
    model = get_peft_model(model, lora_config)
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    return tokenizer, model



In [8]:
tokenizer, model = load_model("BAAI/bge-m3")  


# "BAAI/bge-base-en-v1.5" -> 512

Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at BAAI/bge-m3 and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
model

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): XLMRobertaForSequenceClassification(
      (roberta): XLMRobertaModel(
        (embeddings): XLMRobertaEmbeddings(
          (word_embeddings): Embedding(250002, 1024, padding_idx=1)
          (position_embeddings): Embedding(8194, 1024, padding_idx=1)
          (token_type_embeddings): Embedding(1, 1024)
          (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): XLMRobertaEncoder(
          (layer): ModuleList(
            (0-23): 24 x XLMRobertaLayer(
              (attention): XLMRobertaAttention(
                (self): XLMRobertaSdpaSelfAttention(
                  (query): lora.Linear4bit(
                    (base_layer): Linear4bit(in_features=1024, out_features=1024, bias=True)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.1, inplace=False)
                  

In [10]:
data = RerankerTrainDataset(data_dir = "data_bge", 
                     tokenizer = tokenizer, 
                    max_query_length=128, 
                     max_passage_length=1024, 
                     train_group_size=8)


data_collator = RerankerCollator(tokenizer, query_max_len=128, passage_max_len=1024)

In [None]:
import torch
from torch.nn import CrossEntropyLoss
from transformers import Trainer

class CustomTrainer(Trainer):
    """
    Custom Trainer for BGE Reranker.
    - Overrides compute_loss to handle ranking-specific loss.
    - Assumes each query has 8 passages (1 positive + 7 negatives) (Here N = 8).
    """
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        Compute categorical cross-entropy loss for reranking.
        - Model outputs 8 logits per query (1 per passage).
        - The first passage (index 0) is the correct one.
        - Uses CrossEntropyLoss with label 0.
        """
        labels = torch.zeros(self.args.per_device_train_batch_size, dtype=torch.long, device=self.model.device)  # (batch_size,)
        # print(f"labels size: {labels.size()}")
        
        # Forward pass
        outputs = model(**inputs["pair"])
        logits = outputs.logits  # Shape: (batch_size, 8)
        # print(f"logits size: {logits.size()}")
        
        logits = logits.view(self.args.per_device_train_batch_size, -1)
        # print(f"logits after reshaping: {logits.size()}")
        
        # Compute categorical cross-entropy loss
        loss_fn = CrossEntropyLoss()
        loss = loss_fn(logits, labels)
        
        return (loss, outputs) if return_outputs else loss

In [13]:
training_args = TrainingArguments(output_dir="out", 
                                dataloader_pin_memory=False,
                                per_device_train_batch_size=4, 
                                logging_steps=1,
                                max_steps=1000,
                                bf16=True, 
                                learning_rate = 2e-5, 
                                # max_grad_norm=5,
                                lr_scheduler_type="constant", 
                                dataloader_drop_last=True)




In [14]:
trainer = CustomTrainer(model, training_args, data_collator, data)


No label_names provided for model class `PeftModelForSequenceClassification`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [15]:
session = trainer.train()

You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  return fn(*args, **kwargs)
Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss
1,3.0133
2,2.8258
3,2.6206
4,2.6595
5,2.6933
6,2.5881
7,2.499
8,2.396
9,2.7159
10,2.3668


  return fn(*args, **kwargs)


KeyboardInterrupt: 

In [None]:
# 22 gb for training the lora adaptors

In [16]:
trainer.save_model("bge_m3_reranker_lora_adapter_600")