<a href="https://colab.research.google.com/github/NusRAT-LiA/BERT-Bidirectional-Encoder-Representations-from-Transformers/blob/main/bert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets --quiet

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/471.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m471.0/471.6 kB[0m [31m21.6 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m471.6/471.6 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━[0m [32m112.6/116.3 kB[0m [31m126.7 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/134.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90

In [None]:
from datasets import load_dataset
import torch.nn as nn
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import math
import random

In [None]:
class LayerNormalization(nn.Module):
    def __init__(self, features: int, eps:float=10**-6) -> None:
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(features)) # alpha is a learnable parameter
        self.bias = nn.Parameter(torch.zeros(features)) # bias is a learnable parameter

    def forward(self, x):
        # x: (batch, seq_len, hidden_size)
        # Keep the dimension for broadcasting
        mean = x.mean(dim = -1, keepdim = True) # (batch, seq_len, 1)
        # Keep the dimension for broadcasting
        std = x.std(dim = -1, keepdim = True) # (batch, seq_len, 1)
        # eps is to prevent dividing by zero or when std is very small
        return self.alpha * (x - mean) / (std + self.eps) + self.bias

class FeedForwardBlock(nn.Module):

    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff) # w1 and b1
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model) # w2 and b2

    def forward(self, x):
        # (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

class ResidualConnection(nn.Module):

    def __init__(self, features: int, dropout: float) -> None:
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization(features)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

class MultiHeadAttentionBlock(nn.Module):

    def __init__(self, d_model: int, h: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model # Embedding vector size
        self.h = h # Number of heads
        # Make sure d_model is divisible by h
        assert d_model % h == 0, "d_model is not divisible by h"

        self.d_k = d_model // h # Dimension of vector seen by each head
        self.w_q = nn.Linear(d_model, d_model, bias=False) # Wq
        self.w_k = nn.Linear(d_model, d_model, bias=False) # Wk
        self.w_v = nn.Linear(d_model, d_model, bias=False) # Wv
        self.w_o = nn.Linear(d_model, d_model, bias=False) # Wo
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        d_k = query.shape[-1]
        # Just apply the formula from the paper
        # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            # Write a very low value (indicating -inf) to the positions where mask == 0
            attention_scores.masked_fill_(mask == 0, -1e9)
        attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len) # Apply softmax
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        # (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
        # return attention scores which can be used for visualization
        return (attention_scores @ value), attention_scores

    def forward(self, q, k, v, mask):
        query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        key = self.w_k(k) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        value = self.w_v(v) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)

        # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)

        # Calculate attention
        x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)

        # Combine all the heads together
        # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)

        # Multiply by Wo
        # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        return self.w_o(x)

class EncoderBlock(nn.Module):

    def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])

    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x

class Encoder(nn.Module):

    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, mask):
        for i, layer in enumerate(self.layers):
            x = layer(x, mask)

        return self.norm(x)

In [None]:
class BERTEmbeddings(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, max_position_embeddings: int, segment_vocab_size: int, dropout: float):
        super().__init__()
        self.token_embeddings = nn.Embedding(vocab_size, d_model)
        self.position_embeddings = nn.Embedding(max_position_embeddings, d_model)
        self.segment_embeddings = nn.Embedding(segment_vocab_size, d_model)
        self.dropout = nn.Dropout(dropout)
        self.d_model = d_model

    def forward(self, input_ids, token_type_ids):
        seq_len = input_ids.size(1)
        position_ids = torch.arange(0, seq_len, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

        token_embeddings = self.token_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        segment_embeddings = self.segment_embeddings(token_type_ids)

        embeddings = token_embeddings + position_embeddings + segment_embeddings
        return self.dropout(embeddings * math.sqrt(self.d_model))


In [None]:
class BERTModel(nn.Module):
    def __init__(self, vocab_size: int, max_position_embeddings: int, segment_vocab_size: int, d_model: int=768, N: int=12, h: int=12, dropout: float=0.1, d_ff: int=3072):
        super().__init__()
        # Embeddings layer
        self.embeddings = BERTEmbeddings(vocab_size, d_model, max_position_embeddings, segment_vocab_size, dropout)

        # Encoder layers from your Transformer model
        encoder_blocks = []
        for _ in range(N):
            encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
            feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
            encoder_block = EncoderBlock(d_model, encoder_self_attention_block, feed_forward_block, dropout)
            encoder_blocks.append(encoder_block)

        self.encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))

    def forward(self, input_ids, token_type_ids, attention_mask=None):
        # Generate embeddings
        embeddings = self.embeddings(input_ids, token_type_ids)
        # Pass through the encoder layers
        encoded_output = self.encoder(embeddings, attention_mask)
        return encoded_output


In [None]:
class BERTPretrainingHeads(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        # For MLM, project the encoder output back to the vocabulary size
        self.mlm_head = nn.Linear(d_model, vocab_size)
        # For NSP, use a linear layer to classify sentence pairs
        self.nsp_head = nn.Linear(d_model, 2)
        # Apply layer normalization for stability
        self.norm = LayerNormalization(d_model)

    def forward(self, encoder_output: torch.Tensor, pooled_output: torch.Tensor):
        # MLM: Predict the masked tokens
        mlm_logits = self.mlm_head(encoder_output)

        # NSP: Predict whether the next sentence follows the first
        # pooled_output is the embedding of [CLS] token, typically the first token
        nsp_logits = self.nsp_head(pooled_output)

        return mlm_logits, nsp_logits


class BERTModelForPretraining(nn.Module):
    def __init__(self, vocab_size: int, max_position_embeddings: int, segment_vocab_size: int, d_model: int=768, N: int=12, h: int=12, dropout: float=0.1, d_ff: int=3072):
        super().__init__()
        # Initialize the base BERT model
        self.bert = BERTModel(vocab_size, max_position_embeddings, segment_vocab_size, d_model, N, h, dropout, d_ff)
        # Add the pretraining heads
        self.pretraining_heads = BERTPretrainingHeads(d_model, vocab_size)

    def forward(self, input_ids, token_type_ids, attention_mask=None):
        # Get the encoder output from BERT
        encoder_output = self.bert(input_ids, token_type_ids, attention_mask)
        # [CLS] token is used for NSP, typically the first token's representation
        pooled_output = encoder_output[:, 0]  # Take the embedding of [CLS]

        # Get MLM and NSP predictions
        mlm_logits, nsp_logits = self.pretraining_heads(encoder_output, pooled_output)

        return mlm_logits, nsp_logits


In [None]:
def compute_mlm_loss(mlm_logits, mlm_labels, ignore_index=-100):
    # mlm_labels should have -100 for tokens that should be ignored (i.e., not masked)
    loss_fct = nn.CrossEntropyLoss(ignore_index=ignore_index)
    mlm_loss = loss_fct(mlm_logits.view(-1, mlm_logits.size(-1)), mlm_labels.view(-1))
    return mlm_loss

def compute_nsp_loss(nsp_logits, nsp_labels):
    # nsp_labels should be binary (0 or 1) for whether the sentence is the next one
    loss_fct = nn.CrossEntropyLoss()
    nsp_loss = loss_fct(nsp_logits.view(-1, 2), nsp_labels.view(-1))
    return nsp_loss

def compute_pretraining_loss(mlm_logits, mlm_labels, nsp_logits, nsp_labels):
    # Compute both MLM and NSP losses
    mlm_loss = compute_mlm_loss(mlm_logits, mlm_labels)
    nsp_loss = compute_nsp_loss(nsp_logits, nsp_labels)
    # Return a sum or weighted sum of the losses
    total_loss = mlm_loss + nsp_loss
    return total_loss


In [None]:
def train_step(model, input_ids, token_type_ids, attention_mask, mlm_labels, nsp_labels, optimizer):
    # Forward pass through the model
    mlm_logits, nsp_logits = model(input_ids, token_type_ids, attention_mask)

    # Compute the loss
    loss = compute_pretraining_loss(mlm_logits, mlm_labels, nsp_logits, nsp_labels)

    # Backward pass and optimization
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    return loss.item()


In [None]:
# Assuming we have a data loader that provides:
# - input_ids: Tensor of shape (batch_size, seq_len)
# - token_type_ids: Tensor of shape (batch_size, seq_len)
# - attention_mask: Tensor of shape (batch_size, seq_len)
# - mlm_labels: Tensor of shape (batch_size, seq_len)
# - nsp_labels: Tensor of shape (batch_size)

def pretrain_bert(model, data_loader, num_epochs: int, learning_rate: float, device: torch.device):
    # Move model to the appropriate device (GPU or CPU)
    model.to(device)

    # Initialize the optimizer (Adam with weight decay is often used for BERT)
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

    # Training loop
    for epoch in range(num_epochs):
        model.train()  # Set model to training mode
        total_loss = 0  # Track total loss for each epoch

        for batch in data_loader:
            # Get the input data and move it to the device
            input_ids = batch['input_ids'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            mlm_labels = batch['mlm_labels'].to(device)
            nsp_labels = batch['nsp_labels'].to(device)

            # Forward pass
            mlm_logits, nsp_logits = model(input_ids, token_type_ids, attention_mask)

            # Compute loss for MLM and NSP
            loss = compute_pretraining_loss(mlm_logits, mlm_labels, nsp_logits, nsp_labels)

            # Backward pass and optimization
            optimizer.zero_grad()  # Zero out gradients before backward pass
            loss.backward()  # Backpropagate loss
            optimizer.step()  # Update model parameters

            total_loss += loss.item()  # Accumulate loss

        # Average loss for the epoch
        avg_loss = total_loss / len(data_loader)
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}")

    print("Pretraining complete!")


In [None]:
# Create a simple tokenizer without using transformers
def simple_tokenizer(text, vocab):
    tokens = text.split()
    return [vocab.get(token, vocab['[UNK]']) for token in tokens]

In [None]:
# Load the WikiText-2 dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")

# Build a basic vocabulary with common tokens, a [MASK] token, and others
def build_vocab(texts, vocab_size=30522):
    from collections import Counter
    counter = Counter()

    # Count the frequency of words in the dataset
    for text in texts:
        counter.update(text.lower().split())

    # Select the most common words for the vocabulary (excluding the reserved tokens)
    most_common_tokens = counter.most_common(vocab_size - 5)

    # Initialize the vocabulary with special tokens, including [UNK]
    vocab = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3, '[UNK]': 4}

    # Add the most common tokens to the vocabulary
    for idx, (word, _) in enumerate(most_common_tokens, 5):
        vocab[word] = idx

    return vocab

# Build the vocabulary using WikiText
texts = dataset["text"]
vocab = build_vocab(texts)

# Function to create inputs for MLM
def preprocess_wikitext(text, vocab, seq_len=128):
    tokens = simple_tokenizer(text, vocab)

    # Ensure tokens fit within the sequence length
    tokens = tokens[:seq_len - 2]  # reserve 2 spots for [CLS] and [SEP]

    # Add special tokens [CLS] and [SEP]
    tokens = [vocab['[CLS]']] + tokens + [vocab['[SEP]']]

    # Create input ids, attention mask, and token type ids
    input_ids = tokens + [vocab['[PAD]']] * (seq_len - len(tokens))  # Padding to fixed length
    attention_mask = [1] * len(tokens) + [0] * (seq_len - len(tokens))  # 1 for real tokens, 0 for padding
    token_type_ids = [0] * seq_len  # Single sentence example, no need for type distinction

    # Create MLM labels by randomly masking 15% of the tokens
    mlm_input_ids = input_ids.copy()
    mlm_labels = [-100] * len(input_ids)  # Initialize MLM labels to -100 (ignore)

    for i in range(1, len(tokens) - 1):  # Ignore [CLS] and [SEP] for masking
        if random.random() < 0.15:
            mlm_labels[i] = input_ids[i]  # Store the original token
            mlm_input_ids[i] = vocab['[MASK]']  # Replace token with [MASK]

    return {
        "input_ids": torch.tensor(input_ids),
        "mlm_input_ids": torch.tensor(mlm_input_ids),
        "attention_mask": torch.tensor(attention_mask),
        "token_type_ids": torch.tensor(token_type_ids),
        "mlm_labels": torch.tensor(mlm_labels)
    }

# Preprocess the dataset (a subset for demonstration)
processed_dataset = [preprocess_wikitext(text, vocab) for text in texts[:1000]]

batch_size = 8
data_loader = DataLoader(processed_dataset, batch_size=batch_size, shuffle=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/733k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [None]:
# Assuming the BERT model for pretraining from the previous step is already defined
model = BERTModelForPretraining(
    vocab_size=len(vocab),  # Use the size of the custom vocabulary
    max_position_embeddings=512,
    segment_vocab_size=2,
    d_model=768,
    N=12,  # Number of transformer encoder layers
    h=12,  # Number of attention heads
    dropout=0.1,
    d_ff=3072
)

# Move the model to the appropriate device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Example training loop using the dataset and DataLoader
def pretrain_bert_with_dataloader(model, data_loader, num_epochs: int, learning_rate: float, device: torch.device):
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for batch in data_loader:
            # Move batch data to device
            input_ids = batch['mlm_input_ids'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            mlm_labels = batch['mlm_labels'].to(device)

            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

            # Forward pass
            mlm_logits, nsp_logits = model(input_ids, token_type_ids, attention_mask)

            # Compute pretraining loss
            loss = compute_mlm_loss(mlm_logits, mlm_labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(data_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

    print("Pretraining complete!")

# Run pretraining with the data loader and model
pretrain_bert_with_dataloader(model, data_loader, num_epochs=1, learning_rate=5e-5, device=device)


Epoch 1/1, Loss: 7.7380
Pretraining complete!


In [None]:
# Function to tokenize and prepare the input for inference
def prepare_inference_input(sentence, vocab, seq_len=128):
    tokens = simple_tokenizer(sentence, vocab)

    # Ensure tokens fit within the sequence length
    tokens = tokens[:seq_len - 2]  # Reserve space for [CLS] and [SEP]

    # Add special tokens [CLS] and [SEP]
    tokens = [vocab['[CLS]']] + tokens + [vocab['[SEP]']]

    # Create input ids and attention mask
    input_ids = tokens + [vocab['[PAD]']] * (seq_len - len(tokens))  # Pad to seq_len
    attention_mask = [1] * len(tokens) + [0] * (seq_len - len(tokens))  # 1 for real tokens, 0 for padding
    token_type_ids = [0] * seq_len  # Single sentence example, no need for type distinction

    return {
        "input_ids": torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device),  # Add batch dim
        "attention_mask": torch.tensor(attention_mask, dtype=torch.long).unsqueeze(0).to(device),  # Batch dim
        "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long).unsqueeze(0).to(device)  # Batch dim
    }

In [None]:
# Function to predict the masked words
def predict_masked_words(sentence, vocab, model, seq_len=128):
    # Prepare the input for inference
    inputs = prepare_inference_input(sentence, vocab, seq_len)

    # Forward pass through the model
    with torch.no_grad():  # Disable gradient computation for inference
        mlm_logits, _ = model(inputs["input_ids"], inputs["token_type_ids"], inputs["attention_mask"])

    # Check for the mask token in the input
    mask_token_id = vocab['[MASK]']

    # Find the indices of the masked token
    masked_indices = (inputs["input_ids"] == mask_token_id).nonzero(as_tuple=True)[1]

    if len(masked_indices) == 0:
        print("No [MASK] token found in the input.")
        return []

    # Predict the most likely token for each masked position
    predicted_tokens = []
    for idx in masked_indices:
        logits_for_mask = mlm_logits[0, idx]  # Get logits for the masked token
        predicted_token_id = torch.argmax(logits_for_mask).item()  # Get the most probable token id
        predicted_tokens.append(predicted_token_id)

    # Convert token IDs back to words
    id_to_word = {v: k for k, v in vocab.items()}  # Reverse vocab
    predicted_words = [id_to_word[token_id] for token_id in predicted_tokens]

    return predicted_words

In [None]:
model.eval()

sentence = "test [MASK]"
predicted_words = predict_masked_words(sentence, vocab, model)
print(f"Predicted words for masked positions: {predicted_words}")

Predicted words for masked positions: ['[UNK]']
