### This is a sample code snippet for fine-tuning a Reformer model.
### https://huggingface.co/docs/transformers/en/model_doc/reformer
### https://huggingface.co/google/reformer-crime-and-punishment

In [None]:
%%writefile fine_tune_reformer.py


# Lib versions
# transformers_version='4.6'
# pytorch_version='1.6'
# py_version='py36'


# Headers
import pandas as pd
import os
import torch
import warnings
import pickle as plk
from datasets import Dataset
from transformers import AutoModelForMaskedLM
from transformers import AutoTokenizer
from transformers import DataCollatorForLanguageModeling
from transformers import (
    ReformerForMaskedLM,
    ReformerTokenizer,
    ReformerConfig,
    Trainer,
    DataCollatorForLanguageModeling,
    TrainingArguments,
)


# Load the data
# Provide the path to your data. Here, we assume the data is a list, as shown in the example below
# [{'text': 'Large Transformer models routinely achieve state-of-the-art results on a number of tasks ...',
#   'id': 0}, 
#  {'text': 'The resulting model, the Reformer, performs on par with Transformer models while being ...',
#   'id': 1},]
data_location = "YOUR_DATA_PATH"
with open(data_location, "rb") as fin:
    all_texts = plk.load(fin)
df_train = pd.DataFrame(all_texts)
dataset = Dataset.from_pandas(df_train)


# Load the tokenizer
MODEL_CKPT = "google/reformer-crime-and-punishment"
tokenizer = ReformerTokenizer.from_pretrained(MODEL_CKPT)
tokenizer.add_special_tokens({"mask_token": '[MASK]'})
# print(tokenizer.mask_token_id)
# len(tokenizer)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# print(tokenizer.pad_token_id)
# len(tokenizer)


# Set the sequence length
sequence_length = 2 ** 14


# Prepare the dataset
def tokenize_function(batched_data):
    result = tokenizer(batched_data['text'], pad_to_max_length=True,
                       max_length=sequence_length, return_attention_mask=True,
                       padding='max_length', truncation=True, return_token_type_ids=False)
    # if tokenizer.is_fast:
    #     result['word_ids'] = [result.word_ids(i) for i in range(len(result['input_ids']))]
    return result

tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=['text', 'id'])
tokenized_datasets.set_format(type="torch", columns=["input_ids", "attention_mask"])
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
dataset_split = tokenized_datasets.train_test_split(test_size=0.1) # Split the dataset to train and test


# Load the Reformer model with the below architecture
config = {
  "attention_head_size": 128,
  "attn_layers": [
    "local",
    "local",
    "lsh",
    "local",
    "local",
    "local",
    "lsh",
    "local",
    "local",
    "local",
    "lsh",
    "local"
  ],
  "axial_norm_std": 1.0,
  "axial_pos_embds": True,
  "axial_pos_embds_dim": [
    256,
    768
  ],
  "axial_pos_shape": [
    128,
    128
  ],
  "chunk_size_feed_forward": 0,
  "chunk_size_lm_head": 0,
  "eos_token_id": 2,
  "feed_forward_size": 4096,
  "hidden_act": "relu",
  "hidden_dropout_prob": 0.2,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "is_decoder": False,
  "layer_norm_eps": 1e-12,
  "local_attention_probs_dropout_prob": 0.2,
  "local_attn_chunk_length": 128,
  "local_num_chunks_after": 0,
  "local_num_chunks_before": 1,
  "lsh_attention_probs_dropout_prob": 0.1,
  "lsh_attn_chunk_length": 256,
  "lsh_num_chunks_after": 0,
  "lsh_num_chunks_before": 1,
  "max_position_embeddings": 16384,
  "model_type": "reformer",
  "num_attention_heads": 8,
  "num_buckets": 512,
  "num_hashes": 1,
  "pad_token_id": 0,
  "vocab_size": 323  # +1 for [MASK] token
}
config = ReformerConfig(**config)
model = ReformerForMaskedLM(config)
model = model.train()


# Define the training args
training_args = TrainingArguments(
    output_dir="YOUR_OUTPUT_PATH",
    overwrite_output_dir=True,
    num_train_epochs = 1,
    per_device_train_batch_size = 2,
    gradient_accumulation_steps = 4,    
    per_device_eval_batch_size= 2,
    evaluation_strategy="epoch",
    learning_rate=1e-3,
    warmup_steps=0,
    weight_decay=0.001,
    logging_steps=4,
    save_steps=20,
    fp16=True,
    # logging_steps=logging_steps,
)


# Create the trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset_split["train"],
    eval_dataset=dataset_split["test"],
    data_collator=data_collator,
)


# Train
trainer.train()

# Evaluate the model
eval_result = trainer.evaluate(eval_dataset=dataset_split["test"])


# Saves the model
save_path = "YOUT_SAVE_PATH"
trainer.save_model(save_path)