In [7]:
import numpy as np
import torch
import torch.nn as nn
import math

In [20]:
class PositionalEncoding(nn.Module):
    def __init__(self, max_len, d_model, dropout_prob: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(p = dropout_prob)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0)/d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [25]:
class BERTEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_size, max_len, dropout_prob):
        super().__init__()
        self.token_embeddings = nn.Embedding(vocab_size, embed_size)
        self.position_embeddings = PositionalEncoding(d_model= embed_size, max_len = max_len, dropout_prob = dropout_prob)
        self.segment_embeddings = nn.Embedding(2, embed_size)
        self.norm = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout_prob)
    
    def forward(self, input_ids, segment_ids):
        device = input_ids.device
        if segment_ids is None:
            segment_ids = torch.zeros_like(input_ids, device = device)

        token_embs = self.token_embeddings(input_ids)
        segment_embs = self.segment_embeddings(segment_ids)
        base_embeddings = token_embs + segment_embs

        final_embeddings = self.position_embeddings(base_embeddings)


        normalized_embeddings = self.norm(final_embeddings)
        final_output = self.dropout(normalized_embeddings)

        return final_output


In [26]:
if __name__ == '__main__':
    # --- Configuration for the test ---
    VOCAB_SIZE = 30000  # Size of a typical vocabulary
    EMBED_SIZE = 768    # Dimension for BERT-base
    MAX_LEN = 512       # Max sequence length for BERT
    BATCH_SIZE = 4      # Number of sequences in a batch
    SEQ_LENGTH = 128    # Length of the example sequences

    # --- Instantiate the model ---
    embedding_layer = BERTEmbedding(
        vocab_size=VOCAB_SIZE,
        embed_size=EMBED_SIZE,
        max_len=MAX_LEN,
        dropout_prob=0.1
    )
    
    print("BERTEmbedding module initialized:")
    print(embedding_layer)

    # --- Create dummy input data ---
    # `input_ids` are random integers from 0 to VOCAB_SIZE-1
    # Shape: (BATCH_SIZE, SEQ_LENGTH)
    dummy_input_ids = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LENGTH))

    # `segment_ids` are 0s and 1s
    # Shape: (BATCH_SIZE, SEQ_LENGTH)
    dummy_segment_ids = torch.zeros(BATCH_SIZE, SEQ_LENGTH, dtype=torch.long)
    # Let's make the second half of each sequence belong to segment 1
    dummy_segment_ids[:, SEQ_LENGTH // 2:] = 1

    print("\n--- Running a forward pass ---")
    print(f"Shape of input_ids: {dummy_input_ids.shape}")
    print(f"Shape of segment_ids: {dummy_segment_ids.shape}")

    # --- Get the output ---
    output_embeddings = embedding_layer(dummy_input_ids, dummy_segment_ids)
    
    # --- Verify the output ---
    print("\n--- Verifying the output ---")
    print(f"Shape of output embeddings: {output_embeddings.shape}")
    print(f"Expected output shape: {(BATCH_SIZE, SEQ_LENGTH, EMBED_SIZE)}")
    
    # Check if the shape is correct
    assert output_embeddings.shape == (BATCH_SIZE, SEQ_LENGTH, EMBED_SIZE)
    
    print("\n✅ Test passed! The output shape is correct.")

BERTEmbedding module initialized:
BERTEmbedding(
  (token_embeddings): Embedding(30000, 768)
  (position_embeddings): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (segment_embeddings): Embedding(2, 768)
  (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

--- Running a forward pass ---
Shape of input_ids: torch.Size([4, 128])
Shape of segment_ids: torch.Size([4, 128])

--- Verifying the output ---
Shape of output embeddings: torch.Size([4, 128, 768])
Expected output shape: (4, 128, 768)

✅ Test passed! The output shape is correct.
