Skip to content

feat: add reranker training#1449

Merged
akoumpa merged 37 commits intomainfrom
adil/retriever-feat
Mar 30, 2026
Merged

feat: add reranker training#1449
akoumpa merged 37 commits intomainfrom
adil/retriever-feat

Conversation

@adil-a
Copy link
Copy Markdown
Collaborator

@adil-a adil-a commented Mar 4, 2026

Summary

Adds cross-encoder (reranker) training support alongside the existing bi-encoder pipeline. Cross-encoders jointly attend to query-document pairs, producing relevance scores that are significantly more accurate than bi-encoder similarity — a critical capability for retrieval reranking.

Cross-Encoder Model & Training Recipe

  • NeMoAutoModelCrossEncoder — New auto-model class with full infrastructure support (FSDP2, PEFT, kernel patching)
  • CrossEncoderModel — Wraps any ForSequenceClassification backbone. Routes through LlamaBidirectionalForSequenceClassification for Llama models, falls back to HF AutoModelForSequenceClassification for all others
  • TrainCrossEncoderRecipe — Extends TrainBiEncoderRecipe with cross-entropy loss on reshaped logits, training accuracy tracking, and validation with accuracy@1 and MRR metrics
  • accuracy() / batch_mrr() — Pure-function metric utilities for training and validation

Data Pipeline

  • CrossEncoderCollator — Concatenates query+passage via configurable prompt template, tokenizes, pads. Compatible with NeMoAutoTokenizer
  • flatten_bi_encoder_to_cross_encoder() — Transforms grouped bi-encoder data to flattened cross-encoder format. Same data files work for both model types via model_type config switch
  • make_retrieval_dataset(model_type="cross_encoder") — Unified dataset factory supporting both bi-encoder and cross-encoder formats

Architecture Simplification

  • Removed dual-encoder (lm_q/lm_p) design with share_encoder — single backbone is simpler and aligns with modern embedding models
  • Simplified EncoderStateDictAdapter (was BiencoderStateDictAdapter) — symmetric model. prefix strip/add
  • Any HF model as encoder backboneAutoModel/AutoModelForSequenceClassification fallback means users are not limited to custom bidirectional models
  • Unified naming: Biencoder -> BiEncoder, train_n_passages -> n_passages

CI & Testing

  • Functional test: 2-GPU FSDP2 training (32 steps) + quality evaluation asserting finetuned model improves positive-pair scoring and achieves >=75% ranking accuracy
  • Unit tests: accuracy(), batch_mrr(), CrossEncoderCollator (output shapes, labels, padding), flatten_bi_encoder_to_cross_encoder() (value validation), liger/SDPA retry for cross-encoder
  • L2_Retrieval CI job added to cicd-main.yml

Example

model:
  _target_: nemo_automodel.NeMoAutoModelCrossEncoder.from_pretrained
  pretrained_model_name_or_path: meta-llama/Llama-3.2-1B
  num_labels: 1
  pooling: avg
dataloader:
  dataset:
    _target_: nemo_automodel.components.datasets.llm.make_retrieval_dataset
    model_type: cross_encoder
    n_passages: 5
  collate_fn:
    _target_: nemo_automodel.components.datasets.llm.CrossEncoderCollator
    rerank_max_length: 512
    prompt_template: "question:{query} \n \n passage:{passage}"

Test plan

  • 2-GPU cross-encoder training completes without errors (32 steps, FSDP2)
  • Finetuned model pos-score improvement > 0 (observed: +9.35)
  • Finetuned ranking accuracy >= 75% (observed: 92.0%)
  • 22 unit tests pass (accuracy, batch_mrr, collator, flatten, retry)
  • ruff check clean on nemo_automodel/
  • CI L2_Retrieval job passes

Signed-off-by: adil-a <adil.asif2000@hotmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Mar 4, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

adil-a added 5 commits March 4, 2026 07:49
Signed-off-by: adil-a <adil.asif2000@hotmail.com>
Signed-off-by: adil-a <adil.asif2000@hotmail.com>
Signed-off-by: adil-a <adil.asif2000@hotmail.com>
Signed-off-by: adil-a <adil.asif2000@hotmail.com>
Signed-off-by: adil-a <adil.asif2000@hotmail.com>
Comment thread nemo_automodel/_transformers/encoder.py Outdated
Comment thread nemo_automodel/_transformers/retrieval.py
Signed-off-by: adil-a <adil.asif2000@hotmail.com>
@adil-a adil-a changed the title feat: add retriever training feat: add reranker training Mar 5, 2026
adil-a added 7 commits March 5, 2026 04:42
Signed-off-by: adil-a <adil.asif2000@hotmail.com>
Signed-off-by: adil-a <adil.asif2000@hotmail.com>
Signed-off-by: adil-a <adil.asif2000@hotmail.com>
Signed-off-by: adil-a <adil.asif2000@hotmail.com>
Signed-off-by: adil-a <adil.asif2000@hotmail.com>
Signed-off-by: adil-a <adil.asif2000@hotmail.com>
Signed-off-by: adil-a <adil.asif2000@hotmail.com>
@rnyak rnyak mentioned this pull request Mar 5, 2026
3 tasks
adil-a added 2 commits March 6, 2026 16:26
Signed-off-by: adil-a <adil.asif2000@hotmail.com>
Signed-off-by: adil-a <adil.asif2000@hotmail.com>
Comment thread nemo_automodel/recipes/encoder/train_cross_encoder.py Outdated
Comment thread examples/encoder/cross_encoder/llama3_2_1b.yaml Outdated
Comment thread nemo_automodel/_transformers/encoder.py Outdated
Comment thread examples/encoder/cross_encoder/llama3_2_1b.yaml Outdated
Comment thread examples/encoder/cross_encoder/llama3_2_1b.yaml Outdated
Comment thread nemo_automodel/recipes/encoder/train_retriever_encoder.py Outdated
adil-a and others added 2 commits March 7, 2026 21:27
Signed-off-by: adil-a <adil.asif2000@hotmail.com>
…assages, derive from dataloader

Remove redundant top-level train_n_passages and eval_negative_size from YAML configs and recipe __init__. The recipe now derives train_n_passages and val_n_passages directly from the dataloader dataset config (n_passages), making the dataset config the single source of truth. Also removes dead train_n_passages param from CrossEncoderCollator and unused temperature from cross-encoder config.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@adil-a
Copy link
Copy Markdown
Collaborator Author

adil-a commented Mar 8, 2026

Thanks for all the useful comments @rnyak . I overlooked these yaml details initially. I've now fixed them. Please take a look when possible.

Comment thread docs/guides/llm/retrieval-dataset.md Outdated
Comment thread nemo_automodel/components/datasets/llm/retrieval_dataset.py Outdated
Comment thread examples/retrieval/cross_encoder/llama3_2_1b.yaml
Comment thread nemo_automodel/components/datasets/llm/retrieval_dataset.py
Comment thread examples/retrieval/cross_encoder/llama3_2_1b.yaml
Comment thread nemo_automodel/components/models/llama_bidirectional/model.py
Comment thread examples/retrieval/cross_encoder/llama3_2_1b.yaml
…ection

Drop custom COVERAGE_ARGS (--data-file, --source, --parallel-mode) that
prevented coverage data from being collected. Matches the pattern used
by all other multi-GPU functional tests (L2_DCP, L2_HF_PEFT, etc.).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: adil-a <adil.asif2000@hotmail.com>
@adil-a
Copy link
Copy Markdown
Collaborator Author

adil-a commented Mar 30, 2026

/ok to test 3bb2f62

@adil-a adil-a enabled auto-merge (squash) March 30, 2026 20:14
@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Mar 30, 2026

/claude review

Comment thread .github/CODEOWNERS
)

if num_labels is not None:
batch_dict["labels"] = torch.zeros(num_labels, dtype=torch.long)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential bug: num_labels is read from features[0].get("num_labels"), which reflects the number of unique queries in a single dataset transform call (typically 1 when set_transform processes individual rows). But the labels tensor needs to have size equal to the number of query groups in the entire collated batch (i.e. len(features) // n_passages).

If a DataLoader batch has 8 flat (query, doc) pairs from 4 queries with n_passages=2, num_labels here would be 1 (from a single row's transform), producing labels = torch.zeros(1) instead of torch.zeros(4). This would cause a shape mismatch at F.cross_entropy(logits.view(-1, n_passages), labels) in the training recipe.

This might be masked by how HuggingFace set_transform + __getitems__ batches data in your environment, but it's fragile. Consider computing labels from the actual batch size:

n_queries = len(features) // n_passages  # where n_passages comes from config
labels = torch.zeros(n_queries, dtype=torch.long)

@akoumpa akoumpa disabled auto-merge March 30, 2026 20:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants