# BERT (Updated 1 Feb 2025, Available CUDA)

We shall implement BERT.  For this tutorial, you may want to first look at my Transformers tutorial to get a basic understanding of Transformers. 

For BERT, the main difference is on how we process the datasets, i.e., masking.   Aside from that, the backbone model is still the Transformers.

In [1]:
import math
import re
from   random import *
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
from sklearn.metrics.pairwise import cosine_similarity

In [2]:
# Set GPU device
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# os.environ['http_proxy']  = 'http://192.41.170.23:3128'
# os.environ['https_proxy'] = 'http://192.41.170.23:3128'
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device
print(device)

#make our work comparable if restarted the kernel
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# torch.cuda.get_device_name(0)

mps


## 1. Data

For simplicity, we shall use very simple data like this.

In [3]:
from datasets import load_dataset

# Load BookCorpus dataset
# The first 1% of `train` split.
dataset = load_dataset('roneneldan/TinyStories', split='train[:1%]')
dataset

  from .autonotebook import tqdm as notebook_tqdm


Dataset({
    features: ['text'],
    num_rows: 21197
})

In [4]:
sentences = dataset['text']
text = [x.lower() for x in sentences] #lower case
text = [re.sub("[.,!?\\-]", '', x) for x in text] #clean all symbols
# text

In [5]:
for sentence in text:
    print(sentence, "_____")
    words = sentence.split()
    print(words)
    break

one day a little girl named lily found a needle in her room she knew it was difficult to play with it because it was sharp lily wanted to share the needle with her mom so she could sew a button on her shirt

lily went to her mom and said "mom i found this needle can you share it with me and sew my shirt" her mom smiled and said "yes lily we can share the needle and fix your shirt"

together they shared the needle and sewed the button on lily's shirt it was not difficult for them because they were sharing and helping each other after they finished lily thanked her mom for sharing the needle and fixing her shirt they both felt happy because they had shared and worked together _____
['one', 'day', 'a', 'little', 'girl', 'named', 'lily', 'found', 'a', 'needle', 'in', 'her', 'room', 'she', 'knew', 'it', 'was', 'difficult', 'to', 'play', 'with', 'it', 'because', 'it', 'was', 'sharp', 'lily', 'wanted', 'to', 'share', 'the', 'needle', 'with', 'her', 'mom', 'so', 'she', 'could', 'sew', 'a', 'bu

### Making vocabs

Before making the vocabs, let's remove all question marks and perios, etc, then turn everything to lowercase, and then simply split the text. 

In [6]:
from tqdm.auto import tqdm

# Combine everything into one to make vocab
word_list = list(set(" ".join(text).split()))
word2id = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}  # special tokens

# Create the word2id in a single pass
for i, w in tqdm(enumerate(word_list), desc="Creating word2id"):
    word2id[w] = i + 4  # because 0-3 are already occupied

# Precompute the id2word mapping (this can be done once after word2id is fully populated)
id2word = {v: k for k, v in word2id.items()}
vocab_size = len(word2id)
vocab_size

Creating word2id: 18705it [00:00, 3613747.41it/s]


18709

In [7]:
vocab_size = len(word2id)

# List of all tokens for the whole text
token_list = []

# Process sentences more efficiently
for sentence in tqdm(text, desc="Processing sentences"):
    token_list.append([word2id[word] for word in sentence.split()])

# Now token_list contains the tokenized sentences

Processing sentences: 100%|██████████| 21197/21197 [00:00<00:00, 71709.85it/s]


In [8]:
#take a look at sentences
sentences[:2]

['One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.\n\nLily went to her mom and said, "Mom, I found this needle. Can you share it with me and sew my shirt?" Her mom smiled and said, "Yes, Lily, we can share the needle and fix your shirt."\n\nTogether, they shared the needle and sewed the button on Lily\'s shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together.',
 'Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong.\n\nOne day, Beep was driving in the park when he saw a big tree. The tree had many leaves that we

In [9]:
#take a look at token_list
token_list[:2]

[[12666,
  2860,
  4797,
  17342,
  8928,
  2184,
  8194,
  12918,
  4797,
  4599,
  2497,
  10823,
  18316,
  17852,
  5404,
  6076,
  6098,
  5333,
  4594,
  10073,
  16853,
  6076,
  2984,
  6076,
  6098,
  11954,
  8194,
  15775,
  4594,
  3620,
  9234,
  4599,
  16853,
  10823,
  9939,
  8135,
  17852,
  14186,
  6612,
  4797,
  3946,
  14822,
  10823,
  2812,
  8194,
  1522,
  4594,
  10823,
  9939,
  14987,
  14077,
  5063,
  17575,
  12918,
  1302,
  4599,
  9395,
  11967,
  3620,
  6076,
  16853,
  13387,
  14987,
  6612,
  9811,
  259,
  10823,
  9939,
  15122,
  14987,
  14077,
  4254,
  8194,
  614,
  9395,
  3620,
  9234,
  4599,
  14987,
  12515,
  3572,
  259,
  13363,
  13640,
  6129,
  9234,
  4599,
  14987,
  4036,
  9234,
  3946,
  14822,
  13123,
  2812,
  6076,
  6098,
  8868,
  5333,
  15326,
  6383,
  2984,
  13640,
  9025,
  2842,
  14987,
  12153,
  2805,
  14855,
  15093,
  13640,
  13833,
  8194,
  7224,
  10823,
  9939,
  15326,
  2842,
  9234,
  4599,
  149

In [10]:
#testing one sentence
for tokens in token_list[0]:
    print(id2word[tokens])

one
day
a
little
girl
named
lily
found
a
needle
in
her
room
she
knew
it
was
difficult
to
play
with
it
because
it
was
sharp
lily
wanted
to
share
the
needle
with
her
mom
so
she
could
sew
a
button
on
her
shirt
lily
went
to
her
mom
and
said
"mom
i
found
this
needle
can
you
share
it
with
me
and
sew
my
shirt"
her
mom
smiled
and
said
"yes
lily
we
can
share
the
needle
and
fix
your
shirt"
together
they
shared
the
needle
and
sewed
the
button
on
lily's
shirt
it
was
not
difficult
for
them
because
they
were
sharing
and
helping
each
other
after
they
finished
lily
thanked
her
mom
for
sharing
the
needle
and
fixing
her
shirt
they
both
felt
happy
because
they
had
shared
and
worked
together


## 2. Data loader

We gonna make dataloader.  Inside here, we need to make two types of embeddings: **token embedding** and **segment embedding**

1. **Token embedding** - Given “The cat is walking. The dog is barking”, we add [CLS] and [SEP] >> “[CLS] the cat is walking [SEP] the dog is barking”. 

2. **Segment embedding**
A segment embedding separates two sentences, i.e., [0 0 0 0 1 1 1 1 ]

3. **Masking**
As mentioned in the original paper, BERT randomly assigns masks to 15% of the sequence. In this 15%, 80% is replaced with masks, while 10% is replaced with random tokens, and the rest 10% is left as is.  Here we specified `max_pred` 

4. **Padding**
Once we mask, we will add padding. For simplicity, here we padded until some specified `max_len`. 

Note:  `positive` and `negative` are just simply counts to keep track of the batch size.  `positive` refers to two sentences that are really next to one another.

In [11]:
batch_size = 6
max_mask   = 5  # max masked tokens when 15% exceed, it will only be max_pred
max_len    = 700 # maximum of length to be padded; 

In [12]:
def make_batch():
    batch = []
    positive = negative = 0  #count of batch size;  we want to have half batch that are positive pairs (i.e., next sentence pairs)
    while positive != batch_size/2 or negative != batch_size/2:
        
        #randomly choose two sentence so we can put [SEP]
        tokens_a_index, tokens_b_index = randrange(len(sentences)), randrange(len(sentences))
        #retrieve the two sentences
        tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]

        #1. token embedding - append CLS and SEP
        input_ids = [word2id['[CLS]']] + tokens_a + [word2id['[SEP]']] + tokens_b + [word2id['[SEP]']]

        #2. segment embedding - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)

        #3. mask language modeling
        #masked 15%, but should be at least 1 but does not exceed max_mask
        n_pred =  min(max_mask, max(1, int(round(len(input_ids) * 0.15))))
        #get the pos that excludes CLS and SEP and shuffle them
        cand_maked_pos = [i for i, token in enumerate(input_ids) if token != word2id['[CLS]'] and token != word2id['[SEP]']]
        shuffle(cand_maked_pos)
        masked_tokens, masked_pos = [], []
        #simply loop and change the input_ids to [MASK]
        for pos in cand_maked_pos[:n_pred]:
            masked_pos.append(pos)  #remember the position
            masked_tokens.append(input_ids[pos]) #remember the tokens
            #80% replace with a [MASK], but 10% will replace with a random token
            if random() < 0.1:  # 10%
                index = randint(0, vocab_size - 1) # random index in vocabulary
                input_ids[pos] = word2id[id2word[index]] # replace
            elif random() < 0.9:  # 80%
                input_ids[pos] = word2id['[MASK]'] # make mask
            else:  #10% do nothing
                pass

        # pad the input_ids and segment ids until the max len
        n_pad = max_len - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)

        # pad the masked_tokens and masked_pos to make sure the lenth is max_mask
        if max_mask > n_pred:
            n_pad = max_mask - n_pred
            masked_tokens.extend([0] * n_pad)
            masked_pos.extend([0] * n_pad)

        #check if first sentence is really comes before the second sentence
        #also make sure positive is exactly half the batch size
        if tokens_a_index + 1 == tokens_b_index and positive < batch_size / 2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # IsNext
            positive += 1
        elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # NotNext
            negative += 1
            
    return batch

In [13]:
batch = make_batch()

In [14]:
#len of batch
len(batch)

6

In [15]:
#we can deconstruct using map and zip
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))
input_ids.shape, segment_ids.shape, masked_tokens.shape, masked_pos.shape, isNext.shape

(torch.Size([6, 700]),
 torch.Size([6, 700]),
 torch.Size([6, 5]),
 torch.Size([6, 5]),
 torch.Size([6]))

## 3. Model

Recall that BERT only uses the encoder.

BERT has the following components:

- Embedding layers
- Attention Mask
- Encoder layer
- Multi-head attention
- Scaled dot product attention
- Position-wise feed-forward network
- BERT (assembling all the components)

## 3.1 Embedding

Here we simply generate the positional embedding, and sum the token embedding, positional embedding, and segment embedding together.

<img src = "figures/BERT_embed.png" width=500>

In [16]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, max_len, n_segments, d_model, device):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)  # token embedding
        self.pos_embed = nn.Embedding(max_len, d_model)      # position embedding
        self.seg_embed = nn.Embedding(n_segments, d_model)  # segment(token type) embedding
        self.norm = nn.LayerNorm(d_model)
        self.device = device

    def forward(self, x, seg):
        #x, seg: (bs, len)
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long).to(self.device)
        pos = pos.unsqueeze(0).expand_as(x)  # (len,) -> (bs, len)
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)

## 3.2 Attention mask

In [17]:
def get_attn_pad_mask(seq_q, seq_k, device):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1).to(device)  # batch_size x 1 x len_k(=len_q), one is masking
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # batch_size x len_q x len_k

### Testing the attention mask

In [18]:
print(get_attn_pad_mask(input_ids, input_ids, device).shape)

torch.Size([6, 700, 700])


## 3.3 Encoder

The encoder has two main components: 

- Multi-head Attention
- Position-wise feed-forward network

First let's make the wrapper called `EncoderLayer`

In [19]:
class EncoderLayer(nn.Module):
    def __init__(self, n_heads, d_model, d_ff, d_k, device):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention(n_heads, d_model, d_k, device)
        self.pos_ffn       = PoswiseFeedForwardNet(d_model, d_ff)

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model]
        return enc_outputs, attn

Let's define the scaled dot attention, to be used inside the multihead attention

In [20]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k, device):
        super(ScaledDotProductAttention, self).__init__()
        self.scale = torch.sqrt(torch.FloatTensor([d_k])).to(device)

    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / self.scale # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context, attn 

Let's define the parameters first

In [21]:
n_layers = 6    # number of Encoder of Encoder Layer
n_heads  = 8    # number of heads in Multi-Head Attention
d_model  = 768  # Embedding Size
d_ff = 768 * 4  # 4*d_model, FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_segments = 2

Here is the Multiheadattention.

In [22]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, d_model, d_k, device):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_k
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, self.d_v * n_heads)
        self.device = device
    def forward(self, Q, K, V, attn_mask):
        # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
        residual, batch_size = Q, Q.size(0)
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # q_s: [batch_size x n_heads x len_q x d_k]
        k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # k_s: [batch_size x n_heads x len_k x d_k]
        v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1,2)  # v_s: [batch_size x n_heads x len_k x d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]

        # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        context, attn = ScaledDotProductAttention(self.d_k, self.device)(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v) # context: [batch_size x len_q x n_heads * d_v]
        output = nn.Linear(self.n_heads * self.d_v, self.d_model, device=self.device)(context)
        return nn.LayerNorm(self.d_model, device=self.device)(output + residual), attn # output: [batch_size x len_q x d_model]

Here is the PoswiseFeedForwardNet.

In [23]:
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        # (batch_size, len_seq, d_model) -> (batch_size, len_seq, d_ff) -> (batch_size, len_seq, d_model)
        return self.fc2(F.gelu(self.fc1(x)))

## 3.4 Putting them together

In [24]:
class BERT(nn.Module):
    def __init__(self, n_layers, n_heads, d_model, d_ff, d_k, n_segments, vocab_size, max_len, device):
        super(BERT, self).__init__()
        self.params = {'n_layers': n_layers, 'n_heads': n_heads, 'd_model': d_model,
                       'd_ff': d_ff, 'd_k': d_k, 'n_segments': n_segments,
                       'vocab_size': vocab_size, 'max_len': max_len}
        self.embedding = Embedding(vocab_size, max_len, n_segments, d_model, device)
        self.layers = nn.ModuleList([EncoderLayer(n_heads, d_model, d_ff, d_k, device) for _ in range(n_layers)])
        self.fc = nn.Linear(d_model, d_model)
        self.activ = nn.Tanh()
        self.linear = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.classifier = nn.Linear(d_model, 2)
        # decoder is shared with embedding layer
        embed_weight = self.embedding.tok_embed.weight
        n_vocab, n_dim = embed_weight.size()
        self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
        self.decoder.weight = embed_weight
        self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))
        self.device = device

    def forward(self, input_ids, segment_ids, masked_pos):
        output = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, self.device)
        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)
        # output : [batch_size, len, d_model], attn : [batch_size, n_heads, d_mode, d_model]
        
        # 1. predict next sentence
        # it will be decided by first token(CLS)
        h_pooled   = self.activ(self.fc(output[:, 0])) # [batch_size, d_model]
        logits_nsp = self.classifier(h_pooled) # [batch_size, 2]

        # 2. predict the masked token
        masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1)) # [batch_size, max_pred, d_model]
        h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]
        h_masked  = self.norm(F.gelu(self.linear(h_masked)))
        logits_lm = self.decoder(h_masked) + self.decoder_bias # [batch_size, max_pred, n_vocab]

        return logits_lm, logits_nsp
    
    def get_last_hidden_state(self, input_ids, segment_ids):
        output = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, self.device)
        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)

        return output

## 4. Training

In [25]:
from tqdm.auto import tqdm

n_layers = 12    # number of Encoder of Encoder Layer
n_heads  = 12    # number of heads in Multi-Head Attention
d_model  = 768  # Embedding Size
d_ff = d_model * 4  # 4*d_model, FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_segments = 2

num_epoch = 1000
model = BERT(
    n_layers, 
    n_heads, 
    d_model, 
    d_ff, 
    d_k, 
    n_segments, 
    vocab_size, 
    max_len, 
    device
).to(device)  # Move model to GPU

In [26]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [27]:
batch = make_batch()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))

# Move inputs to GPU
input_ids = input_ids.to(device)
segment_ids = segment_ids.to(device)
masked_tokens = masked_tokens.to(device)
masked_pos = masked_pos.to(device)
isNext = isNext.to(device)

# Wrap the epoch loop with tqdm
for epoch in tqdm(range(num_epoch), desc="Training Epochs"):
    optimizer.zero_grad()
    logits_lm, logits_nsp = model(input_ids, segment_ids, masked_pos)    
    #logits_lm: (bs, max_mask, vocab_size) ==> (6, 5, 34)
    #logits_nsp: (bs, yes/no) ==> (6, 2)

    #1. mlm loss
    #logits_lm.transpose: (bs, vocab_size, max_mask) vs. masked_tokens: (bs, max_mask)
    loss_lm = criterion(logits_lm.transpose(1, 2), masked_tokens) # for masked LM
    loss_lm = (loss_lm.float()).mean()
    #2. nsp loss
    #logits_nsp: (bs, 2) vs. isNext: (bs, )
    loss_nsp = criterion(logits_nsp, isNext) # for sentence classification
    
    #3. combine loss
    loss = loss_lm + loss_nsp
    if epoch % 100 == 0:
        print('Epoch:', '%02d' % (epoch), 'loss =', '{:.6f}'.format(loss))
    loss.backward()
    optimizer.step()

Training Epochs:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch: 00 loss = 110.094528


Training Epochs:  10%|█         | 100/1000 [01:41<15:09,  1.01s/it]

Epoch: 100 loss = 4.995957


Training Epochs:  20%|██        | 200/1000 [03:23<13:31,  1.01s/it]

Epoch: 200 loss = 4.389479


Training Epochs:  30%|███       | 300/1000 [05:03<11:29,  1.01it/s]

Epoch: 300 loss = 4.008092


Training Epochs:  40%|████      | 400/1000 [06:43<09:49,  1.02it/s]

Epoch: 400 loss = 4.012590


Training Epochs:  50%|█████     | 500/1000 [08:22<08:17,  1.00it/s]

Epoch: 500 loss = 3.990366


Training Epochs:  60%|██████    | 600/1000 [10:02<06:39,  1.00it/s]

Epoch: 600 loss = 3.966225


Training Epochs:  70%|███████   | 700/1000 [11:41<04:59,  1.00it/s]

Epoch: 700 loss = 3.967532


Training Epochs:  80%|████████  | 800/1000 [13:21<03:19,  1.00it/s]

Epoch: 800 loss = 3.967981


Training Epochs:  90%|█████████ | 900/1000 [15:01<01:39,  1.00it/s]

Epoch: 900 loss = 3.987228


Training Epochs: 100%|██████████| 1000/1000 [16:41<00:00,  1.00s/it]


In [28]:
# Save the model after training
torch.save(model.state_dict(), 'bert_model.pth')
print("Model saved to bert_model.pth")

Model saved to bert_model.pth


## 5. Inference

Since our dataset is very small, it won't work very well, but just for the sake of demonstration.

In [29]:
# Predict mask tokens ans isNext
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(batch[2]))
print([id2word[w.item()] for w in input_ids[0] if id2word[w.item()] != '[PAD]'])
input_ids = input_ids.to(device)
segment_ids = segment_ids.to(device)
masked_tokens = masked_tokens.to(device)
masked_pos = masked_pos.to(device)
isNext = isNext.to(device)

logits_lm, logits_nsp = model(input_ids, segment_ids, masked_pos)
#logits_lm:  (1, max_mask, vocab_size) ==> (1, 5, 34)
#logits_nsp: (1, yes/no) ==> (1, 2)

#predict masked tokens
#max the probability along the vocab dim (2), [1] is the indices of the max, and [0] is the first value
logits_lm = logits_lm.data.cpu().max(2)[1][0].data.numpy() 
#note that zero is padding we add to the masked_tokens
print('masked tokens (words) : ',[id2word[pos.item()] for pos in masked_tokens[0]])
print('masked tokens list : ',[pos.item() for pos in masked_tokens[0]])
print('masked tokens (words) : ',[id2word[pos.item()] for pos in logits_lm])
print('predict masked tokens list : ', [pos for pos in logits_lm])

#predict nsp
logits_nsp = logits_nsp.cpu().data.max(1)[1][0].data.numpy()
print(logits_nsp)
print('isNext : ', True if isNext else False)
print('predict isNext : ',True if logits_nsp else False)

['[CLS]', 'lily', 'and', 'ben', 'were', 'twins', 'who', 'liked', 'to', 'help', 'their', 'mom', 'bake', 'pies', '[MASK]', 'day', 'mom', 'made', 'a', 'big', 'apple', 'pie', 'and', 'put', 'it', 'on', 'the', 'table', 'to', 'cool', 'she', 'told', 'lily', 'and', 'ben', 'not', 'to', 'touch', 'the', 'pie', 'until', 'it', 'was', 'ready', 'but', 'lily', 'and', 'ben', 'were', 'very', 'hungry', 'and', 'curious', 'they', 'wanted', 'to', '[MASK]', 'the', '[MASK]', 'they', 'waited', 'until', 'mom', 'went', 'to', 'the', 'garden', 'and', 'then', 'they', 'tiptoed', 'to', 'the', 'table', 'lily', 'reached', 'for', 'the', 'pie', 'and', 'lifted', 'it', 'with', 'both', 'hands', 'it', 'was', 'heavy', 'and', 'hot', '"be', 'careful', 'lily"', 'ben', 'whispered', '"don\'t', 'drop', 'the', 'pie"', 'but', 'lily', 'was', 'not', 'careful', 'she', 'lost', 'her', 'balance', 'and', 'the', 'pie', 'slipped', 'from', 'her', 'hands', 'it', 'fell', 'on', 'the', 'floor', 'with', 'a', 'loud', 'thud', 'the', 'pie', 'broke', 'i

Trying a bigger dataset should be able to see the difference.

## SNLI and MNLI datasets 

In [30]:
import datasets
snli = datasets.load_dataset('snli')
mnli = datasets.load_dataset('glue', 'mnli')
mnli['train'].features, snli['train'].features

({'premise': Value(dtype='string', id=None),
  'hypothesis': Value(dtype='string', id=None),
  'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None),
  'idx': Value(dtype='int32', id=None)},
 {'premise': Value(dtype='string', id=None),
  'hypothesis': Value(dtype='string', id=None),
  'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None)})

In [31]:
# List of datasets to remove 'idx' column from
mnli.column_names.keys()

dict_keys(['train', 'validation_matched', 'validation_mismatched', 'test_matched', 'test_mismatched'])

In [32]:
# Remove 'idx' column from each dataset
for column_names in mnli.column_names.keys():
    mnli[column_names] = mnli[column_names].remove_columns('idx')

In [33]:
mnli.column_names.keys()

dict_keys(['train', 'validation_matched', 'validation_mismatched', 'test_matched', 'test_mismatched'])

In [34]:
import numpy as np
np.unique(mnli['train']['label']), np.unique(snli['train']['label'])
#snli also have -1

(array([0, 1, 2]), array([-1,  0,  1,  2]))

In [35]:
# there are -1 values in the label feature, these are where no class could be decided so we remove
snli = snli.filter(
    lambda x: 0 if x['label'] == -1 else 1
)

In [36]:
import numpy as np
np.unique(mnli['train']['label']), np.unique(snli['train']['label'])
#snli also have -1

(array([0, 1, 2]), array([0, 1, 2]))

In [37]:
# Assuming you have your two DatasetDict objects named snli and mnli
from datasets import DatasetDict
# Merge the two DatasetDict objects
raw_dataset = DatasetDict({
    'train': datasets.concatenate_datasets([snli['train'], mnli['train']]).shuffle(seed=55).select(list(range(1000))),
    'test': datasets.concatenate_datasets([snli['test'], mnli['test_mismatched']]).shuffle(seed=55).select(list(range(100))),
    'validation': datasets.concatenate_datasets([snli['validation'], mnli['validation_mismatched']]).shuffle(seed=55).select(list(range(1000)))
})
#remove .select(list(range(1000))) in order to use full dataset
# Now, merged_dataset_dict contains the combined datasets from snli and mnli
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 100
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 1000
    })
})

In [38]:
sentences_premise = raw_dataset['train']['premise']
text_premise = [x.lower() for x in sentences_premise] #lower case
text_premise = [re.sub("[.,!?\\-]", '', x) for x in text_premise] #clean all symbols
# text

sentences_hypothesis = raw_dataset['train']['hypothesis']
text_hypothesis = [x.lower() for x in sentences_hypothesis] #lower case
text_hypothesis = [re.sub("[.,!?\\-]", '', x) for x in text_hypothesis] #clean all symbols
# text

In [39]:
from tqdm.auto import tqdm

# Combine everything into one to make vocab
word_list_premise = list(set(" ".join(text_premise).split()))
word_list_hypothesis = list((set(" ".join(text_hypothesis).split())))

word2id_premise = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}  # special tokens
word2id_hypothesis = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}  # special tokens

# Create the word2id in a single pass
for i, w in tqdm(enumerate(word_list_premise), desc="Creating word2id"):
    word2id_premise[w] = i + 4  # because 0-3 are already occupied

for i, w in tqdm(enumerate(word_list_hypothesis), desc="Creating word2id"):
    word2id_hypothesis[w] = i + 4  # because 0-3 are already occupied

# Precompute the id2word mapping (this can be done once after word2id is fully populated)
id2word_premise = {v: k for k, v in word2id_premise.items()}
vocab_size_premis = len(word2id_premise)


id2word_hypothesis = {v: k for k, v in word2id_hypothesis.items()}
vocab_size_hypothesis = len(word2id_hypothesis)


Creating word2id: 3956it [00:00, 3845345.68it/s]
Creating word2id: 2502it [00:00, 4374384.58it/s]


In [40]:
vocab_size = len(word2id)

# List of all tokens for the whole text
token_list_premise = []
token_list_hypothesis = []

# Process sentences more efficiently
for sentence in tqdm(text_premise, desc="Processing sentences"):
    token_list_premise.append([word2id_premise[word] for word in sentence.split()])

for sentence in tqdm(text_hypothesis, desc="Processing sentences"):
    token_list_hypothesis.append([word2id_hypothesis[word] for word in sentence.split()])

# Now token_list contains the tokenized sentences

Processing sentences: 100%|██████████| 1000/1000 [00:00<00:00, 463356.61it/s]
Processing sentences: 100%|██████████| 1000/1000 [00:00<00:00, 846137.58it/s]


In [41]:
batch_size = 6
max_mask   = 5  # max masked tokens when 15% exceed, it will only be max_pred
max_len    = 700 # maximum of length to be padded; 

In [42]:
def make_batch():
    batch = []
    #randomly choose two sentence so we can put [SEP]
    i = 0
    while i != 6:
        #retrieve the two sentences
        tokens_a = token_list_premise[i]
    
        #1. token embedding - append CLS and SEP
        input_ids = [word2id['[CLS]']] + tokens_a + [word2id['[SEP]']]
    
        #2. segment embedding - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
        segment_ids = [0] * (1 + len(tokens_a) + 1)
    
        #3. mask language modeling
        #masked 15%, but should be at least 1 but does not exceed max_mask
        n_pred =  min(max_mask, max(1, int(round(len(input_ids) * 0.15))))
        #get the pos that excludes CLS and SEP and shuffle them
        cand_maked_pos = [i for i, token in enumerate(input_ids) if token != word2id['[CLS]'] and token != word2id['[SEP]']]
        shuffle(cand_maked_pos)
        masked_pos = []
        #simply loop and change the input_ids to [MASK]
        for pos in cand_maked_pos[:n_pred]:
            masked_pos.append(pos)  #remember the position

        # pad the input_ids and segment ids until the max len
        n_pad = max_len - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)
    
        # pad the masked_tokens and masked_pos to make sure the lenth is max_mask
        if max_mask > n_pred:
            n_pad = max_mask - n_pred
            masked_pos.extend([0] * n_pad)
    
        #check if first sentence is really comes before the second sentence
        #also make sure positive is exactly half the batch size
        batch.append([input_ids, segment_ids, masked_pos]) # NotNext
        i += 1 
            
    return batch

In [43]:
batch_a = make_batch()

In [44]:
input_ids_a, segment_ids_a, masked_pos_a = map(torch.LongTensor, zip(*batch_a))
input_ids_a.shape, segment_ids_a.shape, masked_pos_a.shape

(torch.Size([6, 700]), torch.Size([6, 700]), torch.Size([6, 5]))

In [45]:
def make_batch():
    batch = []
    #randomly choose two sentence so we can put [SEP]
    i = 0
    while i != 6:
        #retrieve the two sentences
        tokens_b = token_list_hypothesis[i]
    
        #1. token embedding - append CLS and SEP
        input_ids = [word2id['[CLS]']] + tokens_b + [word2id['[SEP]']]
    
        #2. segment embedding - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
        segment_ids = [0] * (1 + len(tokens_b) + 1)
    
        #3. mask language modeling
        #masked 15%, but should be at least 1 but does not exceed max_mask
        n_pred =  min(max_mask, max(1, int(round(len(input_ids) * 0.15))))
        #get the pos that excludes CLS and SEP and shuffle them
        cand_maked_pos = [i for i, token in enumerate(input_ids) if token != word2id['[CLS]'] and token != word2id['[SEP]']]
        shuffle(cand_maked_pos)
        masked_pos = []
        #simply loop and change the input_ids to [MASK]
        for pos in cand_maked_pos[:n_pred]:
            masked_pos.append(pos)  #remember the position

        # pad the input_ids and segment ids until the max len
        n_pad = max_len - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)
    
        # pad the masked_tokens and masked_pos to make sure the lenth is max_mask
        if max_mask > n_pred:
            n_pad = max_mask - n_pred
            masked_pos.extend([0] * n_pad)
    
        #check if first sentence is really comes before the second sentence
        #also make sure positive is exactly half the batch size
        batch.append([input_ids, segment_ids, masked_pos]) # NotNext
        i += 1 
            
    return batch

In [46]:
batch_b = make_batch()

In [47]:
input_ids_b, segment_ids_b, masked_pos_b = map(torch.LongTensor, zip(*batch_a))
input_ids_b.shape, segment_ids_b.shape, masked_pos_b.shape

(torch.Size([6, 700]), torch.Size([6, 700]), torch.Size([6, 5]))

In [48]:
input_ids_a = input_ids_a.to(device)
segment_ids_a = segment_ids_a.to(device)
masked_pos_a = masked_pos_a.to(device)

input_ids_b = input_ids_b.to(device)
segment_ids_b = segment_ids_b.to(device)
masked_pos_b = masked_pos_b.to(device)


In [49]:
classifier_head = torch.nn.Linear(56127, 3).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
optimizer_classifier = torch.optim.Adam(classifier_head.parameters(), lr=2e-5)

criterion = nn.CrossEntropyLoss()


# define mean pooling function
def mean_pool(token_embeds, attention_mask):
    # reshape attention_mask to cover 768-dimension embeddings
    in_mask = attention_mask.unsqueeze(-1).expand(
        token_embeds.size()
    ).float()
    # perform mean-pooling but exclude padding tokens (specified by in_mask)
    pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(
        in_mask.sum(1), min=1e-9
    )
    return pool

In [50]:
i = 0
label = []
while i != 6:
    label.append(raw_dataset['train']['label'][i])
    i += 1 

In [51]:
label = torch.LongTensor(label)

In [52]:
num_epoch = 2
label_iter = 0
similarity_score = 0
# 1 epoch should be enough, increase if wanted
for epoch in range(num_epoch):
    model.train()  
    classifier_head.train()
    # initialize the dataloader loop with tqdm (tqdm == progress bar)
    for step, batch in enumerate(tqdm(range(num_epoch), leave=True)):
        # zero all gradients on each new step
        optimizer.zero_grad()
        optimizer_classifier.zero_grad()
        
        # prepare batches and more all to the active device

        label = label.to(device)
        # extract token embeddings from BERT at last_hidden_state
        u,_ = model(input_ids_a, segment_ids_a, masked_pos_a)  
        v,_ = model(input_ids_b, segment_ids_b, masked_pos_b)  

        # build the |u-v| tensor
        uv = torch.sub(u, v)   # batch_size,hidden_dim
        uv_abs = torch.abs(uv) # batch_size,hidden_dim
        
        # concatenate u, v, |u-v|
        x = torch.cat([u, v, uv_abs], dim=-1) # batch_size, 3*hidden_dim
        
        # process concatenated tensor through classifier_head
        x = classifier_head(x) #batch_size, classifer

        x = x.mean(dim=1)
        
        # calculate the 'softmax-loss' between predicted and true label
        loss = criterion(x, label)
        label_iter += 1
        
        # using loss, calculate gradients and then optimizerize
        loss.backward()
        optimizer.step()
        optimizer_classifier.step()

        similarity_score = similarity_score +  cosine_similarity(u.reshape(1, -1).detach().cpu(), v.reshape(1, -1).detach().cpu())[0, 0]

        
    print(f'Epoch: {epoch + 1} | loss = {loss.item():.6f}')
print('Similarity Score is',similarity_score/6)

100%|██████████| 2/2 [00:05<00:00,  2.56s/it]


Epoch: 1 | loss = 6.826176


100%|██████████| 2/2 [00:04<00:00,  2.20s/it]

Epoch: 2 | loss = 9.535089
Similarity Score is 0.666640559832255





In [53]:
total = 0 
accuracy = 0
for i in range(label.size(0)):
   total += 1
   if torch.argmax(x,dim=1)[i].item() == label[i].item():
      accuracy += 1
      
print('Accuracy is',((accuracy / total) * 100 ))

Accuracy is 16.666666666666664


In [54]:
# Save the model after training
torch.save(model.state_dict(), 'sentence_classification.pth')
print("Model saved to bert_model.pth")

Model saved to bert_model.pth


In [55]:
sentence_a = 'Your contribution helped make it possible for us to provide our students with a quality study.'
sentence_b = "Your contribution were of no help with our student's learn."

text_a = sentence_a.lower()#lower case
text_a = re.sub("[.,!?\\-]", '', text_a) #clean all symbols

text_b = sentence_b.lower()#lower case
text_b = re.sub("[.,!?\\-]", '', text_b) #clean all symbols



token_list_a = []

# Process sentences more efficiently

token_list_a.append([word2id[word] for word in text_a.split()])

token_list_b = []

# Process sentences more efficiently

token_list_b.append([word2id[word] for word in text_b.split()])



input_ids_a = [word2id['[CLS]']] + token_list_a[0] + [word2id['[SEP]']]
input_ids_b = [word2id['[CLS]']] + token_list_b[0] + [word2id['[SEP]']]

segment_ids_a = [0] * (1 + len(token_list_a[0]) + 1)
segment_ids_b = [0] * (1 + len(token_list_b[0]) + 1)

In [56]:
cand_maked_pos = [i for i, token in enumerate(input_ids_a) if token != word2id['[CLS]'] and token != word2id['[SEP]']]
shuffle(cand_maked_pos)
masked_pos_a = []
for pos in cand_maked_pos[:5]:
        masked_pos_a.append(pos)  

n_pad = max_len - len(input_ids_a)
input_ids_a.extend([0] * n_pad)
segment_ids_a.extend([0] * n_pad)

if max_mask > 5:
        n_pad = max_mask - 5
        masked_pos_a.extend([0] * n_pad)

In [57]:
cand_maked_pos = [i for i, token in enumerate(segment_ids_b) if token != word2id['[CLS]'] and token != word2id['[SEP]']]
shuffle(cand_maked_pos)
masked_pos_b = []
for pos in cand_maked_pos[:5]:
        masked_pos_b.append(pos)  

n_pad = max_len - len(input_ids_b)
input_ids_b.extend([0] * n_pad)
segment_ids_b.extend([0] * n_pad)

if max_mask > 5:
        n_pad = max_mask - 5
        masked_pos_b.extend([0] * n_pad)


In [58]:
input_ids_a = torch.LongTensor(input_ids_a).to(device)
segment_ids_a = torch.LongTensor(segment_ids_a).to(device)
masked_pos_a = torch.LongTensor(masked_pos_a).to(device)

In [59]:
input_ids_b = torch.LongTensor(input_ids_b).to(device)
segment_ids_b = torch.LongTensor(segment_ids_b).to(device)
masked_pos_b = torch.LongTensor(masked_pos_b).to(device)

In [60]:
result_a,_ = model(input_ids_a.unsqueeze(0), segment_ids_a.unsqueeze(0), masked_pos_a.unsqueeze(0))  
result_b,_ = model(input_ids_b.unsqueeze(0), segment_ids_b.unsqueeze(0), masked_pos_b.unsqueeze(0))  

In [61]:
result_a.reshape(1, -1)
result_b.reshape(1, -1)

tensor([[ 11.6655,  15.0687,  38.4533,  ..., -25.6486,  28.4112, -47.6650]],
       device='mps:0', grad_fn=<ViewBackward0>)

In [62]:

similarity_score = cosine_similarity(result_a.reshape(1, -1).detach().cpu(), result_b.reshape(1, -1).detach().cpu())[0, 0]

In [63]:
np.round(similarity_score)

1.0