# This notebook will pretrain a BERT model on the WikiText-2 dataset

## Let's first load the WikiText-2 dataset

In [None]:
#let's get the data..let's try to use the Hugging Face API

from datasets import load_dataset

# Load the wikitext-2-raw-v1 configuration of the wikitext dataset
raw_datasets = load_dataset("wikitext", "wikitext-2-raw-v1")

# The 'raw_datasets' object will now contain the dataset, typically split into
# 'train', 'validation', and 'test' splits.
# You can access them like this:
train_data = raw_datasets["train"]
validation_data = raw_datasets["validation"]
test_data = raw_datasets["test"]

print(f"Number of training examples: {len(train_data)}")
print(f"Number of validation examples: {len(validation_data)}")
print(f"Number of test examples: {len(test_data)}")

## Let's not create a custom Dataset and Collate function for dynamic padding

In [537]:
import random
from transformers import AutoTokenizer, DataCollatorForLanguageModeling
import torch

checkpoint = "google-bert/bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, data, max_length):
        self.tokenizer=tokenizer
        self.max_length = max_length
        
        L = [p["text"] for p in data]
        self.NSP_data = []

        for paragraph in L:
            cand_sentences = paragraph.split(' . ')
            number_of_sentences = len(cand_sentences)
            if number_of_sentences<2:
                continue
            else:
                for i in range(number_of_sentences-1):
                    if random.random()<0.5:
                        self.NSP_data.append(((cand_sentences[i], cand_sentences[i+1]), 1))
                else:
                    repla_sent = random.choice(L)
                    repla_sent = repla_sent.split(' . ')
                    repla_sent = random.choice(repla_sent)
                    self.NSP_data.append(((cand_sentences[i], repla_sent), 0))


    
    def __getitem__(self, idx):
        pair, label = self.NSP_data[idx]
        return self.tokenizer(*pair, truncation=True, max_length=self.max_length), label
        
    def __len__(self):
        return len(self.NSP_data)


data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True)

def collate_fn(batch):
    tokens = [item[0] for item in batch]
    
    tokens = data_collator(tokens)  #output will be {"input_ids" : ? , "token_type_ids" : ?, "attention_mask" : ?, "labels" : ?}


    mask = (tokens["labels"] == -100)

    mlm_pred_positions = tokens["labels"].clone()
    mlm_pred_positions[mask] = 0
    mlm_pred_positions_batch, mlm_pred_positions_pos = torch.nonzero(mlm_pred_positions, as_tuple=True)
    mlm_preds = torch.nonzero(mlm_pred_positions)
    tokens["labels"] = tokens["labels"][mlm_preds[:,0], mlm_preds[:,1]]
    

    #swap values 1<->0 in attention_mask as nn.TransformerEncoderLayer uses a different convention
    mask = (tokens["attention_mask"]==1)
    tokens["attention_mask"][mask] = 0
    tokens["attention_mask"][~mask] = 1
    tokens["attention_mask"] = tokens["attention_mask"].bool()
            
    nsp_labels = [item[1] for item in batch]
    
    #nsp_labels = torch.LongTensor(nsp_labels)
    X = {k:v for k, v in tokens.items() if k != "labels"}  #X excludes "labels", as this concerns more the output.
    X["labels_batch"] = mlm_pred_positions_batch
    X["labels_positions"] = mlm_pred_positions_pos
    #y = (tokens["labels"], nsp_labels)
    y = (tokens["labels"], torch.zeros(len(batch), dtype=torch.int64))
    return X,y

## Let's now build the BERT model and custom criterion used

In [606]:
import torch.nn as nn
import math

#custom model
class BertPretraining(nn.Module):
    def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads, num_blks, dropout, max_len):
        super().__init__()
        self.num_hiddens=num_hiddens
        self.max_len = max_len
        self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
        self.segment_embedding = nn.Embedding(2, num_hiddens)
        self.pos_embedding = nn.Parameter(torch.randn(1,max_len, num_hiddens))
        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module(f"{i}", nn.TransformerEncoderLayer(d_model=num_hiddens, 
                                                                   nhead=num_heads,
                                                                   dim_feedforward=ffn_num_hiddens,
                                                                   dropout=dropout, 
                                                                   batch_first=True))
            
        self.mlp = nn.Sequential(nn.LazyLinear(num_hiddens), nn.ReLU(), nn.LayerNorm(num_hiddens), nn.LazyLinear(vocab_size))

        self.nsp = nn.Sequential(nn.LazyLinear(num_hiddens), nn.Tanh(), nn.LazyLinear(2))

    
            

    def forward(self, input_ids, token_type_ids, attention_mask, labels_batch, labels_positions):
        X = self.token_embedding(input_ids)*math.sqrt(self.num_hiddens) + self.segment_embedding(token_type_ids)
        X = X + self.pos_embedding[:, :X.shape[1], :]
        for blk in self.blks:
            X = blk(X, src_key_padding_mask=attention_mask)

        ##get mlm_Y_hat prediction
        masked_X = X[labels_batch, labels_positions]
        mlm_Y_hat = self.mlp(masked_X)

        #get nsp_Y_hat prediction
        nsp_Y_hat  =self.nsp(X[:,0,:])

        return mlm_Y_hat, nsp_Y_hat


#custom criterion
class CriterionBert(torch.nn.modules.loss._Loss):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, y_pred, y_true):
        loss = nn.CrossEntropyLoss()
        
        #compute masked language model loss:
        mlm_l = loss(y_pred[0], y_true[0].reshape(-1))

        #compute next sentence prediction loss:
        nsp_l = loss(y_pred[1], y_true[1].reshape(-1))

        #total loss
        l = mlm_l + nsp_l

        return l

## Optuna based objective function for hyperparameter tuning

In [618]:
#approach using optuna

import optuna
from sklearn.model_selection import KFold


#define model, optimizer and criterion    
model = BertPretraining(vocab_size = len(tokenizer), num_hiddens=128, ffn_num_hiddens=256, num_heads=2, num_blks=2, 
                        dropout=0.2, max_len=15)
device = torch.device("cuda:0")
model.to(device)


def objective(trial):

    #define hyperparameter search space. Keep it simple and just search "lr" and "weight_decay"
    
    lr = trial.suggest_float("lr", 1e-5, 1e0, log=True)
    
    weight_decay = trial.suggest_float("Ridge", 1e-5, 1e0, log=True)

    #define optimizer and criterion
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    criterion = CriterionBert()


    #define data and cross-validation splitting values

    dataset = MyDataset(tokenizer, train_data, 15)

    kf = KFold(n_splits=3, shuffle=True)
    fold_accuracies=[]

    for train_idx, val_idx in kf.split(range(len(dataset))):
        train_subset = torch.utils.data.Subset(dataset, train_idx)
        val_subset = torch.utils.data.Subset(dataset, val_idx)

        train_dataloader = torch.utils.data.DataLoader(train_subset, batch_size=256, shuffle=True, pin_memory=True, num_workers=10, 
                                                       collate_fn = collate_fn, prefetch_factor=2, multiprocessing_context='fork')

        val_dataloader = torch.utils.data.DataLoader(train_subset, batch_size=256, shuffle=False, pin_memory=True, num_workers=10, 
                                                    collate_fn = collate_fn, prefetch_factor=2, multiprocessing_context='fork')

        num_epochs=5
        for epoch in range(num_epochs):
            model.train()
            for inputs, labels in train_dataloader:
                #move data to GPU
                inputs = {k:v.to(device, non_blocking=True) for k, v in inputs.items()}
                labels_mlp = labels[0].to(device, non_blocking=True)
                labels_nsp = labels[1].to(device, non_blocking=True)
                labels = (labels_mlp, labels_nsp)
                
                outputs = model(**inputs)
                loss = criterion(outputs, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_dataloader:
                inputs = {k:v.to(device, non_blocking=True) for k, v in inputs.items()}
                labels_mlm = labels[0].to(device, non_blocking=True)
                labels_nsp = labels[1].to(device, non_blocking=True)
                total += labels_mlm.shape[0] + labels_nsp.shape[0]
                outputs = model(**inputs) #mlm_Y_hat and nsp_Y_hat
                #get predictions from mlm_Y_hat
                _, mlm_preds = torch.max(outputs[0],1)
                correct += (mlm_preds==labels_mlm).sum().item()

                #get predictions from nsp_Y_hat
                _, nsp_preds = torch.max(outputs[1],1)
                correct += (nsp_preds==labels_nsp).sum().item()

        fold_accuracies.append(correct/total)

    return sum(fold_accuracies)/len(fold_accuracies)

## Run the optuna study

In [619]:
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=3, n_jobs=10, show_progress_bar=True)

[I 2025-08-30 10:18:34,467] A new study created in memory with name: no-name-ce7afe24-b26b-44dc-b5ca-651ab22ea2e2


  0%|          | 0/3 [00:00<?, ?it/s]

[W 2025-08-30 10:18:38,046] Trial 2 failed with parameters: {'lr': 3.26946681009366e-05, 'Ridge': 0.0002053279746605567} because of the following error: RuntimeError('one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [128, 384]], which is output 0 of AsStridedBackward0, is at version 11; expected version 10 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).').
Traceback (most recent call last):
  File "/home/ildar/anaconda3/envs/play/lib/python3.12/site-packages/optuna/study/_optimize.py", line 201, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "/tmp/ipykernel_17234/3522909497.py", line 59, in objective
    loss.backward()
  File "/home/ildar/anaconda3/envs/play/lib/python3.12/site-packages/torch/_tensor.py", line 626, in backward
    torch.autograd.backward(
  File "/home/ildar/ana

RuntimeError: DataLoader worker (pid 909083) exited unexpectedly with exit code 1. Details are lost due to multiprocessing. Rerunning with num_workers=0 may give better error trace.

[W 2025-08-30 10:18:38,140] Trial 1 failed with parameters: {'lr': 0.004539507271152895, 'Ridge': 7.24419158627919e-05} because of the following error: RuntimeError('one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [128, 28996]], which is output 0 of AsStridedBackward0, is at version 19; expected version 18 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).').
Traceback (most recent call last):
  File "/home/ildar/anaconda3/envs/play/lib/python3.12/site-packages/optuna/study/_optimize.py", line 201, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "/tmp/ipykernel_17234/3522909497.py", line 59, in objective
    loss.backward()
  File "/home/ildar/anaconda3/envs/play/lib/python3.12/site-packages/torch/_tensor.py", line 626, in backward
    torch.autograd.backward(
  File "/home/ildar/an

## Appendix

### In this appendix we play around with the dataset to get a feel for its output signature

#### The dataset uses the tokenizer to create the __getitem__. Let's see what it looks like

In [580]:
L = [p["text"] for p in train_data]

import random
NSP_data = []  #elements will be (sen1, sen2, True/False) depending on whether sen2 follows sen1

for paragraph in L:
    cand_sentences = paragraph.split(' . ')
    number_of_sent = len(cand_sentences)
    if number_of_sent<2:
        continue
    else:
        for i in range(number_of_sent-1):
            if random.random()<0.5:
                NSP_data.append(((cand_sentences[i], cand_sentences[i+1]), 1))
            #look at cand_sentences[i] and cand_sentences[i+1]
            else:
                repla_sent = random.choice(L)
                repla_sent = repla_sent.split(' . ')
                repla_sent = random.choice(repla_sent)
                NSP_data.append(((cand_sentences[i], repla_sent), 0))

In [585]:
pair, _ = NSP_data[5]
tokenizer(*pair, truncation=True, max_length=70)

{'input_ids': [101, 1799, 1122, 5366, 1103, 2530, 1956, 1104, 1103, 1326, 117, 1122, 1145, 9315, 2967, 27939, 117, 1216, 1112, 1543, 1103, 1342, 1167, 1111, 5389, 3970, 1111, 1326, 25551, 1116, 102, 134, 134, 22130, 134, 134, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

#### Let's see how collate function pads/batches data together 

In [588]:
pair1, _ = NSP_data[5]
pair2, _ = NSP_data[6]

x = tokenizer(*pair1, truncation=True, max_length=70)
y = tokenizer(*pair2, truncation=True, max_length=70)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True)

tokens = data_collator([x,y])
tokens

{'input_ids': tensor([[  101,   103,  1122,  5366,  1103,  2530,  1956,  1104,  1103,  1326,
           117,  1122,  1145,  9315,  2967, 27939,   117,  1216,  1112,  1543,
          1103,  1342,  1167,  1111,   103,   103,   103,  1326, 25551,  1116,
           102,   134,   134, 22130,   134,   103,   102,     0,     0,     0,
             0,     0],
        [  101, 23543,  5592,   103,  1777, 10942, 25028,  1105,  3996, 15375,
           103, 17784, 18504, 12355,  1241,  1608,  1121,  2166, 10813,   117,
          1373,  1114,   103,   103,  3464, 17758,  1563,  1900, 26713, 18763,
         16075, 10946,   102,   138,   103,  1264,  1104,  5094,  8630,  1103,
          5444,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 

In [586]:
dataset = MyDataset(tokenizer, train_data, 15)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, collate_fn=collate_fn)

X, y = next(iter(dataloader))
print("X:", X,"\n","y:", y)

X: {'input_ids': tensor([[  101, 14895, 21006,  1185, 12226,  3781,  3464,   102, 12226,  3781,
          3464,  1104,  1103,  2651,   102],
        [  101, 18653,  1643, 26179,  1158,  1103,  1269, 11970,  1104, 12394,
          1105,  1842,   137,   118,   102],
        [  101,   103,  1342,  1310,  1718,   103,  1333,   102, 19729,  1122,
          5366,  1103,   103,  1956,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]]), 'attention_mask': tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False]]), 'labels_batch': tensor([2, 2, 2,