Skip to content
/ mbr Public

Minimum Bayes Risk Decoding for Hugging Face Transformers

License

Notifications You must be signed in to change notification settings

ZurichNLP/mbr

Repository files navigation

mbr 🔥

Main PyPI

mbr adds Sampling-based Minimum Bayes Risk decoding to Hugging Face transformers. Originally proposed by Eikema & Aziz (2022), this technique is a risk-minimizing algorithm for generating text with a language model. This repository implements several optimizations for MBR decoding. Most notably, mbr introduces reference aggregation Vamvas & Sennrich (2024).

Pronounce: ember /ˈɛm.bɚ/

Installation

pip install mbr

Requirements:

  • Python >= 3.9
  • PyTorch
  • Hugging Face transformers < 4.39

Usage

The main components of mbr are:

  • mbr.MBRGenerationMixin: overrides a model's generate method to add MBR decoding.
  • mbr.MBRGenerationConfig: specifies the parameters of MBR decoding, e.g., the number of samples to generate and the metric to optimize.

1. Load a Hugging Face transformers model

Models need to inherit from MBRGenerationMixin for MBR decoding to work. Here's two ways to achieve this, using the Llama model as an example:

Variant A:

from transformers import LlamaForCausalLM

from mbr import MBRGenerationMixin

class MBRLlamaForCausalLM(MBRGenerationMixin, LlamaForCausalLM):
    pass

Then, you can use MBRLlamaForCausalLM as you would use LlamaForCausalLM:

model = MBRLlamaForCausalLM.from_pretrained(...)

Variant B:

from mbr import MBR
model = MBR(LlamaForCausalLM).from_pretrained(...)

2. Configure MBR decoding

Create an MBRConfig object to pass to the model's generate method:

from mbr import MBRConfig

mbr_config = MBRConfig(
    num_samples=10,
    metric="chrf",
)

3. Generate text as usual

Call the model's generate method directly, or use the Pipeline API. Make sure to pass the mbr_config, as well as the model's tokenizer.

from transformers import pipeline

generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
output = generator("Hello,", mbr_config=mbr_config, tokenizer=tokenizer)

How MBR decoding works

The following research papers, among many others, provide a description of Sampling-based Minimum Bayes Risk decoding:

In practice, MBR decoding is most commonly implemented as follows (on the example of machine translation):

  • Instead of searching for the single most probable output sequence (e.g., using beam search), generate a number of samples.
  • Score each sample against the other samples using a metric (e.g., BLEU).
  • Return the sample with the highest score. Intuitively, this can be seen as returning the median of all samples.

Illustration of MBR decoding

The terminology around MBR decoding varies:

Term used in this codebase Alternative terms
samples candidates, hypotheses
references pseudo-references, evidence
metric score expected utility
(negative) expected risk, error

Details

Configuring the sampling

The generation of the samples can be customized by passing a generation_config to the generate method or to the pipeline call:

from transformers import GenerationConfig

generation_config = GenerationConfig.from_pretrained("mymodel",
    do_sample=True,
    num_beams=1,
    epsilon_cutoff=0.02,
)
model.generate(..., generation_config=generation_config)

Separate set of references

By default, the samples themselves are used a references (or a subset of the samples if num_references is smaller than num_samples).

You could also sample the reference set independently, using a custom generation config for the references:

from transformers import GenerationConfig

references_config = GenerationConfig.from_pretrained("mymodel",
    do_sample=True,
    num_beams=1,
    top_p=0.9,
)
model.generate(..., references_config=references_config)

Choosing a metric

By default, mbr uses fastChrF, which is optimized for efficient comparison of many samples to many references.

You can also plug in metrics from the Hugging Face Evaluate library.

A full list of metrics is found here. Some typical choices are:

To use a metric from Hugging Face, either specify the metric's name (e.g., "comet", "bleurt") or pass an evaluate.Metric object directly.

Since different metrics output differently structured dicts, you need to specify the metric_output_field that should be used as the metric score.

from evaluate import load

metric = load('bleu')
mbr_config = MBRGenerationConfig(
    metric=metric,
    metric_output_field="bleu",  # the BLEU metric returns a dict with a "bleu" field
    ...
)

Customizing the metric computation

Internally, mbr will call the metric's compute method to calculate the metric score for each sample.

By default, mbr will call compute separately for each sample–reference pair. Since this requires many compute calls, it can make sense to optimize the metric computation. Different metrics will require different optimization strategies. To override the default way of calling the metric, define a MetricRunner class and pass it to the generate method:

from mbr import MetricRunner

class MyMetricRunner(MetricRunner):

    def __call__(self,
                 input_ids: torch.LongTensor,
                 sample_ids: Tuple[torch.LongTensor],
                 reference_ids: Tuple[torch.LongTensor],
                 ) -> torch.FloatTensor:
        ...  # TODO: implement your efficient metric computation here
        
model.generate(..., metric_runner=MyMetricRunner())

For COMET, an optimized implementation is already provided in CometMetricRunner:

from mbr.metrics.comet import CometMetricRunner

mbr_config = MBRGenerationConfig(
    ...,
    metric="comet",
    metric_output_field="mean_score",
)

metric_runner = CometMetricRunner(mbr_config, tokenizer)
model.generate(..., metric_runner=metric_runner)

Optimizations

MBR decoding is notoriously slow. mbr implements some optimizations:

  • Cached encoder outputs: For encoder-decoder models, the encoder outputs are computed only once and reused during sampling.
  • Optimized ChrF metric: fastChrF is used by default, which is a streamlined ChrF variant for MBR, implemented in Rust.
  • Cached metrics: Most metrics are computed only once for each unique sample–reference pair (since there will be duplicate samples and references).
  • Optimized COMET metric: Inspired by Amrhein & Sennrich (2022), CometMetricRunner caches sequence embeddings and reuses them for all pairwise comparisons.
  • Reference aggregation for COMET (Vamvas & Sennrich, 2024): Consider using mbr.metrics.comet.AggregateCometMetricRunner instead of the default CometMetricRunner if you have many references.

Example scripts

The experiments directory contains the code for reproductions of experiments from the following papers:

Code for research papers

Related projects

Changelog

  • v0.3.0 (draft)

    • New feature: Reference Aggregation (Vamvas & Sennrich, 2024):
      • Set fastChrF with reference aggregation as default metric
      • Add AggregateCometMetricRunner to allow for reference aggregation with COMET
    • Bugfix: Disable dropout for COMET metric
  • v0.2.0

    • Breaking change: Rename MBRGenerationConfig to MBRConfig
    • Breaking change: MetricRunner now returns a MetricOutput dict instead of the raw tensor of scores.
    • Make the size of the metric cache configurable via MBRConfig.metric_cache_size
    • Allow that the number of references can be larger than the number of samples (if generated separately from the samples).
    • Remove GenerationConfig as parent class of MBRConfig

Citation

When using this code for research, please cite the following paper:

@misc{vamvas-sennrich-2024-linear,
      title={Linear-time Minimum Bayes Risk Decoding with Reference Aggregation},
      author={Jannis Vamvas and Rico Sennrich},
      year={2024},
      eprint={2402.04251},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}