In [None]:
import torch
import torch.nn as nn

# Work on showing how the BERT model works and how it was trained

BERT's masked pretraining, often referred to as the cloze procedure, is a pivotal aspect of its training methodology. In this phase, a random subset of words in each training instance is replaced with a special [MASK] token. The model is then tasked with predicting the original words that were masked out, turning the training into a masked language model (MLM) objective. This approach is instrumental in training a bidirectional understanding of context, as the model must consider both preceding and succeeding words to accurately predict the masked tokens. Additionally, 15% of the words are randomly chosen for masking, and some are kept unchanged to avoid the model overfitting to the [MASK] token. This strategy ensures that the model learns a robust contextual representation of words, capturing intricate semantic relationships within sentences.

BERT incorporates several ingenious tricks to enhance its training efficiency. One such technique is the use of segment embeddings, where each token is assigned a segment ID to distinguish between sentences in the input. This allows the model to understand the relationships between tokens in different segments, reinforcing its grasp of sentence-level context. Another key trick involves the use of a positional embedding to convey the order of words in a sentence, as BERT doesn't inherently account for word order. These techniques contribute to the creation of rich, context-aware representations during masked pretraining.

During the fine-tuning phase, the [MASK] tokens play a crucial role in adapting BERT to specific tasks. In the task-specific datasets, a small fraction of the tokens is masked, and the model is fine-tuned to predict these masked tokens, similar to the pretraining phase. However, in fine-tuning, only a fraction of the masked tokens is replaced with the [MASK] token, while the rest are replaced with the actual words. This approach prevents the model from solely relying on the [MASK] token during fine-tuning, ensuring that it retains a nuanced understanding of the task-specific data.

In summary, BERT's training methodology involves the cloze procedure during masked pretraining, where [MASK] tokens are used for predicting masked words bidirectionally. Ingenious tricks like segment embeddings and positional embeddings enhance the model's contextual understanding. During fine-tuning, the [MASK] tokens continue to be crucial, but with modifications to prevent over-reliance on them. These strategies collectively contribute to BERT's remarkable ability to capture intricate contextual relationships in natural language, making it a pioneering model in the realm of natural language processing.

In [None]:
class BERTEmbedding(nn.Module):
    def __init__(self, vocab_size, num_segments, max_sequence_length, hidden_size, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, hidden_size)  # (30_000, 768) Uses Wordpiece Embedding
        self.positional_embedding = nn.Embedding(max_sequence_length, hidden_size)  # (512, 768) Positional Embedding max_length is the context window size, model won't be able to see beyond this
        self.segment_embedding = nn.Embedding(num_segments, hidden_size)  # (3, 768) Segment Embedding for distinguishing between sentences [CLS], [SEP], [PAD]

        self.dropout = nn.Dropout(dropout)
        self.positional_input = torch.arange(max_sequence_length).unsqueeze(0)

    def forward(self, input_ids, segment_ids):
        embeddings = (
            self.token_embedding(input_ids)
            + self.positional_embedding(self.positional_input)
            + self.segment_embedding(segment_ids)
        )
        return self.dropout(embeddings)

In [None]:
class BERTTransformerModel(nn.Module):
    def __init__(
        self,
        vocab_size,
        num_segments,
        max_sequence_length,
        hidden_size,
        num_attention_heads,
        num_layers,
        dropout=0.1,
    ):
        super().__init__()
        self.embedding = BERTEmbedding(
            vocab_size, num_segments, max_sequence_length, hidden_size, dropout=dropout
        )
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(hidden_size, num_attention_heads, dropout=dropout),
            num_layers,
        )
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, input_ids, segment_ids):
        x = self.embedding(input_ids, segment_ids)
        x = self.encoder(x)
        x = self.fc(x[:, 0])
        return x