In [None]:
import torch
import torch.nn as nn

# Work on showing how the BERT model works and how it was trained

In [None]:
class BERTEmbedding(nn.Module):
    def __init__(self, vocab_size, num_segments, max_sequence_length, hidden_size, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, hidden_size)  # (30_000, 768) Uses Wordpiece Embedding
        self.positional_embedding = nn.Embedding(max_sequence_length, hidden_size)  # (512, 768) Positional Embedding max_length is the context window size, model won't be able to see beyond this
        self.segment_embedding = nn.Embedding(num_segments, hidden_size)  # (3, 768) Segment Embedding for distinguishing between sentences [CLS], [SEP], [PAD]

        self.dropout = nn.Dropout(dropout)
        self.positional_input = torch.arange(max_sequence_length).unsqueeze(0)

    def forward(self, input_ids, segment_ids):
        embeddings = (
            self.token_embedding(input_ids)
            + self.positional_embedding(self.positional_input)
            + self.segment_embedding(segment_ids)
        )
        return self.dropout(embeddings)

In [None]:
class BERTTransformerModel(nn.Module):
    def __init__(
        self,
        vocab_size,
        num_segments,
        max_sequence_length,
        hidden_size,
        num_attention_heads,
        num_layers,
        dropout=0.1,
    ):
        super().__init__()
        self.embedding = BERTEmbedding(
            vocab_size, num_segments, max_sequence_length, hidden_size, dropout=dropout
        )
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(hidden_size, num_attention_heads, dropout=dropout),
            num_layers,
        )
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, input_ids, segment_ids):
        x = self.embedding(input_ids, segment_ids)
        x = self.encoder(x)
        x = self.fc(x[:, 0])
        return x