In [2]:
import torch.nn as nn
%run 01_embedding.ipynb
%run encoder.ipynb

BERTEmbedding module initialized:
BERTEmbedding(
  (token_embeddings): Embedding(30000, 768)
  (position_embeddings): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (segment_embeddings): Embedding(2, 768)
  (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

--- Running a forward pass ---
Shape of input_ids: torch.Size([4, 128])
Shape of segment_ids: torch.Size([4, 128])

--- Verifying the output ---
Shape of output embeddings: torch.Size([4, 128, 768])
Expected output shape: (4, 128, 768)

✅ Test passed! The output shape is correct.


In [3]:
class BertModel(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, max_len, dropout_prob: float = 0.01, num_layers: int = 12):
        super().__init__()
        self.embedding = BERTEmbedding(vocab_size = vocab_size, embed_size = d_model, max_len = max_len, dropout_prob = dropout_prob)
        self.encoder_layers = nn.ModuleList(
            [TransformerEncoderBlock(
                d_model = d_model,
                num_heads = num_heads,
                d_ff = d_ff,
                dropout_prob = dropout_prob,
            ) for _ in range(num_layers)
            ]
        )
    
    def forward(self, input_ids, segment_ids, attention_mask =  None):
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        
        x = self.embedding(input_ids = input_ids, segment_ids = segment_ids)

        for layer in self.encoder_layers:
            x = layer(x, mask = attention_mask)
        
        return x

In [None]:
class BertForPreTraining(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, max_len, dropout_prob):
        super().__init__()
        self.bert = BertModel(
            vocab_size = vocab_size,
            d_model = d_model,
            num_layers = num_layers,
            num_heads = num_heads,
            d_ff = d_ff,
            max_len = max_len,
            dropout_prob = dropout_prob,
        )

        self.nsp_head = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.Tanh(),
            nn.Linear(d_model, 2),
        )

        self.mlm_head = nn.Linear(d_model, vocab_size)
    
    def forward(self, input_ids, segment_ids, attention_mask = None):
        sequence_output = self.bert(input_ids, segment_ids, attention_mask)
        cls_token_output = sequence_output[:, 0]
        nsp_logits = self.nsp_head(cls_token_output)
        mlm_logits = self.mlm_head(sequence_output)
        return mlm_logits, nsp_logits