# Train LLM with MLM

In [1]:
from transformers import BertTokenizer, BertForMaskedLM
import torch
import numpy as np
import time

In [2]:
# from transformers import AdamW
from torch.optim import AdamW
from tqdm import tqdm

In [3]:
from llm_funcs import dataset_obj

In [4]:
batch_size=3*8 #24
epochs = 20

In [5]:
device = (torch.device('cuda') if torch.cuda.is_available() 
          else torch.device('cpu'))
device

device(type='cuda')

In [6]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased_local')
model     = BertForMaskedLM.from_pretrained('bert-base-init_local')

In [7]:
with open('mlm_text.text', 'r') as fp:
    text = fp.read().split('\n')

In [8]:
text[:5]

['From my grandfather Verus I learned good morals and the government of my temper.',
 'From the reputation and remembrance of my father, modesty and a manly character.',
 'From my mother, piety and beneficence, and abstinence, not only from evil deeds, but even from evil thoughts; and further, simplicity in my way of living, far removed from the habits of the rich.',
 'From my great-grandfather, not to have frequented public schools, and to have had good teachers at home, and to know that on such things a man should spend liberally.',
 "From my governor, to be neither of the green nor of the blue party at the games in the Circus, nor a partizan either of the Parmularius or the Scutarius at the gladiators' fights; from him too I learned endurance of labour, and to want little, and to work with my own hands, and not to meddle with other people's affairs, and not to be ready to listen to slander."]

In [9]:
inputs = tokenizer(text, return_tensors='pt', max_length=512,
                  truncation=True, padding='max_length')

In [10]:
inputs

{'input_ids': tensor([[  101,  2013,  2026,  ...,     0,     0,     0],
        [  101,  2013,  1996,  ...,     0,     0,     0],
        [  101,  2013,  2026,  ...,     0,     0,     0],
        ...,
        [  101,  2043, 15223,  ...,     0,     0,     0],
        [  101,  7887,  3288,  ...,     0,     0,     0],
        [  101,   102,     0,  ...,     0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 0,  ..., 0, 0, 0]])}

In [11]:
inputs.input_ids[0]

tensor([  101,  2013,  2026,  5615,  2310,  7946,  1045,  4342,  2204, 25288,
         1998,  1996,  2231,  1997,  2026, 12178,  1012,   102,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0, 

In [12]:
inputs.token_type_ids[0]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [13]:
inputs.attention_mask[0]

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [14]:
type(inputs.input_ids[0]), type(inputs.input_ids), 

(torch.Tensor, torch.Tensor)

In [15]:
inputs['labels'] = inputs.input_ids.detach().clone()
inputs

{'input_ids': tensor([[  101,  2013,  2026,  ...,     0,     0,     0],
        [  101,  2013,  1996,  ...,     0,     0,     0],
        [  101,  2013,  2026,  ...,     0,     0,     0],
        ...,
        [  101,  2043, 15223,  ...,     0,     0,     0],
        [  101,  7887,  3288,  ...,     0,     0,     0],
        [  101,   102,     0,  ...,     0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 0,  ..., 0, 0, 0]]), 'labels': tensor([[  101,  2013,  2026,  ...,     0,     0,     0],
        [  101,  2013,  1996,  ...,     0,     0,     0],
        [  101,  2013, 

In [16]:
# Special tokens
PAD  = 0
CLS  = 101
SEP  = 102
MASK = 103

In [17]:
rand = torch.rand(inputs.input_ids.shape)
rand.shape

torch.Size([508, 512])

In [18]:
rand = torch.rand(inputs.input_ids.shape)

# select 15%, remove special tokens from mask
mask_arr = ((rand < 0.15)*
            (inputs.input_ids != CLS)*
            (inputs.input_ids != SEP)*
            (inputs.input_ids != PAD))

mask_arr.shape

torch.Size([508, 512])

In [19]:
mask_arr

tensor([[False, False, False,  ..., False, False, False],
        [False, False,  True,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False,  True,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])

In [20]:
# index position of true values to be masked --> selection
selection = []

for i in np.arange(mask_arr.shape[0]):
    selection.append(torch.flatten(mask_arr[i].nonzero()).tolist())

In [21]:
selection[:5]

[[6, 10],
 [2, 6, 9],
 [10, 12, 22, 27, 31, 43, 45],
 [2, 5, 6, 11, 32],
 [10, 13, 20, 22, 40, 41, 48, 55, 58, 69, 75, 76, 87, 88]]

In [22]:
# replace mask token with selection 
for i in np.arange(mask_arr.shape[0]):
    inputs.input_ids[i, selection[i]] = MASK

In [23]:
inputs.input_ids

tensor([[  101,  2013,  2026,  ...,     0,     0,     0],
        [  101,  2013,   103,  ...,     0,     0,     0],
        [  101,  2013,  2026,  ...,     0,     0,     0],
        ...,
        [  101,  2043, 15223,  ...,     0,     0,     0],
        [  101,  7887,   103,  ...,     0,     0,     0],
        [  101,   102,     0,  ...,     0,     0,     0]])

In [24]:
# convert dataset to pytorch data object
dataset = dataset_obj(inputs)

In [25]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, # 24
                                        shuffle=True)

### training

In [26]:
model.to(device);
model.train();

In [27]:
optimizer = AdamW(model.parameters(), lr=1e-4); # 1e-5
clip_grad = 0.1

In [28]:
st = time.time()
model.zero_grad()
optimizer.zero_grad()

for epoch in np.arange(epochs):
    loop = tqdm(dataloader, leave=True)
    for batch in loop:
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids, attention_mask=attention_mask,
                       labels=labels)
        loss = outputs.loss
        loss.backward()

        # clip gradient
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad)
        optimizer.step()

        optimizer.zero_grad()
        model.zero_grad()
        
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())

  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 0: 100%|████████████████████████████████████████████████████| 22/22 [00:09<00:00,  2.20it/s, loss=1.17]
Epoch 1: 100%|█████████████████████████████████████████████████████| 22/22 [00:09<00:00,  2.26it/s, loss=1.4]
Epoch 2: 100%|████████████████████████████████████████████████████| 22/22 [00:09<00:00,  2.21it/s, loss=1.88]
Epoch 3: 100%|███████████████████████████████████████████████████| 22/22 [00:09<00:00,  2.22it/s, loss=0.673]
Epoch 4: 100%|███████████████████████████████████████████████████| 22/22 [00:09<00:00,  2.25it/s, loss=0.831]
Epoch 5: 100%|████████████████████████████████████████████████████| 22/22 [00:09<00:00,  2.24it/s, loss=6.78]
Epoch 6: 100%|███████████████████████████████████████████████████| 22/22 [00:09<00:00,  2.24it/s, loss=0.987]
Epoch 7: 100%|████████████████████████████████████████████████████| 22/22 [00:09<00:00,  2.24it/s, loss=2.25]
Epoch 8: 100%|████████████████████████████

In [29]:
ed = time.time()
print(f'time: {np.round((ed-st),2)} sec')
# 23.4G GPU, 197.62 sec

time: 197.62 sec
