In [1]:
# Install necessary libraries if not already installed
!pip install torch transformers datasets

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from datasets import load_dataset
import math



Defaulting to user installation because normal site-packages is not writeable


In [2]:
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")



Using device: cuda


In [3]:
# Load the BookCorpus dataset from Hugging Face
dataset = load_dataset('bookcorpus', split='train')
print("Total samples:", len(dataset))


Total samples: 74004228


In [4]:
# Limit to 100k samples as specified
dataset = dataset.select(range(100000))
print("Subset size:", len(dataset))


Subset size: 100000


In [5]:
# Initialize the BERT tokenizer (WordPiece)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')


In [6]:
# Tokenization and Preprocessing Function
def preprocess(text):
    # Tokenize and encode input text
    encoding = tokenizer(text, 
                         add_special_tokens=True, 
                         truncation=True,
                         max_length=128, 
                         padding='max_length', 
                         return_tensors='pt')
    
    # Flatten tensor to (seq_length)
    input_ids = encoding['input_ids'].squeeze()
    attention_mask = encoding['attention_mask'].squeeze()
    
    # Create segment ids (all zeros for single sentence input)
    segment_ids = torch.zeros_like(input_ids)
    
    return input_ids, attention_mask, segment_ids


In [7]:
from torch.utils.data import Dataset

# Custom Dataset Class for BookCorpus
class BookCorpusDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        text = self.dataset[idx]['text']
        input_ids, attention_mask, segment_ids = preprocess(text)
        return input_ids, attention_mask, segment_ids


In [8]:
# Create DataLoader for 
train_dataset = BookCorpusDataset(dataset)
train_loader = DataLoader(train_dataset, batch_size= 8, shuffle=True)

# Sample check
for batch in train_loader:
    input_ids, attention_mask, segment_ids = batch
    print("Input IDs:", input_ids.shape)
    print("Attention Mask:", attention_mask.shape)
    print("Segment IDs:", segment_ids.shape)
    break


Input IDs: torch.Size([8, 128])
Attention Mask: torch.Size([8, 128])
Segment IDs: torch.Size([8, 128])


Tokenization and Masking for MLM

In [9]:
import random

def mask_tokens(input_ids, tokenizer, mask_prob=0.15):
    """
    Prepare masked tokens inputs/labels for masked language modeling.
    """
    labels = input_ids.clone()
    
    # Masking (15% of tokens)
    probability_matrix = torch.full(labels.shape, mask_prob)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100  # Only compute loss on masked tokens
    
    # Replace 80% of the time with [MASK]
    indices_replaced = (torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices)
    input_ids[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
    
    # Replace 10% of the time with random token
    indices_random = (torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced)
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long).to(input_ids.device)  # Move to the same device
    input_ids[indices_random] = random_words[indices_random]
    
    # 10% of the time, keep the original token (already done by default)
    return input_ids, labels



Implementation: BERTEmbedding



In [10]:
import torch
import torch.nn as nn

class BERTEmbedding(nn.Module):
    def __init__(self, vocab_size, hidden_dim, max_position_embeddings, segment_vocab_size=2, dropout_prob=0.1):
        super(BERTEmbedding, self).__init__()
        # Token Embeddings
        self.token_embeddings = nn.Embedding(vocab_size, hidden_dim)
        
        # Position Embeddings
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_dim)
        
        # Segment Embeddings (For sentence pairs in NLI tasks)
        self.segment_embeddings = nn.Embedding(segment_vocab_size, hidden_dim)
        
        # Layer Normalization and Dropout
        self.layer_norm = nn.LayerNorm(hidden_dim, eps=1e-12)
        self.dropout = nn.Dropout(dropout_prob)
    
    def forward(self, input_ids, segment_ids):
        seq_length = input_ids.size(1)
        
        # Position IDs [0, 1, 2, ..., seq_length-1]
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        
        # Get Embeddings
        token_embeddings = self.token_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        segment_embeddings = self.segment_embeddings(segment_ids)
        
        # Sum and Normalize
        embeddings = token_embeddings + position_embeddings + segment_embeddings
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        
        return embeddings


Implementation: MultiHeadSelfAttention



In [11]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads, dropout_prob=0.1):
        super(MultiHeadSelfAttention, self).__init__()
        assert hidden_dim % num_heads == 0
        
        # Parameters
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        
        # Linear layers for Q, K, V
        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.key = nn.Linear(hidden_dim, hidden_dim)
        self.value = nn.Linear(hidden_dim, hidden_dim)
        
        # Output projection
        self.out = nn.Linear(hidden_dim, hidden_dim)
        
        # Dropout for attention scores
        self.dropout = nn.Dropout(dropout_prob)
        
    def forward(self, hidden_states, attention_mask):
        batch_size, seq_length, hidden_dim = hidden_states.size()
        
        # Linear projections
        query = self.query(hidden_states)
        key = self.key(hidden_states)
        value = self.value(hidden_states)
        
        # Reshape to (batch_size, num_heads, seq_length, head_dim)
        query = query.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        key = key.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        value = value.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Scaled Dot-Product Attention
        attention_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attention_scores += attention_mask
        
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        attention_probs = self.dropout(attention_probs)
        
        context = torch.matmul(attention_probs, value)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_length, hidden_dim)
        
        output = self.out(context)
        
        return output


Implementation: FeedForward



In [12]:
class FeedForward(nn.Module):
    def __init__(self, hidden_dim, intermediate_dim, dropout_prob=0.1):
        super(FeedForward, self).__init__()
        self.dense1 = nn.Linear(hidden_dim, intermediate_dim)
        self.activation = nn.GELU()
        self.dense2 = nn.Linear(intermediate_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout_prob)
        self.layer_norm = nn.LayerNorm(hidden_dim, eps=1e-12)
    
    def forward(self, hidden_states):
        intermediate = self.activation(self.dense1(hidden_states))
        output = self.dense2(intermediate)
        output = self.dropout(output)
        output = self.layer_norm(output + hidden_states)  # Residual Connection
        return output


Implementation: BERTLayer



In [13]:
class BERTLayer(nn.Module):
    def __init__(self, hidden_dim, num_heads, intermediate_dim, dropout_prob=0.1):
        super(BERTLayer, self).__init__()
        self.attention = MultiHeadSelfAttention(hidden_dim, num_heads, dropout_prob)
        self.feed_forward = FeedForward(hidden_dim, intermediate_dim, dropout_prob)
        self.layer_norm = nn.LayerNorm(hidden_dim, eps=1e-12)
        self.dropout = nn.Dropout(dropout_prob)
    
    def forward(self, hidden_states, attention_mask):
        # Multi-Head Self Attention
        attention_output = self.attention(hidden_states, attention_mask)
        attention_output = self.dropout(attention_output)
        attention_output = self.layer_norm(attention_output + hidden_states)
        
        # Feed Forward Network
        layer_output = self.feed_forward(attention_output)
        
        return layer_output


In [14]:
class BERTModel(nn.Module):
    def __init__(self, vocab_size, hidden_dim=768, num_layers=12, num_heads=12, intermediate_dim=3072, max_position_embeddings=512, segment_vocab_size=2, dropout_prob=0.1):
        super(BERTModel, self).__init__()
        
        # Embedding Layer
        self.embedding = BERTEmbedding(vocab_size, hidden_dim, max_position_embeddings, segment_vocab_size, dropout_prob)
        
        # Stacking BERT Layers
        self.layers = nn.ModuleList([
            BERTLayer(hidden_dim, num_heads, intermediate_dim, dropout_prob) for _ in range(num_layers)
        ])
    
    def forward(self, input_ids, segment_ids, attention_mask):
        # Input Embedding
        hidden_states = self.embedding(input_ids, segment_ids)
        
        # Apply attention mask to ignore padding tokens
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        
        # Pass through each BERT Layer
        for layer in self.layers:
            hidden_states = layer(hidden_states, extended_attention_mask)
        
        return hidden_states


In [15]:
class MLMHead(nn.Module):
    def __init__(self, hidden_dim, vocab_size):
        super(MLMHead, self).__init__()
        self.dense = nn.Linear(hidden_dim, hidden_dim)
        self.activation = nn.GELU()
        self.layer_norm = nn.LayerNorm(hidden_dim, eps=1e-12)
        
        # Output layer
        self.decoder = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.activation(hidden_states)
        hidden_states = self.layer_norm(hidden_states)
        logits = self.decoder(hidden_states)
        
        return logits


In [16]:
class BERTForMaskedLM(nn.Module):
    def __init__(self, vocab_size, hidden_dim=768, num_layers=12, num_heads=12, intermediate_dim=3072, max_position_embeddings=512, segment_vocab_size=2, dropout_prob=0.1):
        super(BERTForMaskedLM, self).__init__()
        
        # Base BERT Model
        self.bert = BERTModel(vocab_size, hidden_dim, num_layers, num_heads, intermediate_dim, max_position_embeddings, segment_vocab_size, dropout_prob)
        
        # MLM Head
        self.mlm_head = MLMHead(hidden_dim, vocab_size)
    
    def forward(self, input_ids, segment_ids, attention_mask):
        # Forward pass through BERT
        hidden_states = self.bert(input_ids, segment_ids, attention_mask)
        
        # Get logits for masked language modeling
        logits = self.mlm_head(hidden_states)
        
        return logits


In [None]:
import torch.optim as optim
from torch.nn import CrossEntropyLoss

# Initialize model and move to GPU
vocab_size = tokenizer.vocab_size
model = BERTForMaskedLM(vocab_size).to(device)

# Optimizer (AdamW)
optimizer = optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)

# Loss Function (CrossEntropy for MLM)
criterion = CrossEntropyLoss()

# Training Loop
epochs = 3  # Feel free to adjust

model.train()  # Set model to training mode
for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    total_loss = 0
    
    for batch in train_loader:
        input_ids, attention_mask, segment_ids = [x.to(device) for x in batch]
        
        # Masking for MLM
        input_ids, labels = mask_tokens(input_ids, tokenizer)
        input_ids, labels = input_ids.to(device), labels.to(device)
        
        # Forward pass
        logits = model(input_ids, segment_ids, attention_mask)
        
        # Calculate loss
        loss = criterion(logits.view(-1, vocab_size), labels.view(-1))
        total_loss += loss.item()
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    avg_loss = total_loss / len(train_loader)
    print(f"Average Loss: {avg_loss:.4f}")




Epoch 1/3
Average Loss: 0.6857
Epoch 2/3
Average Loss: 0.5534
Epoch 3/3


In [None]:
# Save model weights
torch.save(model.state_dict(), 'bert_mlm_weights.pth')
print("Model weights saved.")