In [None]:
from transformers import BertTokenizer, BertForMaskedLM, BertTokenizer

model_path = tokenizer_path = "bert-base-uncased"  
model = BertForMaskedLM.from_pretrained(model_path, output_hidden_states=True)
tokenizer = BertTokenizer.from_pretrained(tokenizer_path, do_lower_case=True)

In [None]:
from datasets import load_dataset
datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')

In [None]:
import random
def tokenize_function(examples):
    for i in range(len(examples["text"])):
        texttt = examples["text"][i].split()
        random.shuffle(texttt)
        examples["text"][i] = " ".join(texttt)

    return tokenizer(examples["text"], truncation=True)

In [None]:
tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=1, remove_columns=["text"])

In [None]:
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [None]:
# block_size = tokenizer.model_max_length
block_size = 128

lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
    num_proc=4,
)

In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    f"finetuned-scrumbled-wikitext2-3e-4",
    learning_rate=2.5e-4,
    num_train_epochs=8,
    save_steps=80,
)

In [None]:
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"],
    eval_dataset=lm_datasets["validation"],
    data_collator=data_collator,
)

In [None]:
trainer.train()