In [6]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset

# Work on showing how the BERT model works and how it was trained

BERT's masked pretraining, often referred to as the cloze procedure, is a pivotal aspect of its training methodology. In this phase, a random subset of words in each training instance is replaced with a special [MASK] token. The model is then tasked with predicting the original words that were masked out, turning the training into a masked language model (MLM) objective. This approach is instrumental in training a bidirectional understanding of context, as the model must consider both preceding and succeeding words to accurately predict the masked tokens. Additionally, 15% of the words are randomly chosen for masking, and some are kept unchanged to avoid the model overfitting to the [MASK] token. This strategy ensures that the model learns a robust contextual representation of words, capturing intricate semantic relationships within sentences.

BERT incorporates several ingenious tricks to enhance its training efficiency. One such technique is the use of segment embeddings, where each token is assigned a segment ID to distinguish between sentences in the input. This allows the model to understand the relationships between tokens in different segments, reinforcing its grasp of sentence-level context. Another key trick involves the use of a positional embedding to convey the order of words in a sentence, as BERT doesn't inherently account for word order. These techniques contribute to the creation of rich, context-aware representations during masked pretraining.

During the fine-tuning phase, the [MASK] tokens play a crucial role in adapting BERT to specific tasks. In the task-specific datasets, a small fraction of the tokens is masked, and the model is fine-tuned to predict these masked tokens, similar to the pretraining phase. However, in fine-tuning, only a fraction of the masked tokens is replaced with the [MASK] token, while the rest are replaced with the actual words. This approach prevents the model from solely relying on the [MASK] token during fine-tuning, ensuring that it retains a nuanced understanding of the task-specific data.

In summary, BERT's training methodology involves the cloze procedure during masked pretraining, where [MASK] tokens are used for predicting masked words bidirectionally. Ingenious tricks like segment embeddings and positional embeddings enhance the model's contextual understanding. During fine-tuning, the [MASK] tokens continue to be crucial, but with modifications to prevent over-reliance on them. These strategies collectively contribute to BERT's remarkable ability to capture intricate contextual relationships in natural language, making it a pioneering model in the realm of natural language processing.

In [2]:
class BERTEmbedding(nn.Module):
    """
    BERT Embedding module that combines token, positional, and segment embeddings.
    """

    def __init__(self, vocab_size: int, num_segments: int, max_sequence_length: int,
                 hidden_size: int, dropout: float = 0.1):
        """
        Initialize BERT Embedding module.

        Args:
            vocab_size (int): Size of the vocabulary.
            num_segments (int): Number of segments for segment embeddings.
            max_sequence_length (int): Maximum sequence length.
            hidden_size (int): Size of the hidden layer.
            dropout (float): Dropout rate (default: 0.1).
        """
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, hidden_size) # (30_000, 768) Uses Wordpiece Embedding
        self.positional_embedding = nn.Embedding(max_sequence_length, hidden_size) # (512, 768) Positional Embedding max_length is the context window size, model won't be able to see beyond this
        self.segment_embedding = nn.Embedding(num_segments, hidden_size) # (3, 768) Segment Embedding for distinguishing between sentences [CLS], [SEP], [PAD]
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("positional_input", torch.arange(max_sequence_length).unsqueeze(0))

    def forward(self, input_ids: torch.Tensor, segment_ids: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the BERT Embedding module.

        Args:
            input_ids (torch.Tensor): Input token IDs.
            segment_ids (torch.Tensor): Segment IDs for distinguishing different segments.

        Returns:
            torch.Tensor: Combined embeddings.
        """
        embeddings = (self.token_embedding(input_ids)
                      + self.positional_embedding(self.positional_input)
                      + self.segment_embedding(segment_ids))
        return self.dropout(embeddings)

In [3]:
class BERTTransformerModel(nn.Module):
    """
    BERT Transformer model for sequence classification.
    """

    def __init__(self, vocab_size: int, num_segments: int, max_sequence_length: int,
                 hidden_size: int, num_attention_heads: int, num_layers: int,
                 dropout: float = 0.1):
        """
        Initialize BERT Transformer model.

        Args:
            vocab_size (int): Size of the vocabulary.
            num_segments (int): Number of segments for segment embeddings.
            max_sequence_length (int): Maximum sequence length.
            hidden_size (int): Size of the hidden layer.
            num_attention_heads (int): Number of attention heads.
            num_layers (int): Number of transformer layers.
            dropout (float): Dropout rate (default: 0.1).
        """
        super().__init__()
        self.embedding = BERTEmbedding(vocab_size, num_segments, max_sequence_length,
                                       hidden_size, dropout=dropout)
        encoder_layer = nn.TransformerEncoderLayer(hidden_size, num_attention_heads,
                                                   dropout=dropout)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, input_ids: torch.Tensor, segment_ids: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the BERT Transformer model.

        Args:
            input_ids (torch.Tensor): Input token IDs.
            segment_ids (torch.Tensor): Segment IDs for distinguishing different segments.

        Returns:
            torch.Tensor: Output logits.
        """
        x = self.embedding(input_ids, segment_ids)
        x = self.encoder(x)
        x = self.fc(x[:, 0])  # Assuming classification based on the [CLS] token
        return x

In [4]:
# Example data (replace with your actual data)
input_ids = torch.tensor([[1, 2, 3, 4, 0], [1, 2, 3, 0, 0]])  # Input token IDs
segment_ids = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])  # Segment IDs
labels = torch.tensor([1, 0])  # Example labels (binary classification)

In [None]:
# Hyperparameters
vocab_size = 30000
num_segments = 3
max_sequence_length = 5
hidden_size = 768
num_attention_heads = 12
num_layers = 6
dropout = 0.1
learning_rate = 1e-4
num_epochs = 10

# Instantiate model
model = BERTTransformerModel(vocab_size, num_segments, max_sequence_length, hidden_size,
                             num_attention_heads, num_layers, dropout)

model(input_ids, segment_ids)  # Forward pass

In [8]:
# Create DataLoader
dataset = TensorDataset(input_ids, segment_ids, labels)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# Instantiate model
model = BERTTransformerModel(vocab_size, num_segments, max_sequence_length, hidden_size,
                             num_attention_heads, num_layers, dropout)

# Loss function and optimizer
criterion = nn.BCEWithLogitsLoss()  # Binary cross-entropy loss with logits
optimizer = Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    total_loss = 0.0
    for input_ids_batch, segment_ids_batch, labels_batch in dataloader:
        optimizer.zero_grad()
        outputs = model(input_ids_batch, segment_ids_batch)
        loss = criterion(outputs.squeeze(), labels_batch.float())  # Squeeze to remove extra dimension
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader)}")



Epoch 1, Loss: 0.7733625173568726
Epoch 2, Loss: 1.1450214385986328
Epoch 3, Loss: 0.7164051532745361
Epoch 4, Loss: 0.9264397621154785
Epoch 5, Loss: 0.6959352493286133
Epoch 6, Loss: 0.7335981130599976
Epoch 7, Loss: 0.7223803997039795
Epoch 8, Loss: 0.7208968997001648
Epoch 9, Loss: 0.6882054805755615
Epoch 10, Loss: 0.7581155896186829
