In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertConfig

class ModernBERT(nn.Module):
    def __init__(self, config):
        super(ModernBERT, self).__init__()

        # Token Embeddings (Word Vectors)
        self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)

        # Positional Encoding (Learnable)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)

        # Segment Embeddings (Token Type IDs)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # Layer Normalization and Dropout
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # Transformer Encoder Layers (Stack of Self-Attention Blocks)
        self.encoder_layers = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.num_hidden_layers)
        ])

        # Output Layer (For Masked Language Model)
        self.output_layer = nn.Linear(config.hidden_size, config.vocab_size)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        seq_length = input_ids.size(1)

        # Token Embeddings
        token_embeddings = self.embedding(input_ids)

        # Positional Embeddings
        position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        position_embeddings = self.position_embeddings(position_ids)

        # Segment Embeddings
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        # Combine Embeddings
        embeddings = token_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)

        # Pass through Transformer Encoder Layers
        hidden_states = embeddings
        for layer in self.encoder_layers:
            hidden_states = layer(hidden_states, attention_mask)

        # Final Output Layer (Predict Token Probabilities)
        logits = self.output_layer(hidden_states)

        return logits


class TransformerBlock(nn.Module):
    def __init__(self, config):
        super(TransformerBlock, self).__init__()

        # Multi-Head Self-Attention
        self.attention = MultiHeadSelfAttention(config)

        # Layer Normalization
        self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-12)
        self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-12)

        # Feed-Forward Network
        self.ffn = nn.Sequential(
            nn.Linear(config.hidden_size, config.intermediate_size),
            nn.GELU(),
            nn.Linear(config.intermediate_size, config.hidden_size),
        )

        # Dropout
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, x, attention_mask=None):
        # Multi-Head Self-Attention with Residual Connection
        attn_output = self.attention(x, attention_mask)
        x = self.norm1(x + attn_output)

        # Feed-Forward Network with Residual Connection
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_output))

        return x


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, config):
        super(MultiHeadSelfAttention, self).__init__()

        self.num_heads = config.num_attention_heads
        self.head_dim = config.hidden_size // config.num_attention_heads
        self.scale = self.head_dim ** -0.5  # Scaling factor for dot-product

        # Projection Layers
        self.query = nn.Linear(config.hidden_size, config.hidden_size)
        self.key = nn.Linear(config.hidden_size, config.hidden_size)
        self.value = nn.Linear(config.hidden_size, config.hidden_size)
        self.out = nn.Linear(config.hidden_size, config.hidden_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def forward(self, x, attention_mask=None):
        batch_size, seq_length, hidden_size = x.size()

        # Project Q, K, V Matrices
        Q = self.query(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.key(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.value(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled Dot-Product Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale

        # Apply Attention Mask (If Given)
        if attention_mask is not None:
            scores = scores.masked_fill(attention_mask == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Compute Attention Output
        attn_output = torch.matmul(attn_weights, V)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, hidden_size)

        return self.out(attn_output)


In [None]:
config = BertConfig(
    vocab_size=30522,  # Match the tokenizer's vocabulary
    hidden_size=768,  # Hidden layer size (standard BERT size)
    num_hidden_layers=12,  # Number of transformer blocks
    num_attention_heads=12,  # Number of self-attention heads
    intermediate_size=3072,  # Feed-forward network size
    max_position_embeddings=512,  # Maximum sequence length
    type_vocab_size=2,  # Segment embeddings (BERT uses 2)
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
)

# Initialize Model
model = ModernBERT(config)


In [None]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./modernbert_checkpoints",
    overwrite_output_dir=True,
    num_train_epochs=3,
    per_device_train_batch_size=8,
    save_steps=1000,
    logging_steps=500,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
)

trainer.train()
