In [2]:
import torch
import torch.nn as nn
from torchsummary import summary

In [19]:
class DyT(nn.Module):
    def __init__(self, dims, init_alpha=0.5, **kwargs):
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
        self.beta = nn.Parameter(torch.zeros(dims))
        self.gamma = nn.Parameter(torch.ones(dims))

    def forward(self, x):
        return self.gamma * torch.tanh(self.alpha * x) + self.beta

In [52]:
class Encoder_Block(nn.Module): ## prenorm support
    def __init__(self, dim=768, num_heads=12, d_mha=0.1, d_ff=0.1, d_res=0.1, prenorm=True, **kwargs):
        super().__init__()
        ff_hidden_dim=dim*4
        self.prenorm = prenorm ## BERT uses postnorm

        self.mha = nn.MultiheadAttention(dim, num_heads, d_mha, batch_first=True)
        self.norm1 = DyT(dim, **kwargs)

        self.ff = nn.Sequential(
            nn.Linear(dim, ff_hidden_dim),
            nn.ReLU(),
            nn.Linear(ff_hidden_dim, dim),
            nn.Dropout(d_ff)
        )

        self.norm2 = DyT(dim, **kwargs)
        self.dropout = nn.Dropout(d_res)

    def forward(self, x, pad_mask=None):
        if self.prenorm:
            x = self.norm1(x)

        attn_output, _ = self.mha(x, x, x, key_padding_mask=pad_mask)
        x = x + self.dropout(attn_output)

        if not self.prenorm:
            x = self.norm1(x)

        if self.prenorm:
            x = self.norm2(x)

        ff_output = self.ff(x)
        x = x + self.dropout(ff_output)

        if not self.prenorm:
            x = self.norm2(x)

        return x

In [42]:
class Encoder(nn.Module):
    def __init__(self, num_layers=12, dim=768, prenorm=True, **kwargs):
        super().__init__()

        self.layers = nn.ModuleList([
            Encoder_Block(dim, prenorm=prenorm, **kwargs) for _ in range(num_layers)
        ])

        self.norm = DyT(dim)
        self.prenorm = prenorm

    def forward(self, x, pad_mask=None):
        for block in self.layers:
            x = block(x, pad_mask)
        x = self.norm(x)

        return x

In [43]:
summary(Encoder(num_layers=2).to('cuda'), (512, 768), device='cuda')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
               DyT-1             [-1, 512, 768]               0
MultiheadAttention-2  [[-1, 512, 768], [-1, 512, 512]]               0
           Dropout-3             [-1, 512, 768]               0
               DyT-4             [-1, 512, 768]               0
            Linear-5            [-1, 512, 3072]       2,362,368
              ReLU-6            [-1, 512, 3072]               0
            Linear-7             [-1, 512, 768]       2,360,064
           Dropout-8             [-1, 512, 768]               0
           Dropout-9             [-1, 512, 768]               0
    Encoder_Block-10             [-1, 512, 768]               0
              DyT-11             [-1, 512, 768]               0
MultiheadAttention-12  [[-1, 512, 768], [-1, 512, 512]]               0
          Dropout-13             [-1, 512, 768]               0
              DyT-14    

In [48]:
class MyBERT(nn.Module):
    def __init__(self, vocab_size, seq_len=512, dim=768, **kwargs):
        super().__init__()
        self.encoder = Encoder(dim=dim,**kwargs)
        self.token_embeddings = nn.Embedding(vocab_size, dim)
        self.segment_embeddings = nn.Embedding(2, dim)
        self.positional_embeddings = nn.Embedding(seq_len, dim)
        self.register_buffer("position_ids", torch.arange(seq_len).unsqueeze(0))
        
    def forward(self, x, segment, pad_mask=None):
        batch_size, seq_len = x.shape
        position_ids = self.position_ids.expand(batch_size, seq_len)
        x = self.token_embeddings(x) + self.segment_embeddings(segment) + self.positional_embeddings(position_ids)

        x = self.encoder(x, pad_mask)

        return x

In [54]:
bert = MyBERT(1000, dim=768, num_layers=4)

In [55]:
batch_size = 2
seq_len = 512

# Random input IDs between 0 and 999
x = torch.randint(0, 1000, (batch_size, seq_len))

# Dummy segment IDs (all 0s = single segment for now)
segment = torch.zeros_like(x, dtype=torch.long)

# Optional padding mask (1 = real token, 0 = pad)
pad_mask = torch.ones_like(x, dtype=torch.bool)  # full attention for now

# Forward pass
output = bert(x, segment, pad_mask)
print(output.shape)  # should be [2, 16, 768]

torch.Size([2, 512, 768])
