## Encoder Pretraining - Masked Language Modeling
"The cat sits on the mat" => "The [MASK] sits on the [MASK]"  
Task: Predict "cat" and "mat"

In [2]:
import multiprocessing as mp
from datasets import load_dataset
import torch
import torch.nn as nn

from GPT import Encoder
from preprocessing import load_tokenizer_and_dataset
from parameters import REPLACE_FRACTION, MASK_FRACTION, RANDOM_TOKEN_FRACTION, VOCAB_SIZE, EMBEDDING_SIZE, CONTEXT_SIZE

In [3]:
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train[:1%]")
tokenizer = load_tokenizer_and_dataset("./models/tokenizer.pkl")[0]

Using the latest cached version of the dataset since wikitext couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'wikitext-103-raw-v1' at C:\Users\q603178\.cache\huggingface\datasets\wikitext\wikitext-103-raw-v1\0.0.0\b08601e04326c79dfdd32d625aee71d232d685c3 (last modified on Thu Aug  1 10:44:00 2024).


In [12]:
list(tokenizer.vocab.values())[-10:]

[b' interessanter',
 b' selfish',
 b' ego',
 b' egoistisch',
 b'<|STARTOFTEXT|>',
 b'<|ENDOFTEXT|>',
 b'<|PAD|>',
 '<|MASK|>',
 '<|MASK|>',
 '<|MASK|>']

In [6]:
# add [MASK] token to tokenizer
from functools import partial

if b"<|MASK|>" not in tokenizer.vocab.values():
    tokenizer.vocab[max(tokenizer.vocab) + 1] = b"<|MASK|>"
pad_token = [key for key, value in tokenizer.vocab.items() if value == b"<|PAD|>"][0]
mask_token = tokenizer.encode("<|MASK|>")[0], tokenizer.encode("<|PAD|>")[0]

def clean_text(examples):
    return {"text": [text.strip() for text in examples["text"] if text.strip() != ""]}

# Apply cleaning function
cleaned_dataset = dataset.map(
    clean_text,
    batched=True,
    num_proc=mp.cpu_count(),
    desc="Cleaning text"
)

def tokenize_function(sample, tokenizer, mask_id, replace_fraction, mask_fraction, random_token_fraction):
    import random

    tokens = tokenizer.encode(sample["text"])
    
    # store index and original value of replaced tokens to be used as targets in traing
    original_values = []
    indices_to_replace = random.sample(range(len(tokens)), int(len(tokens) * replace_fraction))

    for i in indices_to_replace:
        original_values.append(tokens[i])

        rand = random.random()
        if rand < mask_fraction:
            tokens[i] = mask_id
        elif rand < (mask_fraction + random_token_fraction):
            tokens[i] = random.randint(0, mask_id)
    
    return {
        "tokens": tokens,
        "masked_indices": indices_to_replace,
        "original_values": original_values
    }
    

tokenize_function_partial = partial(
    tokenize_function,
    tokenizer=tokenizer,
    mask_id=mask_id,
    replace_fraction=REPLACE_FRACTION,
    mask_fraction=MASK_FRACTION,
    random_token_fraction=RANDOM_TOKEN_FRACTION
)
tokenized_dataset = cleaned_dataset.map(tokenize_function_partial, num_proc=mp.cpu_count(), remove_columns=["text"], desc="Tokenize & Replace Tokens")

ValueError: too many values to unpack (expected 2)

In [None]:
tokenized_dataset

NameError: name 'tokenized_dataset' is not defined

In [None]:
class MLM(nn.Module):
    def __init__(self):
        super().__init__()

        # add an Embedding Table for Character Embedding
        self.token_embedding_table = nn.Embedding(VOCAB_SIZE+1, EMBEDDING_SIZE)
        self.position_embedding_table = nn.Embedding(CONTEXT_SIZE, EMBEDDING_SIZE)

        self.encoder = Encoder()

        self.lm_head = nn.Linear(EMBEDDING_SIZE, VOCAB_SIZE, bias=False)
        # weight sharing (use same weights for Input Embeddings (token_embedding_table) and lm_head)
        self.token_embedding_table.weight = self.lm_head.weight

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.ones_(module.weight)
            torch.nn.init.zeros_(module.bias)
