In [1]:
# import required moduels
import os
import re
import math
import torch
import numpy as np
import torch.nn as nn
from random import *
import torch.optim as optim
import torch.nn.functional as F
from datasets import load_dataset

In [2]:
# check gpu is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
# create our work reproducilble everytime
SEED = 1234
torch.manual_seed(SEED)
# reduce randomness in GPU kernels
torch.backends.cudnn.deterministic = True
# prevents CUDNN selecting different algo each run
torch.backends.cudnn.benchmark = False
# check cuda is used by GPU name
print("GPU", torch.cuda.get_device_name(0))

GPU Tesla T4


## Data Set
I will be using BookCorpus dataset and use only 1% of the dataset to be trained.

In [None]:
# downgrade the datasets because of the uncompatibility
!pip install -q "datasets<4.0.0"
# restart runtime


In [5]:
# load subset of the dataset
dataset = load_dataset('bookcorpus', split='train[:100000]')
# dataset = load_dataset("lucadiliello/bookcorpusopen", split="train[:1%]")
print(dataset)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Dataset({
    features: ['text'],
    num_rows: 100000
})


In [6]:
# extract sentences from train dataset
sentences = dataset['text']
sentences[1]

'but just one look at a minion sent him practically catatonic .'

In [7]:
# text cleaning
# lowercase
text = [x.lower() for x in sentences]
# removing punctuation
text = [re.sub("[.,!\\-]", " ", x) for x in text]
print(text)



## Build Vocabulary


In [8]:
from tqdm.auto import tqdm

# combine all words into a single list
word_list = list(set(" ".join(text).split()))
word2id = {
  '[PAD]': 0,
  '[CLS]': 1,
  '[SEP]': 2,
  '[MASK]': 3,
  '[UNK]': 4
}

# create word2id
for i,w in tqdm(enumerate(word_list), desc='Creating word2id'):
  word2id[w] = i + 5


# Precompute the id2word mapping (this can be done once after word2id is fully populated)
id2word = {v: k for k, v in word2id.items()}
vocab_size = len(word2id)
vocab_size

Creating word2id: 0it [00:00, ?it/s]

20889

In [9]:
# List of all tokens for the whole text
token_list = []

# Process sentences more efficiently
for sentence in tqdm(text, desc="Processing sentences"):
    token_list.append([word2id[word] for word in sentence.split()])


Processing sentences:   0%|          | 0/100000 [00:00<?, ?it/s]

In [10]:
#take a look at sentences
sentences[:2]

['usually , he would be tearing around the living room , playing with his toys .',
 'but just one look at a minion sent him practically catatonic .']

In [11]:
#take a look at token_list
token_list[:2]

[[17030,
  13614,
  4606,
  5053,
  19042,
  16593,
  6177,
  7632,
  7477,
  17169,
  1875,
  17986,
  9252],
 [5920, 12366, 10785, 15092, 586, 17525, 2399, 19299, 19625, 10264, 19911]]

In [12]:
batch_size = 6
max_mask   = 5  # max masked tokens when 15% exceed, it will only be max_pred
max_len    = 1000 # maximum of length to be padded;

In [13]:
import random

PAD_ID = word2id["[PAD]"]
CLS_ID = word2id["[CLS]"]
SEP_ID = word2id["[SEP]"]
MASK_ID = word2id["[MASK]"]
UNK_ID = word2id["[UNK]"]


# --------- Utility function: pad or truncate list to max_len ---------
def pad_to_len(seq, pad_value, max_len):
    """
    seq: a Python list (example: token ids)
    pad_value: what to pad with (example: PAD_ID)
    max_len: final required length

    If seq is shorter -> pad
    If seq is longer -> cut
    """
    seq = seq[:max_len]
    return seq + [pad_value] * (max_len - len(seq))


def make_batch():
    """
    This function creates ONE training batch for BERT pretraining:
    1) MLM (Masked Language Modeling)
    2) NSP (Next Sentence Prediction)

    Output batch contains tuples:
      (input_ids, segment_ids, masked_tokens, masked_pos, is_next)
    """

    batch = []
    positive = 0  # how many "is_next = 1"
    negative = 0  # how many "is_next = 0"

    # We want 50% positive and 50% negative inside one batch
    while positive < batch_size // 2 or negative < batch_size // 2:

        # ------------------- Pick sentence A -------------------
        a_idx = randrange(len(token_list))  # random index for sentence A
        tokens_a = token_list[a_idx]        # sentence A token IDs

        # ------------------- Decide NSP label -------------------
        # 1 means: sentence B is the real "next sentence"
        # 0 means: sentence B is random (not the next one)
        if positive >= batch_size // 2:
            is_next = 0   # if enough positives, force negatives
        elif negative >= batch_size // 2:
            is_next = 1   # if enough negatives, force positives
        else:
            is_next = randint(0, 1)  # random 0 or 1

        # ------------------- Pick sentence B -------------------
        if is_next == 1:
            # True next sentence
            # We approximate "next sentence" by using a_idx + 1
            b_idx = min(a_idx + 1, len(token_list) - 1)
            tokens_b = token_list[b_idx]
        else:
            # Random sentence (negative sample)
            b_idx = randrange(len(token_list))
            tokens_b = token_list[b_idx]

        # ------------------- Build BERT input format -------------------
        # BERT format: [CLS] A [SEP] B [SEP]
        input_ids = [CLS_ID] + tokens_a + [SEP_ID] + tokens_b + [SEP_ID]

        # segment ids: 0 for sentence A part, 1 for sentence B part
        # Example: [CLS] A [SEP] -> 0, B [SEP] -> 1
        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)

        # ------------------- Truncate if too long -------------------
        input_ids = input_ids[:max_len]
        segment_ids = segment_ids[:max_len]

        # ------------------- MLM: choose tokens to mask -------------------
        # We do NOT mask [CLS] and [SEP]
        cand_pos = [i for i, t in enumerate(input_ids) if t not in (CLS_ID, SEP_ID)]
        shuffle(cand_pos)

        # Standard BERT: mask 15% of tokens
        n_pred = min(max_mask, max(1, int(round(len(input_ids) * 0.15))))

        masked_tokens = []  # true tokens we want model to predict
        masked_pos = []     # positions where we applied mask

        for pos in cand_pos[:n_pred]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])  # save original token at pos

            # BERT masking rule:
            # 80% -> replace by [MASK]
            # 10% -> replace by random token
            # 10% -> keep original token
            r = random.random()
            if r < 0.8:
                input_ids[pos] = MASK_ID
            elif r < 0.9:
                input_ids[pos] = randint(0, vocab_size - 1)
            else:
                # keep the original token
                pass

        # ------------------- Pad masked tokens/positions -------------------
        # We pad masked arrays to fixed size max_mask for batching
        masked_tokens = pad_to_len(masked_tokens, PAD_ID, max_mask)
        masked_pos = pad_to_len(masked_pos, 0, max_mask)

        # ------------------- Pad input and segment ids -------------------
        input_ids = pad_to_len(input_ids, PAD_ID, max_len)
        segment_ids = pad_to_len(segment_ids, 0, max_len)

        # ------------------- Count positive/negative -------------------
        if is_next == 1:
            positive += 1
        else:
            negative += 1

        # Add one training example (one sentence pair) into batch
        batch.append((input_ids, segment_ids, masked_tokens, masked_pos, is_next))

    return batch

In [14]:
batch = make_batch()

In [15]:
# convert to tensors (PyTorch needs tensors, not Python lists)
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(
    torch.LongTensor, zip(*batch)
)

print("input_ids:", input_ids.shape)         # [batch, max_len]
print("segment_ids:", segment_ids.shape)     # [batch, max_len]
print("masked_tokens:", masked_tokens.shape) # [batch, max_mask]
print("masked_pos:", masked_pos.shape)       # [batch, max_mask]
print("isNext:", isNext.shape)               # [batch]

input_ids: torch.Size([6, 1000])
segment_ids: torch.Size([6, 1000])
masked_tokens: torch.Size([6, 5])
masked_pos: torch.Size([6, 5])
isNext: torch.Size([6])


In [16]:
class BERTEmbedding(nn.Module):
    """
    Combines 3 embeddings:
    1) Token embedding: maps token id -> vector
    2) Position embedding: gives order information (0,1,2,...)
    3) Segment embedding: tells whether token belongs to sentence A or B
    """
    def __init__(self, vocab_size, max_len, n_segments, d_model):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.seg_emb = nn.Embedding(n_segments, d_model)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, input_ids, segment_ids):
        """
        input_ids: [B, S]
        segment_ids: [B, S]
        returns: [B, S, d_model]
        """
        B, S = input_ids.size()

        # positions: [0,1,2,...,S-1] for every row in batch
        pos = torch.arange(S, device=input_ids.device).unsqueeze(0).expand(B, S)

        x = self.tok_emb(input_ids) + self.pos_emb(pos) + self.seg_emb(segment_ids)
        return self.norm(x)


In [17]:
def get_attn_pad_mask(input_ids):
    """
    Create a mask where PAD positions are True (masked out).

    input_ids: [B, S]
    return mask: [B, 1, 1, S] (broadcastable for attention)
    """
    pad_mask = input_ids.eq(PAD_ID)  # True where PAD
    return pad_mask.unsqueeze(1).unsqueeze(2)  # [B,1,1,S]


In [18]:
class ScaledDotProductAttention(nn.Module):
    """
    Attention(Q,K,V) = softmax(QK^T / sqrt(d_k)) V
    """
    def __init__(self, d_k):
        super().__init__()
        self.scale = math.sqrt(d_k)

    def forward(self, Q, K, V, attn_mask):
        """
        Q,K,V: [B, H, S, d_k]
        attn_mask: [B, 1, 1, S] -> broadcast to [B,H,S,S]
        """
        scores = torch.matmul(Q, K.transpose(-1, -2)) / self.scale  # [B,H,S,S]

        # Mask PAD tokens: set large negative so softmax makes it ~0
        scores = scores.masked_fill(attn_mask, -1e9)

        attn = torch.softmax(scores, dim=-1)                       # [B,H,S,S]
        context = torch.matmul(attn, V)                            # [B,H,S,d_k]
        return context


In [19]:
class MultiHeadAttention(nn.Module):
    """
    Multi-head self-attention.
    """
    def __init__(self, d_model, n_heads, d_k):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_k

        # Linear projections for Q, K, V
        self.W_Q = nn.Linear(d_model, n_heads * d_k)
        self.W_K = nn.Linear(d_model, n_heads * d_k)
        self.W_V = nn.Linear(d_model, n_heads * d_k)

        # Output projection
        self.W_O = nn.Linear(n_heads * d_k, d_model)

        self.attn = ScaledDotProductAttention(d_k)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, attn_mask):
        """
        x: [B, S, d_model]
        """
        residual = x
        B, S, _ = x.size()

        # Project and reshape: [B,S,n_heads*d_k] -> [B,n_heads,S,d_k]
        Q = self.W_Q(x).view(B, S, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_K(x).view(B, S, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_V(x).view(B, S, self.n_heads, self.d_k).transpose(1, 2)

        # Broadcast mask to all heads: [B,1,1,S] -> [B,H,1,S]
        attn_mask = attn_mask.expand(B, self.n_heads, 1, S)

        context = self.attn(Q, K, V, attn_mask)  # [B,H,S,d_k]

        # Merge heads: [B,H,S,d_k] -> [B,S,H*d_k]
        context = context.transpose(1, 2).contiguous().view(B, S, self.n_heads * self.d_k)

        out = self.W_O(context)  # [B,S,d_model]
        return self.norm(out + residual)  # add & norm


In [20]:
class FeedForward(nn.Module):
    """
    FFN = Linear -> GELU -> Linear, plus residual + layernorm
    """
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        residual = x
        x = self.fc2(F.gelu(self.fc1(x)))
        return self.norm(x + residual)


In [21]:
class EncoderLayer(nn.Module):
    """
    One Transformer encoder block:
    1) Multi-head self-attention
    2) Feed-forward
    """
    def __init__(self, d_model, n_heads, d_k, d_ff):
        super().__init__()
        self.mha = MultiHeadAttention(d_model, n_heads, d_k)
        self.ffn = FeedForward(d_model, d_ff)

    def forward(self, x, attn_mask):
        x = self.mha(x, attn_mask)
        x = self.ffn(x)
        return x


In [22]:
class BERT(nn.Module):
    """
    BERT pretraining model:
    - Encoder stack
    - MLM head (predict masked tokens)
    - NSP head (predict if B is next sentence)
    """
    def __init__(self, vocab_size, max_len, n_segments,
                 n_layers=4, d_model=256, n_heads=4, d_k=64, d_ff=1024):
        super().__init__()

        self.embedding = BERTEmbedding(vocab_size, max_len, n_segments, d_model)
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, n_heads, d_k, d_ff) for _ in range(n_layers)
        ])

        # NSP head: uses [CLS] representation
        self.pooler = nn.Linear(d_model, d_model)
        self.nsp_classifier = nn.Linear(d_model, 2)

        # MLM head: predict vocab at masked positions
        self.mlm_transform = nn.Linear(d_model, d_model)
        self.mlm_norm = nn.LayerNorm(d_model)
        self.mlm_decoder = nn.Linear(d_model, vocab_size)


    def forward(self, input_ids, segment_ids, masked_pos=None, return_hidden=False):
      """
      input_ids: [B,S]
      segment_ids: [B,S]
      masked_pos: [B,max_mask]
      """
      x = self.embedding(input_ids, segment_ids)     # [B,S,H]
      attn_mask = get_attn_pad_mask(input_ids)       # [B,1,1,S]

      # ---- Encoder (run ONCE) ----
      for layer in self.layers:
          x = layer(x, attn_mask)

      # ---- For Task 2 (Sentence-BERT) ----
      if return_hidden:
          return x

      # ---- NSP head ----
      cls_vec = x[:, 0]
      pooled = torch.tanh(self.pooler(cls_vec))
      logits_nsp = self.nsp_classifier(pooled)       # [B,2]

      # ---- MLM head ----
      B, S, H = x.size()
      masked_pos = masked_pos.unsqueeze(-1).expand(-1, -1, H)   # [B,max_mask,H]
      h_masked = torch.gather(x, 1, masked_pos)                 # [B,max_mask,H]
      h_masked = self.mlm_norm(F.gelu(self.mlm_transform(h_masked)))
      logits_mlm = self.mlm_decoder(h_masked)                   # [B,max_mask,vocab_size]

      return logits_mlm, logits_nsp


In [23]:
# Sanity Test
# create one batch
batch = make_batch()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))

# move to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = BERT(vocab_size=vocab_size, max_len=max_len, n_segments=2).to(device)

logits_mlm, logits_nsp = model(
    input_ids.to(device),
    segment_ids.to(device),
    masked_pos.to(device)
)

print("logits_mlm:", logits_mlm.shape)  # [B, max_mask, vocab_size]
print("logits_nsp:", logits_nsp.shape)  # [B, 2]


logits_mlm: torch.Size([6, 5, 20889])
logits_nsp: torch.Size([6, 2])


In [24]:
import torch.optim as optim

# MLM loss: predict the original token at masked positions
# ignore_index=PAD_ID means: padded masked slots do not contribute to loss
mlm_criterion = nn.CrossEntropyLoss(ignore_index=PAD_ID)

# NSP loss: classify 0/1 (random/next)
nsp_criterion = nn.CrossEntropyLoss()

optimizer = optim.AdamW(model.parameters(), lr=2e-4)

In [25]:
# ---- Training settings ----
epochs = 4
steps_per_epoch = 200   # increase if you have time/GPU
print_every = 50

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()

for epoch in range(epochs):
    total_loss = 0.0
    total_mlm = 0.0
    total_nsp = 0.0

    for step in range(1, steps_per_epoch + 1):

        # 1) Create one batch using your make_batch()
        batch = make_batch()

        # 2) Convert batch into torch tensors
        input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(
            torch.LongTensor, zip(*batch)
        )

        # 3) Move to GPU/CPU
        input_ids = input_ids.to(device)
        segment_ids = segment_ids.to(device)
        masked_tokens = masked_tokens.to(device)
        masked_pos = masked_pos.to(device)
        isNext = isNext.to(device)

        # 4) Forward pass
        logits_mlm, logits_nsp = model(input_ids, segment_ids, masked_pos)

        # 5) Compute MLM loss
        # logits_mlm: [B, max_mask, vocab_size]
        # CrossEntropy expects: [B, vocab_size, max_mask]
        loss_mlm = mlm_criterion(logits_mlm.transpose(1, 2), masked_tokens)

        # 6) Compute NSP loss
        # logits_nsp: [B,2], isNext: [B]
        loss_nsp = nsp_criterion(logits_nsp, isNext)

        # 7) Total loss = MLM + NSP (like original BERT pretraining)
        loss = loss_mlm + loss_nsp

        # 8) Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_mlm += loss_mlm.item()
        total_nsp += loss_nsp.item()

        # 9) Print progress sometimes
        if step % print_every == 0:
            print(
                f"Epoch {epoch+1}/{epochs} | Step {step}/{steps_per_epoch} | "
                f"Loss={total_loss/step:.4f} | MLM={total_mlm/step:.4f} | NSP={total_nsp/step:.4f}"
            )

    print(f"Epoch {epoch+1} finished. Avg Loss={total_loss/steps_per_epoch:.4f}")

Epoch 1/4 | Step 50/200 | Loss=9.8730 | MLM=9.1697 | NSP=0.7033
Epoch 1/4 | Step 100/200 | Loss=8.9769 | MLM=8.2791 | NSP=0.6978
Epoch 1/4 | Step 150/200 | Loss=8.5292 | MLM=7.8329 | NSP=0.6963
Epoch 1/4 | Step 200/200 | Loss=8.2861 | MLM=7.5907 | NSP=0.6954
Epoch 1 finished. Avg Loss=8.2861
Epoch 2/4 | Step 50/200 | Loss=7.4321 | MLM=6.7389 | NSP=0.6933
Epoch 2/4 | Step 100/200 | Loss=7.4548 | MLM=6.7622 | NSP=0.6926
Epoch 2/4 | Step 150/200 | Loss=7.4480 | MLM=6.7541 | NSP=0.6940
Epoch 2/4 | Step 200/200 | Loss=7.3901 | MLM=6.6961 | NSP=0.6940
Epoch 2 finished. Avg Loss=7.3901
Epoch 3/4 | Step 50/200 | Loss=7.3396 | MLM=6.6479 | NSP=0.6917
Epoch 3/4 | Step 100/200 | Loss=7.3057 | MLM=6.6082 | NSP=0.6974
Epoch 3/4 | Step 150/200 | Loss=7.2826 | MLM=6.5861 | NSP=0.6965
Epoch 3/4 | Step 200/200 | Loss=7.3047 | MLM=6.6081 | NSP=0.6966
Epoch 3 finished. Avg Loss=7.3047
Epoch 4/4 | Step 50/200 | Loss=7.3574 | MLM=6.6643 | NSP=0.6932
Epoch 4/4 | Step 100/200 | Loss=7.3173 | MLM=6.6207 | NSP

In [26]:
save_obj = {
    "model_state_dict": model.state_dict(),
    "config": {
        "vocab_size": vocab_size,
        "max_len": max_len,
        "n_segments": 2,
        "n_layers": model.layers.__len__(),
        "d_model": model.embedding.tok_emb.embedding_dim,
        "n_heads": model.layers[0].mha.n_heads,
        "d_k": model.layers[0].mha.d_k,
        "d_ff": model.layers[0].ffn.fc1.out_features
    },
    "word2id": word2id,
    "id2word": id2word
}


torch.save(save_obj, "bert_mlm.pt")
print("Saved: bert_mlm.pt")


Saved: bert_mlm.pt


In [None]:
device = torch.device("cpu")
model = model.to(device).eval()

class EncoderOnly(nn.Module):
    def __init__(self, bert):
        super().__init__()
        self.bert = bert
    def forward(self, input_ids, segment_ids):
        return self.bert(input_ids, segment_ids, return_hidden=True)

encoder = EncoderOnly(model).eval()

PAD_ID = word2id["[PAD]"]
S = max_len

example_input_ids = torch.full((1, S), PAD_ID, dtype=torch.long)
example_segment_ids = torch.zeros((1, S), dtype=torch.long)

traced = torch.jit.trace(encoder, (example_input_ids, example_segment_ids), strict=False)
traced.save("bert_encoder.pt")

print("Saved: bert_encoder.pt")

Saved: bert_encoder.pt
