# Unsupervised Pre-training with Code Examples


This notebook demonstrates the unsupervised pre-training process for a BERT model using **Masked Language Modeling (MLM)**. The steps include loading data, tokenization, masking, model training, and saving the pre-trained model.


In [1]:

from transformers import BertTokenizer, BertForMaskedLM, AdamW
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm


## 1. Data Preparation

In [2]:

# Load dataset (e.g., WikiText dataset)
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Tokenize data
def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])


Map:   0%|          | 0/1801350 [00:00<?, ? examples/s]

## 2. Masking for MLM

In [3]:
# Create masked inputs for MLM
def mask_tokens(inputs, tokenizer, mlm_probability=0.15):
    labels = inputs.clone()
    probability_matrix = torch.full(labels.shape, mlm_probability)
    special_tokens_mask = [
        tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
    ]
    probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100  # Exclude from loss computation

    inputs[masked_indices] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
    return inputs, labels

# Prepare input data
inputs = torch.tensor(tokenized_dataset["input_ids"])
inputs, labels = mask_tokens(inputs, tokenizer)

## 3. Define the Model

In [4]:
# Load pre-trained BERT model for Masked Language Modeling
model = BertForMaskedLM.from_pretrained("bert-base-uncased")

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another archite

## 4. Train the Model

In [12]:
train_dataloader = DataLoader(inputs, batch_size=32, shuffle=True)
loop = tqdm(train_dataloader, leave=True)
for batch in loop:
    print(batch.shape)
    break

  0%|          | 0/56293 [00:00<?, ?it/s]

torch.Size([32, 512])





IndexError: too many indices for tensor of dimension 2

In [5]:
# DataLoader setup
train_dataloader = DataLoader(inputs, batch_size=32, shuffle=True)

# Optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Training loop
model.train()
for epoch in range(3):  # 3 epochs
    loop = tqdm(train_dataloader, leave=True)
    for batch in loop:
        input_ids, labels = batch
        input_ids = input_ids.to('mps')
        labels = labels.to('mps')

        outputs = model(input_ids, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loop.set_description(f"Epoch {epoch}")
        loop.set_postfix(loss=loss.item())

  0%|          | 0/56293 [00:00<?, ?it/s]


ValueError: too many values to unpack (expected 2)

## 5. Save the Model

In [None]:

# Save the trained model and tokenizer
model.save_pretrained("./pretrained_bert")
tokenizer.save_pretrained("./pretrained_bert")


## 6. Validate the Model

In [None]:

from sklearn.metrics import accuracy_score

# Prepare validation data (reusing tokenized dataset)
validation_inputs = torch.tensor(tokenized_dataset["input_ids"][:1000])  # Using a subset for validation
validation_inputs, validation_labels = mask_tokens(validation_inputs, tokenizer)

# Validation loop
def validate_model(model, validation_inputs, validation_labels, batch_size=32):
    model.eval()
    dataloader = DataLoader(list(zip(validation_inputs, validation_labels)), batch_size=batch_size)
    predictions, true_labels = [], []
    with torch.no_grad():
        for batch in dataloader:
            input_ids, labels = batch
            input_ids = input_ids.to("mps")
            labels = labels.to("mps")
            outputs = model(input_ids)
            logits = outputs.logits
            pred_labels = torch.argmax(logits, dim=-1)
            
            # Collect predictions and true labels
            for pred, label in zip(pred_labels, labels):
                predictions.extend(pred[label != -100].tolist())
                true_labels.extend(label[label != -100].tolist())
    
    accuracy = accuracy_score(true_labels, predictions)
    return accuracy

# Calculate validation accuracy
validation_accuracy = validate_model(model, validation_inputs, validation_labels)
print(f"Validation Accuracy: {validation_accuracy:.4f}")


## 7. Visualize Training Loss and Validation Accuracy

In [None]:

import matplotlib.pyplot as plt

# Simulated training loss and validation accuracy for demonstration
epochs = [1, 2, 3]
training_loss = [1.2, 0.9, 0.7]
validation_accuracies = [0.65, 0.72, 0.78]

# Plotting
plt.figure(figsize=(10, 5))

# Loss
plt.subplot(1, 2, 1)
plt.plot(epochs, training_loss, marker='o', label="Training Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training Loss Over Epochs")
plt.legend()

# Accuracy
plt.subplot(1, 2, 2)
plt.plot(epochs, validation_accuracies, marker='o', label="Validation Accuracy", color="green")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.title("Validation Accuracy Over Epochs")
plt.legend()

plt.tight_layout()
plt.show()
