In [1]:
%load_ext autoreload
%autoreload 2

In [11]:
from transformers import AutoTokenizer, AutoModelForMaskedLM

model_id = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id)
model.eval()

ModernBertForMaskedLM(
  (model): ModernBertModel(
    (embeddings): ModernBertEmbeddings(
      (tok_embeddings): Embedding(50368, 768, padding_idx=50283)
      (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (drop): Dropout(p=0.0, inplace=False)
    )
    (layers): ModuleList(
      (0): ModernBertEncoderLayer(
        (attn_norm): Identity()
        (attn): ModernBertAttention(
          (Wqkv): Linear(in_features=768, out_features=2304, bias=False)
          (rotary_emb): ModernBertRotaryEmbedding()
          (Wo): Linear(in_features=768, out_features=768, bias=False)
          (out_drop): Identity()
        )
        (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): ModernBertMLP(
          (Wi): Linear(in_features=768, out_features=2304, bias=False)
          (act): GELUActivation()
          (drop): Dropout(p=0.0, inplace=False)
          (Wo): Linear(in_features=1152, out_features=768, bias=False)
        )
      )
      (1-21)

In [None]:

text = "The capital of Germany is [MASK]."
inputs = tokenizer(text, return_tensors="pt")
outputs = model(**inputs)

# To get predictions for the mask:
masked_index = inputs["input_ids"][0].tolist().index(tokenizer.mask_token_id)
predicted_token_id = outputs.logits[0, masked_index].argmax(axis=-1)
predicted_token = tokenizer.decode(predicted_token_id)
print("Predicted token:", predicted_token)
# Predicted token:  Paris

In [4]:
import torch
import os

# Set the model to evaluation mode
model.eval()

# Create dummy input tensors matching the expected input shapes
batch_size = 1
sequence_length = inputs["input_ids"].shape[1]  # Get sequence length from previous example
dummy_input_ids = torch.zeros((batch_size, sequence_length), dtype=torch.int64)
dummy_attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.int64)

# Define dynamic axes for variable sequence lengths
dynamic_axes = {
    'input_ids': {0: 'batch_size', 1: 'sequence_length'},
    'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
    'output': {0: 'batch_size', 1: 'sequence_length'}
}

# Export the model to ONNX
output_path = "modernbert.onnx"
torch.onnx.export(
    model,                                         # Model being exported
    (                                              # Model input args 
        dummy_input_ids,                           # Input ids
        dummy_attention_mask,                      # Attention mask
        None,                                      # token_type_ids (not used)
        None,                                      # position_ids (not used) 
        None,                                      # inputs_embeds (not used)
        None,                                      # labels (not used)
        None,                                      # output_attentions (not used)
        None,                                      # output_hidden_states (not used)
        None,                                      # return_dict (not used)
    ),                                            
    output_path,                                   # Output file path
    input_names=['input_ids', 'attention_mask'],   # Input names
    output_names=['output'],                       # Output names
    dynamic_axes=dynamic_axes,                     # Dynamic axes specification
    opset_version=20,                             # ONNX opset version
    do_constant_folding=True,                     # Fold constants for optimization
)

print(f"Model exported to {output_path}")


Model exported to modernbert.onnx


In [7]:
import onnxruntime
import numpy as np

# Initialize ONNX Runtime session
ort_session = onnxruntime.InferenceSession("modernbert.onnx")

# Convert the PyTorch tensors to numpy arrays for ONNX Runtime
ort_inputs = {
    'input_ids': inputs['input_ids'].numpy(),
    'attention_mask': inputs['attention_mask'].numpy()
}

# Run inference with ONNX Runtime
ort_outputs = ort_session.run(['output'], ort_inputs)

# Get the output logits
ort_logits = ort_outputs[0]

# Convert to probabilities using softmax
ort_probs = np.exp(ort_logits) / np.sum(np.exp(ort_logits), axis=-1, keepdims=True)

# Get the predicted class (highest probability)
ort_predicted = np.argmax(ort_probs, axis=-1)

print("ONNX Runtime Prediction:", tokenizer.decode(ort_predicted[0]))

# Compare with PyTorch model output (optional)
with torch.no_grad():
    torch_outputs = model(inputs['input_ids'], attention_mask=inputs['attention_mask'])
    torch_probs = torch.softmax(torch_outputs.logits, dim=-1)
    torch_predicted = torch.argmax(torch_probs, dim=-1)

print("PyTorch Model Prediction:", tokenizer.decode(torch_predicted[0]))

# Verify the outputs match
np.testing.assert_allclose(ort_probs, torch_probs.numpy(), rtol=1e-3, atol=1e-3)
print("✓ ONNX Runtime and PyTorch outputs match within tolerance")


ONNX Runtime Prediction: [CLS]The capital of Germany is Berlin.[SEP]
PyTorch Model Prediction: [CLS]The capital of Germany is Berlin.[SEP]
✓ ONNX Runtime and PyTorch outputs match within tolerance
