# BGE M3

## Dense embeddings only 

In [1]:
# %pip install FlagEmbedding
# !pip install -U FlagEmbedding

Collecting FlagEmbedding
  Downloading FlagEmbedding-1.3.4.tar.gz (163 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.8/163.8 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Collecting transformers>=4.44.2 (from FlagEmbedding)
  Downloading transformers-4.52.2-py3-none-any.whl.metadata (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.2/40.2 kB[0m [31m17.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets>=2.19.0 (from FlagEmbedding)
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting accelerate>=0.20.1 (from FlagEmbedding)
  Downloading accelerate-1.7.0-py3-none-any.whl.metadata (19 kB)
Collecting sentence_transformers (from FlagEmbedding)
  Downloading sentence_transformers-4.1.0-py3-none-any.whl.metadata (13 kB)
Collecting peft (from FlagEmbedding)
  Downloading peft-0.15.2-py3-none-any.whl.metadata (13 kB)
Collecting ir-datasets (from F

In [2]:
!pip install unsloth accelerate pandas matplotlib 

Collecting unsloth
  Downloading unsloth-2025.5.7-py3-none-any.whl.metadata (47 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.1/47.1 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
Collecting matplotlib
  Downloading matplotlib-3.10.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting unsloth_zoo>=2025.5.8 (from unsloth)
  Downloading unsloth_zoo-2025.5.8-py3-none-any.whl.metadata (8.0 kB)
Collecting torch>=2.4.0 (from unsloth)
  Downloading torch-2.7.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting xformers>=0.0.27.post2 (from unsloth)
  Downloading xformers-0.0.30-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (1.0 kB)
Collecting bitsandbytes (from unsloth)
  Downloading bitsandbytes-0.45.5-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting triton>=3.0.0 (from unsloth)
  Downloading triton-3.3.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.5 kB)
Collecting tyro (from 

In [3]:
import logging
from typing import Tuple
from transformers import (
    AutoModelForSequenceClassification, AutoConfig,
    AutoTokenizer, PreTrainedTokenizer
)

from FlagEmbedding.abc.finetune.reranker import AbsRerankerRunner, AbsRerankerModel
from FlagEmbedding.finetune.reranker.encoder_only.base.modeling import CrossEncoderModel
from FlagEmbedding.finetune.reranker.encoder_only.base.trainer import EncoderOnlyRerankerTrainer

logger = logging.getLogger(__name__)


In [4]:
availible_models = ["BAAI/bge-base-en-v1.5", "BAAI/bge-large-en-v1.5", "BAAI/bge-reranker-v2-m3"]

In [5]:
config = {
    # Model Arguments
    "model_name_or_path": "BAAI/bge-base-en-v1.5",  
    "config_name": None,
    "tokenizer_name": None,
    "cache_dir": "./cache",
    "trust_remote_code": False,
    "model_type": "encoder",
    "token": None,  # HF token 
    
    # Data Arguments
    "train_data": ["./ft_data/training.json"],  # Training path of the data
    "cache_path": "./data_cache",
    "train_group_size": 8,
    "query_max_len": 32,
    "passage_max_len": 128,
    "max_len": 512,
    "pad_to_multiple_of": None,
    "max_example_num_per_dataset": 100000,
    "query_instruction_for_rerank": "Search query:",  # Optional instruction
    "query_instruction_format": "{}{}",
    "knowledge_distillation": False,
    "passage_instruction_for_rerank": "Passage:",  # Optional instruction
    "passage_instruction_format": "{}{}",
    "shuffle_ratio": 0.0,
    "sep_token": "\n",
    
    # Training Arguments
    "output_dir": "./results",
    "overwrite_output_dir": True,
    "do_train": True,
    "do_eval": False,
    "max_steps": 1000, 
    "per_device_train_batch_size": 8,
    "per_device_eval_batch_size": 8,
    "gradient_accumulation_steps": 1,
    "learning_rate": 5e-5,
    "weight_decay": 0.01,
    # "num_train_epochs": 3,
    "lr_scheduler_type": "linear",
    "warmup_ratio": 0.1,
    "logging_dir": "./logs",
    "logging_steps": 25,
    "save_steps": 100,
    "save_total_limit": 2,
    "fp16": False,
    "sub_batch_size": None,
    "report_to": "none",
    
}

In [6]:
from transformers import HfArgumentParser

from FlagEmbedding.abc.finetune.reranker import (
    AbsRerankerModelArguments,
    AbsRerankerDataArguments,
    AbsRerankerTrainingArguments
)
from FlagEmbedding.finetune.reranker.encoder_only.base import EncoderOnlyRerankerRunner


parser = HfArgumentParser((AbsRerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments))

model_args, data_args, training_args = parser.parse_dict(config)

In [7]:
import os
import math
import random
import logging
import datasets
import numpy as np
import torch.distributed as dist
from dataclasses import dataclass
from torch.utils.data import Dataset
from transformers import (
    PreTrainedTokenizer, 
    DataCollatorWithPadding,
    BatchEncoding,
    DataCollatorForSeq2Seq
)
from typing import List


logger = logging.getLogger(__name__)


class RerankerTrainDataset(Dataset):
    """Abstract class for reranker training dataset.

    Args:
        args (AbsRerankerDataArguments): Data arguments.
        tokenizer (PreTrainedTokenizer): Tokenizer to use.
    """
    def __init__(
        self,
        args: AbsRerankerDataArguments,
        tokenizer: PreTrainedTokenizer
    ):
        self.args = args
        self.tokenizer = tokenizer

        train_datasets = []
        for data_dir in args.train_data:
            if not os.path.isdir(data_dir):
                if not (data_dir.endswith('.json') or data_dir.endswith('.jsonl')): continue
                temp_dataset = self._load_dataset(data_dir)
                if len(temp_dataset) == 0: continue
                train_datasets.append(temp_dataset)
            else:
                for file in os.listdir(data_dir):
                    if not (file.endswith('.json') or file.endswith('.jsonl')): continue
                    temp_dataset = self._load_dataset(os.path.join(data_dir, file))
                    if len(temp_dataset) == 0: continue
                    train_datasets.append(temp_dataset)


        self.dataset = datasets.concatenate_datasets(train_datasets)

        self.max_length = self.args.query_max_len + self.args.passage_max_len

    # def _load_dataset(self, file_path: str):
    #     """Load dataset from path.

    #     Args:
    #         file_path (str): Path to load the datasets from.

    #     Raises:
    #         ValueError: `pos_scores` and `neg_scores` not found in the features of training data

    #     Returns:
    #         datasets.Dataset: Loaded HF dataset.
    #     """
    #     if dist.get_rank() == 0:
    #         logger.info(f'loading data from {file_path} ...')

    #     temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=self.args.cache_path)
    #     if len(temp_dataset) > self.args.max_example_num_per_dataset:
    #         temp_dataset = temp_dataset.select(random.sample(list(range(len(temp_dataset))), self.args.max_example_num_per_dataset))
    #     if not self.args.knowledge_distillation:
    #         if 'pos_scores' in temp_dataset.column_names:
    #             temp_dataset = temp_dataset.remove_columns(['pos_scores'])
    #         if 'neg_scores' in temp_dataset.column_names:
    #             temp_dataset = temp_dataset.remove_columns(['neg_scores'])
    #     else:
    #         if 'pos_scores' not in temp_dataset.column_names or 'neg_scores' not in temp_dataset.column_names:
    #             raise ValueError(f"`pos_scores` and `neg_scores` not found in the features of training data in {file_path}, which is necessary when using knowledge distillation.")
    #     return temp_dataset

    def _load_dataset(self, file_path: str):
        """Load dataset from path.

        Args:
            file_path (str): Path to load the datasets from.

        Raises:
            ValueError: `pos_scores` and `neg_scores` not found in the features of training data

        Returns:
            datasets.Dataset: Loaded HF dataset.
        """
        # Check if distributed is initialized before using it
        is_main_process = True
        if hasattr(dist, 'is_initialized') and dist.is_initialized():
            is_main_process = dist.get_rank() == 0
            
        if is_main_process:
            logger.info(f'loading data from {file_path} ...')

        temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=self.args.cache_path)
        if len(temp_dataset) > self.args.max_example_num_per_dataset:
            temp_dataset = temp_dataset.select(random.sample(list(range(len(temp_dataset))), self.args.max_example_num_per_dataset))
            
        if not self.args.knowledge_distillation:
            if 'pos_scores' in temp_dataset.column_names:
                temp_dataset = temp_dataset.remove_columns(['pos_scores'])
            if 'neg_scores' in temp_dataset.column_names:
                temp_dataset = temp_dataset.remove_columns(['neg_scores'])
        else:
            if 'pos_scores' not in temp_dataset.column_names or 'neg_scores' not in temp_dataset.column_names:
                raise ValueError(f"`pos_scores` and `neg_scores` not found in the features of training data in {file_path}, which is necessary when using knowledge distillation.")
        return temp_dataset

    def _shuffle_text(self, text):
        """shuffle the input text.

        Args:
            text (str): Input text.

        Returns:
            str: Shuffled text.
        """
        if self.args.shuffle_ratio > 0 and len(text) > 100 and random.random() < self.args.shuffle_ratio:
            split_text = []
            chunk_size = len(text)//3 + 1
            for i in range(0, len(text), chunk_size):
                split_text.append(text[i:i+chunk_size])
            random.shuffle(split_text)
            return " ".join(split_text)
        else:
            return text

    def __len__(self):
        return len(self.dataset)

    def create_one_example(self, qry_encoding: str, doc_encoding: str):
        """Creates a single input example by encoding and preparing a query and document pair for the model.

        Args:
            qry_encoding (str): Query to be encoded.
            doc_encoding (str): Document to be encoded.

        Returns:
            dict: A dictionary containing tokenized and prepared inputs, ready for model consumption.
        """
        qry_inputs = self.tokenizer.encode(qry_encoding, truncation=True, max_length=self.args.query_max_len + self.args.passage_max_len // 4, add_special_tokens=False)
        doc_inputs = self.tokenizer.encode(doc_encoding, truncation=True, max_length=self.args.passage_max_len + self.args.query_max_len // 2, add_special_tokens=False)
        item = self.tokenizer.prepare_for_model(
            qry_inputs,
            doc_inputs,
            truncation='only_second',
            max_length=self.args.query_max_len + self.args.passage_max_len,
            padding=False,
        )
        return item

    def __getitem__(self, item):
        data = self.dataset[item]
        train_group_size = self.args.train_group_size

        query = data['query']
        if self.args.query_instruction_for_rerank is not None:
            query = self.args.query_instruction_format.format(
                data['query_prompt'] if 'query_prompt' in data else self.args.query_instruction_for_rerank,
                query
            )

        passages = []
        teacher_scores = []

        assert isinstance(data['pos'], list) and isinstance(data['neg'], list)

        pos_idx = random.choice(list(range(len(data['pos']))))
        passages.append(self._shuffle_text(data['pos'][pos_idx]))

        neg_all_idx = list(range(len(data['neg'])))
        if len(data['neg']) < train_group_size - 1:
            num = math.ceil((train_group_size - 1) / len(data['neg']))
            neg_idxs = random.sample(neg_all_idx * num, train_group_size - 1)
        else:
            neg_idxs = random.sample(neg_all_idx, self.args.train_group_size - 1)
        for neg_idx in neg_idxs:
            passages.append(data['neg'][neg_idx])

        if self.args.knowledge_distillation:
            assert isinstance(data['pos_scores'], list) and isinstance(data['neg_scores'], list)
            teacher_scores.append(data['pos_scores'][pos_idx])
            for neg_idx in neg_idxs:
                teacher_scores.append(data['neg_scores'][neg_idx])
            if not all(isinstance(score, (int, float)) for score in teacher_scores):
                raise ValueError(f"pos_score or neg_score must be digit")
        else:
            teacher_scores = None

        if self.args.passage_instruction_for_rerank is not None:
            passages = [
                self.args.passage_instruction_format.format(
                    data['passage_prompt'] if 'passage_prompt' in data else self.args.passage_instruction_for_rerank, p
                )
                for p in passages
            ]

        batch_data = []
        for passage in passages:
            batch_data.append(self.create_one_example(query, passage))

        return batch_data, teacher_scores



In [8]:
from transformers import AutoTokenizer, AutoConfig

In [9]:
tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            token=model_args.token,
            trust_remote_code=True
        )

In [10]:
num_labels = 1

config_model = AutoConfig.from_pretrained(
    model_args.config_name if model_args.config_name else model_args.model_name_or_path,
    num_labels=num_labels,
    cache_dir=model_args.cache_dir,
    token=model_args.token,
    trust_remote_code=model_args.trust_remote_code,
)
# config_model

In [11]:

base_model = AutoModelForSequenceClassification.from_pretrained(
    model_args.model_name_or_path,
    config=config_model,
    cache_dir=model_args.cache_dir,
    token=model_args.token,
    from_tf=bool(".ckpt" in model_args.model_name_or_path),
    trust_remote_code=model_args.trust_remote_code
)


model = CrossEncoderModel(
    base_model,
    tokenizer=tokenizer,
    train_batch_size=training_args.per_device_train_batch_size,
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at BAAI/bge-base-en-v1.5 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
# base_model
# model.compute_loss?

In [13]:
training_args.gradient_checkpointing

False

In [14]:
if training_args.gradient_checkpointing:
    model.enable_input_require_grads()


In [16]:
train_dataset = RerankerTrainDataset(data_args, tokenizer)


In [20]:
@dataclass
class RerankerCollator(DataCollatorWithPadding):
    """
    The abstract reranker collator.
    """
    query_max_len: int = config["query_max_len"]
    passage_max_len: int = config["passage_max_len"]

    def __call__(self, features) -> List[BatchEncoding]:
        teacher_scores = [f[1] for f in features]
        if teacher_scores[0] is None:
            teacher_scores = None
        elif isinstance(teacher_scores[0], list):
            teacher_scores = sum(teacher_scores, [])

        features = [f[0] for f in features]
        if isinstance(features[0], list):
            features = sum(features, [])

        collated = self.tokenizer.pad(
            features,
            padding=self.padding,
            max_length=self.query_max_len + self.passage_max_len,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors,
        )

        return {
            "pair": collated,
            "teacher_scores": teacher_scores,
        }

In [21]:
data_collator = RerankerCollator(
            tokenizer= tokenizer,
            query_max_len= config["query_max_len"],
            passage_max_len= config["passage_max_len"],
            pad_to_multiple_of=None,
            padding=True,
            return_tensors="pt"
        )

In [22]:
# Add this before creating the trainer
from torch.optim import AdamW

# Add a dummy train method to AdamW
if not hasattr(AdamW, 'train'):
    AdamW.train = lambda self: None

# Then create your trainer as normal
trainer = EncoderOnlyRerankerTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer
)

  trainer = EncoderOnlyRerankerTrainer(


In [None]:
trainer.train()



Step,Training Loss


In [33]:
# Save the model onto disk

trainer.save_model()

Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
