In [1]:
import torch
import torch.nn as nn
import math

In [3]:
class Embeddings(nn.Module):
    def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, dropout_prob, layer_norm_eps):
        super(Embeddings, self).__init__()
        self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size)
        self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
        self.dropout = nn.Dropout(dropout_prob)
    
    def forward(self, input_ids, token_type_ids=None, position_ids=None):
        if position_ids is None:
            position_ids = torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device)
            position_ids = position_ids.unsqueeze(0).expand(input_ids.size(0), -1)
        
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        
        embeddings = words_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

class SelfAttention(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, dropout_prob):
        super(SelfAttention, self).__init__()
        if hidden_size % num_attention_heads != 0:
            raise ValueError("Hidden size must be divisible by the number of attention heads")
        
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = hidden_size // num_attention_heads
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        
        self.query = nn.Linear(hidden_size, self.all_head_size)
        self.key = nn.Linear(hidden_size, self.all_head_size)
        self.value = nn.Linear(hidden_size, self.all_head_size)
        
        self.dropout = nn.Dropout(dropout_prob)
    
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)
    
    def forward(self, hidden_states, attention_mask=None):
        query_layer = self.query(hidden_states)
        key_layer = self.key(hidden_states)
        value_layer = self.value(hidden_states)
        
        query_layer = self.transpose_for_scores(query_layer)
        key_layer = self.transpose_for_scores(key_layer)
        value_layer = self.transpose_for_scores(value_layer)
        
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        
        if attention_mask is not None:
            attention_scores = attention_scores + attention_mask
        
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        attention_probs = self.dropout(attention_probs)
        
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        return context_layer

class SelfOutput(nn.Module):
    def __init__(self, hidden_size, dropout_prob, layer_norm_eps):
        super(SelfOutput, self).__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
        self.dropout = nn.Dropout(dropout_prob)
    
    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

class Intermediate(nn.Module):
    def __init__(self, hidden_size, intermediate_size):
        super(Intermediate, self).__init__()
        self.dense = nn.Linear(hidden_size, intermediate_size)
        self.intermediate_act_fn = nn.GELU()
    
    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states

class Output(nn.Module):
    def __init__(self, intermediate_size, hidden_size, dropout_prob, layer_norm_eps):
        super(Output, self).__init__()
        self.dense = nn.Linear(intermediate_size, hidden_size)
        self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
        self.dropout = nn.Dropout(dropout_prob)
    
    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

class Layer(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, intermediate_size, dropout_prob, layer_norm_eps):
        super(Layer, self).__init__()
        self.attention = SelfAttention(hidden_size, num_attention_heads, dropout_prob)
        self.attention_output = SelfOutput(hidden_size, dropout_prob, layer_norm_eps)
        self.intermediate = Intermediate(hidden_size, intermediate_size)
        self.output = Output(intermediate_size, hidden_size, dropout_prob, layer_norm_eps)
    
    def forward(self, hidden_states, attention_mask=None):
        attention_output = self.attention(hidden_states, attention_mask)
        attention_output = self.attention_output(attention_output, hidden_states)
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output

class Encoder(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, intermediate_size, num_hidden_layers, dropout_prob, layer_norm_eps):
        super(Encoder, self).__init__()
        self.layer = nn.ModuleList([Layer(hidden_size, num_attention_heads, intermediate_size, dropout_prob, layer_norm_eps) for _ in range(num_hidden_layers)])
    
    def forward(self, hidden_states, attention_mask=None):
        for layer_module in self.layer:
            hidden_states = layer_module(hidden_states, attention_mask)
        return hidden_states

class Pooler(nn.Module):
    def __init__(self, hidden_size):
        super(Pooler, self).__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.activation = nn.Tanh()
    
    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

class BERTModel(nn.Module):
    def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, num_attention_heads, intermediate_size, num_hidden_layers, dropout_prob, layer_norm_eps):
        super(BERTModel, self).__init__()
        self.embeddings = Embeddings(vocab_size, hidden_size, max_position_embeddings, type_vocab_size, dropout_prob, layer_norm_eps)
        self.encoder = Encoder(hidden_size, num_attention_heads, intermediate_size, num_hidden_layers, dropout_prob, layer_norm_eps)
        self.pooler = Pooler(hidden_size)
    
    def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
        embedding_output = self.embeddings(input_ids, token_type_ids, position_ids)
        encoder_output = self.encoder(embedding_output, attention_mask)
        pooled_output = self.pooler(encoder_output)
        return encoder_output, pooled_output




In [4]:
# Example usage
batch_size = 2
seq_length = 5
vocab_size = 30522
hidden_size = 768
max_position_embeddings = 512
type_vocab_size = 2
num_attention_heads = 12
intermediate_size = 3072
num_hidden_layers = 12
dropout_prob = 0.1
layer_norm_eps = 1e-12

# Create an instance of BERTModel
model = BERTModel(vocab_size, hidden_size, max_position_embeddings, type_vocab_size, num_attention_heads, intermediate_size, num_hidden_layers, dropout_prob, layer_norm_eps)

# Dummy input (random tensor)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_length))

# Get model output
encoder_output, pooled_output = model(input_ids)

print("Encoder Output Shape:", encoder_output.shape)  # Expected: [batch_size, seq_length, hidden_size]
print("Pooled Output Shape:", pooled_output.shape)    # Expected: [batch_size, hidden_size]


Encoder Output Shape: torch.Size([2, 5, 768])
Pooled Output Shape: torch.Size([2, 768])
