In [None]:
# Transformers installation
! pip install transformers datasets

In [14]:
import torch
import random
import math
from datasets import load_dataset
from transformers import AutoTokenizer, BigBirdForMaskedLM
from transformers import Trainer
from transformers import TrainingArguments
from transformers import DataCollatorForLanguageModeling
# from transformers import AutoModelForSequenceClassification


In [15]:
device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")

### Specify Dataset

In [16]:
Dataset = "WikiText-2" 
# Dataset = "Enwik8"
# Dataset = "PennTreeBank"

### Specify Hyper Parameters

In [None]:
MAX_SEQ_LENGTH = 160
BATCH_SIZE = 128
BLOCK_SIZE = 16
NUM_RANDOM_BLOCKS = 2
LR = 2e-4
WEIGHT_DECAY = 0.01
EPOCHS = 3

# Fine-tuning pretrained BigBird model

## Preparing the datasets

In [17]:
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base", use_fast=True)
# tokenizer = AutoTokenizer.from_pretrained('google/reformer-crime-and-punishment', use_fast=True)


In [18]:
def tokenize_function_ptb(examples):
    return tokenizer(examples["sentence"], padding="max_length", truncation=True)

def tokenize_function_wikitext2_enwik8(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)


def prep_enwik8(path):
    file = path + 'enwik8'
    
    # Read file
    with open(file) as f:
        lines = f.readlines()
        
    random.shuffle(lines)

    # Calculate splits: 80/10/10 - train/val/test
    train_split = math.floor(len(lines)*.8)
    test_val_split = math.floor(len(lines)*.1)
    
    with open(path+'enwik8_train.txt', 'w') as train:
        with open(path+'enwik8_validation.txt', 'w') as val:
            with open(path+'enwik8_test.txt', 'w') as test:
                for i, line in enumerate(lines):
                    if i < train_split:
                        train.write(line)
                    elif i < train_split + test_val_split:
                        val.write(line)
                    else:
                        test.write(line)

### Get tokenized datasets

In [None]:
if Dataset == "WikiText-2":
  raw_datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')
  tokenized_datasets = raw_datasets.map(tokenize_function_wikitext2_enwik8, batched=True, num_proc=4, remove_columns=["text"])
elif Dataset == "PennTreeBank":
  raw_datasets = load_dataset("ptb_text_only")
  tokenized_datasets = raw_datasets.map(tokenize_function_ptb, batched=True, num_proc=4, remove_columns=["sentence"])
elif Dataset == "Enwik8":
  !wget https://data.deepai.org/enwik8.zip
  !unzip -qq 'tiny-imagenet-200.zip'
  prep_enwik8('/content/')
  datasets = load_dataset('text', data_files={'train': '/content/enwik8_train.txt','validation': '/content/enwik8_validation.txt','test': '/content/enwik8_test.txt'})
  tokenized_datasets = datasets.map(tokenize_function_wikitext2_enwik8, batched=True, num_proc=4, remove_columns=["text"])

### Group text with a max sequence length

In [20]:
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= MAX_SEQ_LENGTH:
        total_length = (total_length // MAX_SEQ_LENGTH) * MAX_SEQ_LENGTH
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + max_seq_length] for i in range(0, total_length, MAX_SEQ_LENGTH)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [21]:
lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=BATCH_SIZE,
    num_proc=4,
)

### Next, load the pretrained model from the checkpoint and fine-tune

In [None]:
## clear cache
torch.cuda.empty_cache()

In [23]:
# model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)
model_checkpoint = "google/bigbird-roberta-base"
model = BigBirdForMaskedLM.from_pretrained(model_checkpoint, num_random_blocks=NUM_RANDOM_BLOCKS, block_size=BLOCK_SIZE).to(device)

Some weights of the model checkpoint at google/bigbird-roberta-base were not used when initializing BigBirdForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BigBirdForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BigBirdForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [25]:
## Specify training arguments

model_name = model_checkpoint.split("/")[-1]
training_args = TrainingArguments(
    f"{model_name}-finetuned-wikitext2",
    evaluation_strategy = "epoch",
    learning_rate=LR,
    weight_decay=WEIGHT_DECAY,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    # no_cuda = True,
    # load_best_model_at_end=True,
    # push_to_hub=True,
)

In [27]:
## Create the Trainer object to train using the API

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"],
    eval_dataset=lm_datasets["validation"],
    data_collator=data_collator,
)

In [None]:
train_results = trainer.train()

In [None]:
trainer.save_model()

### Print results

In [None]:
eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

In [None]:
# Display Metrics
metrics = train_results.metrics
trainer.log_metrics("train", metrics)