# Fine-tuning a masked language model (PyTorch)

This notebook will fine-tune BERT models from the pretrained settings for Masked Language Modelling on the Wikitext-V2 dataset with a variety of weight decay and dropout values. It was made from modifying the Huggingface tutorial with the same name found here: https://huggingface.co/course/chapter7/3?fw=tf.  In addition, it also pre-processes and saves the Wikitext dataset.

In [None]:
from transformers import BertForMaskedLM

model_checkpoint = "bert-base-uncased"
model = BertForMaskedLM.from_pretrained(model_checkpoint)

In [None]:
bert_num_parameters = model.num_parameters() / 1_000_000
print(f"'>>> BERT number of parameters: {round(bert_num_parameters)}M'")


In [None]:
text = "This is a great [MASK]."

In [None]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained(model_checkpoint)

In [None]:
import torch

inputs = tokenizer(text, return_tensors="pt")
token_logits = model(**inputs).logits
# Find the location of [MASK] and extract its logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
# Pick the [MASK] candidates with the highest logits
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
    print(f"'>>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}'")

In [None]:
from datasets import load_dataset

wiki_dataset = load_dataset("wikitext", "wikitext-2-v1")
wiki_dataset

In [None]:
sample = wiki_dataset["train"].shuffle(seed=42).select(range(3))

for row in sample:
    print(f"\n'>>> Text: {row['text']}'")

In [None]:
def tokenize_function(examples):
    result = tokenizer(examples["text"])
    if tokenizer.is_fast:
        result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
    return result


# Use batched=True to activate fast multithreading!
tokenized_datasets = wiki_dataset.map(
    tokenize_function, batched=True, remove_columns=["text"]
)
tokenized_datasets

In [None]:
tokenizer.model_max_length

In [None]:
chunk_size = 128

In [None]:
# Slicing produces a list of lists for each feature
tokenized_samples = tokenized_datasets["train"][:3]

for idx, sample in enumerate(tokenized_samples["input_ids"]):
    print(f"'>>> Text {idx} length: {len(sample)}'")

In [None]:
concatenated_examples = {
    k: sum(tokenized_samples[k], []) for k in tokenized_samples.keys()
}
total_length = len(concatenated_examples["input_ids"])
print(f"'>>> Concatenated texts length: {total_length}'")

In [None]:
chunks = {
    k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
    for k, t in concatenated_examples.items()
}

for chunk in chunks["input_ids"]:
    print(f"'>>> Chunk length: {len(chunk)}'")

In [None]:
def group_texts(examples):
    # Concatenate all texts
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    # Compute length of concatenated texts
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the last chunk if it's smaller than chunk_size
    total_length = (total_length // chunk_size) * chunk_size
    # Split by chunks of max_len
    result = {
        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concatenated_examples.items()
    }
    # Create a new labels column
    result["labels"] = result["input_ids"].copy()
    return result

In [None]:
lm_datasets = tokenized_datasets.map(group_texts, batched=True)
lm_datasets

In [None]:
lm_datasets.save_to_disk("processed_dataset")

In [None]:
from datasets import DatasetDict

lm_datasets = DatasetDict()
lm_datasets = lm_datasets.load_from_disk("processed_dataset")
lm_datasets

In [None]:
tokenizer.decode(lm_datasets["train"][1]["input_ids"])

In [None]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

In [None]:
print(len(lm_datasets["train"]))

In [None]:
from transformers import TrainingArguments
import math

weight_decays = [0, 0.1, 0.01, 0.001]

for decay_val in weight_decays:
    
    model_checkpoint = "bert-base-uncased"
    model = BertForMaskedLM.from_pretrained(model_checkpoint)

    batch_size = 20
    # Show the training loss with every epoch
    logging_steps = len(lm_datasets["train"]) // batch_size
    model_name = model_checkpoint.split("/")[-1]

    training_args = TrainingArguments(
        output_dir=f"weight_decay_"+str(decay_val),
        overwrite_output_dir=True,
        evaluation_strategy="epoch",
        learning_rate=5e-5,
        num_train_epochs=15,
        weight_decay=decay_val,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        push_to_hub=False,
        fp16=True,
        logging_steps=logging_steps,
    )
    
    from transformers import Trainer

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=lm_datasets["train"],
        eval_dataset=lm_datasets["validation"],
        data_collator=data_collator,
    )
    
    eval_results = trainer.evaluate()
    print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")    

    trainer.train()
    
    eval_results = trainer.evaluate()
    print(f">>> Weight decay: " + str(decay_val) +  "Perplexity:" + str(math.exp(eval_results['eval_loss'])))
    
    trainer.save_model()

In [None]:
from transformers import BertConfig, TrainingArguments
import math
dropouts = [0, 0.2, 0.4]

for dropout_val in dropouts:
    batch_size = 20
    # Show the training loss with every epoch
    logging_steps = len(lm_datasets["train"]) // batch_size
    model_name = model_checkpoint.split("/")[-1]

    dropout_config = BertConfig(hidden_dropout_prob = dropout_val, attention_probs_dropout_prob = dropout_val)
    model_checkpoint = "bert-base-uncased"
    model = BertForMaskedLM.from_pretrained(model_checkpoint, config=dropout_config)
    
    training_args = TrainingArguments(
        output_dir=f"dropout_"+str(dropout_val),
        overwrite_output_dir=True,
        evaluation_strategy="epoch",
        learning_rate=5e-5,
        num_train_epochs=15,
        weight_decay=0.01,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        push_to_hub=False,
        fp16=True,
        logging_steps=logging_steps,
    )
    
    from transformers import Trainer

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=lm_datasets["train"],
        eval_dataset=lm_datasets["validation"],
        data_collator=data_collator,
    )
    
    eval_results = trainer.evaluate()
    print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")    

    trainer.train()
    
    eval_results = trainer.evaluate()
    print(f">>> Dropout: " + str(dropout_val) + "Perplexity:" + str(math.exp(eval_results['eval_loss'])))
    
    trainer.save_model()