In [45]:
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import os
import sys

import torch
from torch import nn
from torch.nn.attention.flex_attention import create_block_mask

In [5]:
sys.path.append("/data/horse/ws/lama722b-nanite-lm/nanite-lm")

In [25]:
from codebase.transformer import Attention, FeedForward, RotaryEmbedding

In [76]:
@dataclass
class PostModernBertArgs:
    vocab_size: int = 256
    dim: int = 768
    n_layers: int = 22
    head_dim: Optional[int] = None

    # Embedding Related Params
    pad_token_id: int = 255
    norm_eps: float = 1e-5
    norm_bias: float = False
    embedding_dropout: float = 0.0

    #Model Args
    max_seqlen: int = 512
    multiple_of: int = 256
    ffn_dim_multiplier: Optional[int] = None
    rope_theta: float = 10_000.0
    head_dim: Optional[int] = None
    n_heads: int = 12
    n_kv_heads: Optional[int] = None
    

In [86]:
class PostModernBertEncoderBlock(nn.Module):
    def __init__(self, args):
        super().__init__()
        # ModernBERT uses nn.Identity as attn_norm for layer_idx ==0
        self.attn_norm = nn.LayerNorm(args.dim, eps=args.norm_eps, bias=args.norm_bias)
        self.head_dim = args.head_dim or args.dim // args.n_heads
        self.n_heads = args.n_heads or args.dim // args.head_dim
        self.n_kv_heads = args.n_kv_heads or self.n_heads
        self.attn = Attention(args.dim,
                              head_dim=self.head_dim,
                              n_heads=self.n_heads,
                              n_kv_heads=self.n_kv_heads,
                              rope_theta=args.rope_theta
                             )
        self.mlp = FeedForward(args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of, ffn_dim_multiplier=args.ffn_dim_multiplier)
        self.mlp_norm = nn.LayerNorm(args.dim, eps=args.norm_eps, bias=args.norm_bias)

    def forward(self, x, freq_cis):
        h = x + self.attn(x, freq_cis)
        # Residual Connection
        # MLP Output
        out = h + self.mlp(self.mlp_norm(h))
        return out

class PostModernEmbeddings(nn.Module):
    """
    Currently similar to ModernBert without `torch.compile`
    """
    def __init__(self, args):
        super().__init__()
        self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim, padding_idx=args.pad_token_id)
        self.norm = nn.LayerNorm(args.dim, eps=args.norm_eps, bias=args.norm_bias)
        self.drop = nn.Dropout(args.embedding_dropout)

    def forward(self, input_ids):
        h = self.drop(self.norm(self.tok_embeddings(input_ids)))
        return h

class PostModernBert(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.max_seqlen = args.max_seqlen
        self.tok_embeddings = PostModernEmbeddings(args)
        self.rope_embeddings = RotaryEmbedding(
            theta=args.rope_theta,
            head_dim=args.head_dim or args.dim // args.n_heads,
            max_seqlen=args.max_seqlen,
        )
        self.layers = nn.ModuleList()
        for _ in range(args.n_layers):
            self.layers.append(PostModernBertEncoderBlock(args))
        self.final_norm = nn.LayerNorm(args.dim, eps=args.norm_eps, bias=args.norm_bias)

    def forward(self, inp, attention_mask=None): 
        batch_size, seq_len = inp.shape[:2]
        #TODO(krotonus): Move the below variable
        tok_idx = None
        
        if attention_mask is None:
            attention_mask = torch.ones(
                (batch_size, seq_len),
            )
        freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx)
        h = self.tok_embeddings(inp)
        for i, layer in enumerate(self.layers):
            h = layer(h, freq_cis)
        h = self.final_norm(h)
        return h

In [87]:
args = PostModernBertArgs()
model = PostModernBert(args)
inp = torch.randint(low=0, high=256, size=(1, 10))
model(inp)

tensor([[[ 0.0859, -0.1358,  1.4519,  ...,  0.3180, -1.1336, -0.6670],
         [-0.4470, -0.2797, -0.2225,  ...,  0.5280,  0.5631, -0.6856],
         [-0.6883, -0.7299,  2.5745,  ...,  0.4008,  0.0605,  0.0774],
         ...,
         [-1.4310, -0.9367,  0.3493,  ...,  1.0679, -0.2719,  0.1380],
         [-0.6686, -0.0198,  0.2524,  ..., -0.4789, -0.5860, -1.4794],
         [ 0.3705, -0.3117,  1.2196,  ...,  1.1364, -0.8761, -0.8457]]],
       grad_fn=<NativeLayerNormBackward0>)

In [79]:
model

PostModernBert(
  (tok_embeddings): PostModernEmbeddings(
    (tok_embeddings): Embedding(256, 768, padding_idx=255)
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (drop): Dropout(p=0.0, inplace=False)
  )
  (rope_embeddings): RotaryEmbedding()
  (layers): ModuleList(
    (0-21): 22 x PostModernBertEncoderBlock(
      (attn_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): Attention(
        (wq): Linear(in_features=768, out_features=768, bias=False)
        (wk): Linear(in_features=768, out_features=768, bias=False)
        (wv): Linear(in_features=768, out_features=768, bias=False)
        (wo): Linear(in_features=768, out_features=768, bias=False)
      )
      (mlp): FeedForward(
        (w1): Linear(in_features=768, out_features=2048, bias=False)
        (w3): Linear(in_features=768, out_features=2048, bias=False)
        (w2): Linear(in_features=2048, out_features=768, bias=False)
      )
      (mlp_norm): LayerNorm((768,), eps=1e-0

In [52]:
from transformers import AutoConfig, AutoModel
ref_config = AutoConfig.from_pretrained("answerdotai/ModernBERT-base")

ERROR! Session/line number was not unique in database. History logging moved to new session 4


In [88]:
ref_model = AutoModel.from_config(ref_config)
print(ref_model)
ref_model(inp).last_hidden_state

ModernBertModel(
  (embeddings): ModernBertEmbeddings(
    (tok_embeddings): Embedding(50368, 768, padding_idx=50283)
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (drop): Dropout(p=0.0, inplace=False)
  )
  (layers): ModuleList(
    (0): ModernBertEncoderLayer(
      (attn_norm): Identity()
      (attn): ModernBertAttention(
        (Wqkv): Linear(in_features=768, out_features=2304, bias=False)
        (rotary_emb): ModernBertRotaryEmbedding()
        (Wo): Linear(in_features=768, out_features=768, bias=False)
        (out_drop): Identity()
      )
      (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): ModernBertMLP(
        (Wi): Linear(in_features=768, out_features=2304, bias=False)
        (act): GELUActivation()
        (drop): Dropout(p=0.0, inplace=False)
        (Wo): Linear(in_features=1152, out_features=768, bias=False)
      )
    )
    (1-21): 21 x ModernBertEncoderLayer(
      (attn_norm): LayerNorm((768,), eps=1e-05, e

tensor([[[-1.6153, -0.8174, -0.5756,  ...,  1.8049,  1.5337,  1.3146],
         [ 0.1290,  0.8749, -1.3029,  ...,  0.0916,  0.4125, -0.1262],
         [ 0.4988, -0.9913, -0.5715,  ..., -1.4603, -1.0956, -0.5315],
         ...,
         [-2.2732,  0.6711, -0.0564,  ...,  1.1797, -0.6134, -0.3962],
         [ 1.1997,  0.4882,  0.7147,  ..., -0.7807, -0.1498, -0.5456],
         [-0.4344, -0.6453,  0.1459,  ...,  0.0474,  0.9059, -1.3447]]],
       grad_fn=<NativeLayerNormBackward0>)

In [81]:
model(inp)

tensor([[[-6.8827e-01, -7.6610e-01,  5.3896e-01,  ..., -4.8102e-01,
          -8.7296e-01, -4.4662e-03],
         [-3.5031e-01, -1.5203e+00, -1.4971e-03,  ..., -4.7378e-01,
           6.1782e-01,  1.5122e+00],
         [ 1.5161e-01, -1.7240e+00, -3.6489e-01,  ...,  1.1788e+00,
           1.0126e+00,  2.6658e-01],
         ...,
         [-9.1646e-01,  2.1504e-01,  3.2687e-01,  ..., -6.3135e-01,
           9.6177e-01,  8.9055e-01],
         [-1.8178e-01, -5.7361e-01, -9.8695e-01,  ..., -9.3489e-01,
           2.8982e-01,  1.2466e-02],
         [-1.7038e+00, -1.3339e+00, -3.5118e-01,  ...,  6.0886e-01,
          -6.5583e-01, -5.5107e-01]]], grad_fn=<NativeLayerNormBackward0>)

In [89]:
import torch.nn.attention.flex_attention as flex

In [92]:
def mod_fn(b, h, q_idx, kv_idx):
    return q_idx != kv_idx

B = 2
H = 12 # Number of query heads
q_len = 512
kv_len = 512
inp = torch.rand((B, H, q_len, kv_len))
flex.create_block_mask(mod_fn, B, H, q_len, kv_len, device="cpu")

BlockMask(
    kv_num_blocks=torch.Size([2, 12, 4]),
    kv_indices=torch.Size([2, 12, 4, 4]),
    full_kv_num_blocks=torch.Size([2, 12, 4]),
    full_kv_indices=torch.Size([2, 12, 4, 4]),
    q_num_blocks=torch.Size([2, 12, 4]),
    q_indices=torch.Size([2, 12, 4, 4]),
    full_q_num_blocks=torch.Size([2, 12, 4]),
    full_q_indices=torch.Size([2, 12, 4, 4]),
    BLOCK_SIZE=(128, 128),
    shape=(2, 12, 512, 512),
    sparsity=0.00%,
    mask_mod=mod_fn
)