# Train BERT on WebText

### Some Flags

In [1]:
LOAD_DATASET = False
CHUNK_DATA = False
TRAIN_TOKENIZER = False
HUGGINGFACE_TOKENIZER = True

### Load OpenWebText dataset

In [2]:
from datasets import load_dataset

if LOAD_DATASET:
    dataset = load_dataset('openwebtext', split='train')

### Save the data to files

In [3]:
from pathlib import Path
import numpy as np

datadir = Path('splited_data')
datadir.mkdir(exist_ok=True)
file_name = str(datadir / 'text_{}.txt')

In [5]:
def chunker(seq, size):
    return (seq[pos:pos + size] for pos in range(0, len(seq), size))

if CHUNK_DATA:
    chunk_size = 10_000
    digits = np.ceil(np.log10(int(len(dataset) / chunk_size))).astype(int) + 2

    for i, sample in enumerate(chunker(dataset, chunk_size)):
        print(f'\r{i}/{int(len(dataset) / chunk_size)}', end='')
        sample = [x.replace('\n', '') for x in sample['text']]
        with open(file_name.format(str(i).zfill(digits)), 'w', encoding='utf-8') as fp:
            fp.write('\n'.join(sample))

### Tokenizer

Train tokenizer:

In [6]:
# !pip install tokenizers
from tokenizers import BertWordPieceTokenizer

# initialize
tokenizer = BertWordPieceTokenizer(clean_text=True, handle_chinese_chars=False, strip_accents=False, lowercase=True)

# and train
paths = [str(x) for x in datadir.glob('**/*.txt')]
vocab_size=30_000

In [7]:
if TRAIN_TOKENIZER:
    tokenizer.train(files=paths, vocab_size=vocab_size, min_frequency=2, limit_alphabet=1000, wordpieces_prefix='##', 
                    special_tokens=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'], show_progress=True)

    tokenizer_dir = Path('tokenizer')
    tokenizer_dir.mkdir(exist_ok=True)

    tokenizer.save_model(str(tokenizer_dir))

Load tokenizer:

In [8]:
from transformers import BertTokenizer


if HUGGINGFACE_TOKENIZER:
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
elif not TRAIN_TOKENIZER:
    tokenizer = BertTokenizer.from_pretrained(str(tokenizer_dir))

### Preprocess

Tokenize one file

In [9]:
with open(str(list(datadir.glob('**/*.txt'))[1]), 'r', encoding='utf-8') as fp:
    lines = fp.read().split('\n')

In [10]:
batch = tokenizer(lines, max_length=512, padding='max_length', truncation=True)

print(len(batch['input_ids']))
print(len(batch['token_type_ids']))
print(len(batch['attention_mask']))

10000
10000
10000


In [11]:
import torch

labels = torch.tensor([x for x in batch['input_ids']])
mask = torch.tensor([x for x in batch['attention_mask']])

In [12]:
PAD = tokenizer.pad_token_id
CLS = tokenizer.cls_token_id
SEP = tokenizer.sep_token_id
MASK = tokenizer.mask_token_id

special_chars = {'PAD': PAD, 'CLS': CLS, 'SEP': SEP, 'MASK': MASK}
for k, val in special_chars.items():
    print(k, ':', val)

PAD : 0
CLS : 101
SEP : 102
MASK : 103


In [13]:
mask_percent = 0.15

# make copy of labels tensor, this will be input_ids
input_ids = labels.detach().clone()

# create random array of floats with equal dims to input_ids
rand = torch.rand(input_ids.shape)

# mask random 15% where token is not 0 [PAD], 1 [CLS], or 2 [SEP]
mask_arr = (rand < mask_percent) * (input_ids != 0) * (input_ids != 1) * (input_ids != 2)

for mask_arr_i, input_ids_i, mask_i in zip(mask_arr, input_ids, mask):
    input_ids_i[mask_arr_i] = MASK
#     mask_i[mask_arr_i] = 0

In [14]:
encodings = {'input_ids': input_ids, 'attention_mask': mask, 'labels': labels}

In [15]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __len__(self):
        return self.encodings['input_ids'].shape[0]

    def __getitem__(self, i):
        return {key: tensor[i] for key, tensor in self.encodings.items()}

In [16]:
dataset = Dataset(encodings)
loader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)

### Prepare BERT

In [17]:
from transformers import BertConfig, BertForMaskedLM, AdamW

bert_config = BertConfig(vocab_size=len(tokenizer.vocab), 
                         hidden_size=768, 
                         num_hidden_layers=12,
                         num_attention_heads=12,
                         intermediate_size=3072,
                         hidden_dropout_prob=0.1,
                         attention_probs_dropout_prob=0.1,
                         max_position_embeddings=512,
                         type_vocab_size=2,
                         pad_token_id=PAD,
                         position_embedding_type='absolute',
                         use_cache=True)

In [18]:
bert = BertForMaskedLM(bert_config)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
bert.to(device)
bert.train()

optim = AdamW(bert.parameters(), lr=1e-4)

### Train

In [None]:
from tqdm import tqdm

epochs = 2

for epoch in range(epochs):
    loop = tqdm(loader, leave=True)
    for batch in loop:
        # initialize calculated gradients (from prev step)
        optim.zero_grad()
        
        # pull all tensor batches required for training
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # process
        outputs = bert(input_ids, attention_mask=attention_mask, labels=labels)

        loss = outputs.loss
        loss.backward()
        optim.step()
        
        # print relevant info to progress bar
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())

Epoch 0:   8%|█▌                  | 49/625 [21:43<3:57:22, 24.73s/it, loss=4.32]