# Fine-tune classifier with ModernBERT in 2025

Large Language Models (LLMs) have become ubiquitous in 2024. However, smaller, specialized models - particularly for classification tasks - remain critical for building efficient and cost-effective AI systems. One key use case is routing user prompts to the most appropriate LLM or selecting optimal few-shot examples, where fast, accurate classification is essential.

This blog post demonstrates how to fine-tune ModernBERT, a new state-of-the-art encoder model, for classifying user prompts to implement an intelligent LLM router. ModernBERT is a refreshed version of BERT models, with 8192 token context length, significantly better downstream performance, and much faster processing speeds.

You will learn how to:
1. Setup environment and install libraries
2. Load and prepare the classification dataset  
3. Fine-tune & evaluate ModernBERT with the Hugging Face `Trainer`
4. Run inference & test model

## Quick intro: ModernBERT

ModernBERT is a modernization of BERT maintaining full backward compatibility while delivering dramatic improvements through architectural innovations like rotary positional embeddings (RoPE), alternating attention patterns, and hardware-optimized design. The model comes in two sizes:
- ModernBERT Base (139M parameters)
- ModernBERT Large (395M parameters)

ModernBERT achieves state-of-the-art performance across classification, retrieval and code understanding tasks while being 2-4x faster than previous encoder models. This makes it ideal for high-throughput production applications like LLM routing, where both accuracy and latency are critical.

ModernBERT was trained on 2 trillion tokens of diverse data including web documents, code, and scientific articles - making it much more robust than traditional BERT models trained primarily on Wikipedia. This broader knowledge helps it better understand the nuances of user prompts across different domains.

If you want to learn more about ModernBERT's architecture and training process, check out the official [blog](https://huggingface.co/blog/modernbert). 

---

Now let's get started building our LLM router with ModernBERT! 🚀

*Note: This tutorial was created and tested on an NVIDIA L4 GPU with 24GB of VRAM.*

## Setup environment and install libraries

Our first step is to install Hugging Face Libraries and Pyroch, including transformers and datasets. 

In [None]:
# Install Pytorch & other libraries
%pip install "torch==2.4.1" tensorboard flash-attn "setuptools<71.0.0" scikit-learn 

# Install Hugging Face libraries
%pip install  --upgrade \
  "datasets==3.1.0" \
  "evaluate==0.4.3" \
  "hf-transfer==0.1.8"
  #"transformers==4.47.1" \

# ModernBERT is not yet available in an official release, so we need to install it from github
%pip install "git+https://github.com/huggingface/transformers.git@6e0515e99c39444caae39472ee1b2fd76ece32f1" --upgrade


We will use the [Hugging Face Hub](https://huggingface.co/models) as a remote model versioning service. This means we will automatically push our model, logs and information to the Hub during training. You must register on the [Hugging Face](https://huggingface.co/join) for this. After you have an account, we will use the `login` util from the `huggingface_hub` package to log into our account and store our token (access key) on the disk.

In [None]:
from huggingface_hub import login

login(token="", add_to_git_credential=True) # ADD YOUR TOKEN HERE

## 2. Load and prepare the dataset

In our example we want to fine-tune ModernBERT to act as a router for user prompts. Therefore we need a classification dataset consisting of user prompts and their "difficulty" score. We are going to use the `DevQuasar/llm_router_dataset-synth` dataset, which is a synthetic dataset of ~15,000 user prompts with a difficulty score of "large_llm" (`1`) or "small_llm" (`0`). 


We will use the `load_dataset()` method from the [🤗 Datasets](https://huggingface.co/docs/datasets/index) library to load the `DevQuasar/llm_router_dataset-synth` dataset.

In [1]:
from datasets import load_dataset

# Dataset id from huggingface.co/dataset
dataset_id = "DevQuasar/llm_router_dataset-synth"

# Load raw dataset
raw_dataset = load_dataset(dataset_id)

print(f"Train dataset size: {len(raw_dataset['train'])}")
print(f"Test dataset size: {len(raw_dataset['test'])}")

Train dataset size: 15306
Test dataset size: 4921


Let’s check out an example of the dataset.

In [2]:
from random import randrange

random_id = randrange(len(raw_dataset['train']))
raw_dataset['train'][random_id]
# {'id': '6225a9cd-5cba-4840-8e21-1f9cf2ded7e6',
# 'prompt': 'How many legs does a spider have?',
# 'label': 0}

{'id': '5445f394-7896-4151-a7cb-ecee3cd1e342',
 'prompt': 'How does the concept of entropy influence the performance of a clustering algorithm, and what are its implications on data visualization?',
 'label': 1}

To train our model, we need to convert our text prompts to token IDs. This is done by a Tokenizer, which tokenizes the inputs (including converting the tokens to their corresponding IDs in the pre-trained vocabulary) if you want to learn more about this, out **[chapter 6](https://huggingface.co/course/chapter6/1?fw=pt)** of the [Hugging Face Course](https://huggingface.co/course/chapter1/1).

In [3]:
from transformers import AutoTokenizer

# Model id to load the tokenizer
model_id = "answerdotai/ModernBERT-base"
# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.model_max_length = 1024 # set model_max_length to 1024 as prompts are not longer than 1024 tokens

# Tokenize helper function
def tokenize(batch):
    return tokenizer(batch['prompt'], padding='max_length', truncation=True, return_tensors="pt")

# Tokenize dataset
raw_dataset =  raw_dataset.rename_column("label", "labels") # to match Trainer
tokenized_dataset = raw_dataset.map(tokenize, batched=False,remove_columns=["prompt","id"])

print(tokenized_dataset["train"].features.keys())
# dict_keys(['input_ids', 'token_type_ids', 'attention_mask','lable'])

Map:   0%|          | 0/4921 [00:00<?, ? examples/s]

dict_keys(['labels', 'input_ids', 'attention_mask'])


## 3. Fine-tune & evaluate ModernBERT with the Hugging Face `Trainer`

After we have processed our dataset, we can start training our model. We will use the [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) model. The first step is to load our model with `AutoModelForSequenceClassification` class from the [Hugging Face Hub](https://huggingface.co/answerdotai/ModernBERT-base). This will initialize the pre-trained ModernBERT weights with a classification head on top. Here we pass the number of classes (2) from our dataset and the label names to have readable outputs for inference.

In [4]:
from transformers import AutoModelForSequenceClassification

# Model id to load the tokenizer
model_id = "answerdotai/ModernBERT-base"

# Prepare model labels - useful for inference
labels = tokenized_dataset["train"].features["labels"].names
num_labels = len(labels)
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

# Download the model from huggingface.co/models
model = AutoModelForSequenceClassification.from_pretrained(
    model_id, num_labels=num_labels, label2id=label2id, id2label=id2label,
)

You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in ModernBertForSequenceClassification is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this 

We evaluate our model during training. The `Trainer` supports evaluation during training by providing a `compute_metrics` method. We use the `evaluate` library to calculate the [f1 metric](https://huggingface.co/spaces/evaluate-metric/f1) during training on our test split.

In [5]:
import evaluate
import numpy as np

# Metric Id
metric = evaluate.load("f1")

# Metric helper method
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels, average="weighted")

The last step is to define the hyperparameters (`TrainingArguments`) we use for our training. Here we are adding optimizations introduced features for fast training times using `torch_compile` option in the `TrainingArguments`.

We also leverage the [Hugging Face Hub](https://huggingface.co/models) integration of the `Trainer` to push our checkpoints, logs, and metrics during training into a repository.

In [6]:
from huggingface_hub import HfFolder
from transformers import Trainer, TrainingArguments

# Define training args
training_args = TrainingArguments(
    output_dir= "modernbert-llm-router",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=8,
    learning_rate=5e-5,
		num_train_epochs=5,
    bf16=True, # bfloat16 training 
		torch_compile=False, # optimizations
    optim="adamw_torch_fused", # improved optimizer 
    # logging & evaluation strategies
    logging_strategy="steps",
    logging_steps=100,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    # push to hub parameters
    report_to="tensorboard",
    push_to_hub=True,
    hub_strategy="every_save",
    hub_token=HfFolder.get_token(),

)

# Create a Trainer instance
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    compute_metrics=compute_metrics,
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[2024-12-24 14:53:55,367] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/opt/conda/envs/pytorch/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/opt/conda/envs/pytorch/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status


We can start our training by using the **`train`** method of the `Trainer`.

In [7]:
# Start training
trainer.train()

InternalTorchDynamoError: Caught InternalTorchDynamoError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 84, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 1160, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 895, in forward
    hidden_states = self.embeddings(input_ids)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 210, in forward
    self.compiled_embeddings(input_ids)
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
    return _compile(
           ^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_utils_internal.py", line 84, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 846, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 189, in _fn
    torch.cuda.set_rng_state(cuda_rng_state)  # type: ignore[possibly-undefined]
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/cuda/random.py", line 75, in set_rng_state
    _lazy_call(cb)
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/cuda/__init__.py", line 244, in _lazy_call
    callable()
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/cuda/random.py", line 73, in cb
    default_generator.set_state(new_state_copy)
torch._dynamo.exc.InternalTorchDynamoError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.



You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True



![tensorboard](../assets/tensorboard.png)

Using Pytorch 2.0 and supported features in `transformers` allows us train our BERT model on `10_000` samples within `457.7964` seconds. 

We also ran the training without the `torch_compile` option to compare the training times. The training without `torch_compile` took 457 seconds, had a `train_samples_per_second` value of 65.55 and an `f1` score of `0.931`.

```bash
{'train_runtime': 696.2701, 'train_samples_per_second': 43.1, 'eval_f1': 0.928788}
```

By using the `torch_compile` option and the `adamw_torch_fused` optimized , we can see that the training time is reduced by 52.5% compared to the training without PyTorch 2.0. 

```bash
{'train_runtime': 457.7964, 'train_samples_per_second': 65.55, 'eval_f1': 0.931773}
```

Our absoulte training time went down from 696s to 457. The `train_samples_per_second` value increased from 43 to 59. The `f1` score is the same/slighty better than the training without `torch_compile`.

Pytorch 2.0 is incredible powerful! 🚀 

Lets save our results and tokenizer to the Hugging Face Hub and create a model card.

In [None]:
# Save processor and create model card
tokenizer.save_pretrained(repository_id)
trainer.create_model_card()
trainer.push_to_hub()

## 4. Run Inference & test model

To wrap up this tutorial, we will run inference on a few examples and test our model. We will use the `pipeline` method from the `transformers` library to run inference on our model.

In [None]:
from transformers import pipeline

# load model from huggingface.co/models using our repository id
classifier = pipeline("sentiment-analysis", model=repository_id, tokenizer=repository_id, device=0)

sample = "I have been waiting longer than expected for my bank card, could you provide information on when it will arrive?"


pred = classifier(sample)
print(pred)
# [{'label': 'card_arrival', 'score': 0.9903606176376343}]

## Conclusion

In this tutorial, we learned how to use PyTorch 2.0 to train a text classification model on the BANKING77 dataset. We saw that PyTorch 2.0 is a powerful tool to speed up your training times. In our example running on a NVIDIA A10G we managed to achieve 52.5% better performance. The Hugging Face Trainer allows you to easily integrate PyTorch 2.0 into your training pipeline by simply adding the `torch_compile` option to the `TrainingArguments`. We can further benefit from PyTorch 2.0 by using the new fused AdamW optimizer when bf16 is available. 

Additionally, I want to mentioned that we reduced the training time by 52%, which could be interpreted in a cost saving of 52% for the training or in 52% faster iterations cycles and time to production. You should be able to see even better improvements by using A100 GPUs or by reducing the "Trainer" overhead, e.g. removing evaluation and logging. 

PyTorch 2.0 is now officially launched and we are excited to see what the future brings. 🚀