# Causal Language Modeling Task
A series of experiments demonstrating causal language modeling and training performance differences on a pretrained Reformer model. Model, datasets, and examples sourced from Huggingface.


---


## Test Models
**Reformer**
* 6-layer
* 256-hidden
* 2-heads
* 3M parameters
* Trained on English text: Crime and Punishment novel by Fyodor Dostoyevsky.

In [None]:
!pip install datasets transformers sentencepiece

In [None]:
# Imports
import math, random, torch
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments

In [None]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

In [None]:
def tokenize_function_ptb(examples):
    return tokenizer(examples["sentence"])

def tokenize_function_wt2_enwik8(examples):
    return tokenizer(examples["text"])
    
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

def prep_enwik8(path):
    file = path + 'enwik8'
    
    # Read file
    with open(file) as f:
        lines = f.readlines()
        
    random.shuffle(lines)

    # Calculate splits: 80/10/10 - train/val/test
    train_split = math.floor(len(lines)*.8)
    test_val_split = math.floor(len(lines)*.1)
    
    with open(path+'enwik8_train.txt', 'w') as train:
        with open(path+'enwik8_validation.txt', 'w') as val:
            with open(path+'enwik8_test.txt', 'w') as test:
                for i, line in enumerate(lines):
                    if i < train_split:
                        train.write(line)
                    elif i < train_split + test_val_split:
                        val.write(line)
                    else:
                        test.write(line)

In [None]:
# Hyperparameters
LEARNING_RATE = 2e-4
WEIGHT_DECAY = 0.01
NUM_EPOCHS = 30
BATCH_SIZE = 16
block_size = 2048
PUSH_HUB = False
AXIAL_POS = False

# Dataset selection
DATASET_SELECT = 0  # 0 = wikitext-2, 1 = penn treebank, 2 = enwik8
PATH_TO_ENWIK8 = '/content/data/' 

In [None]:
model_id = 'google/reformer-crime-and-punishment'

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, padding=True, truncation=True)

In [None]:
if DATASET_SELECT == 0:
  datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')
  tokenized_datasets = datasets.map(tokenize_function_wt2_enwik8, batched=True, num_proc=4, remove_columns=["text"])
elif DATASET_SELECT == 1:
  datasets = load_dataset("ptb_text_only")
  tokenized_datasets = datasets.map(tokenize_function_ptb, batched=True, num_proc=4, remove_columns=["sentence"])
elif DATASET_SELECT == 2:
  prep_enwik8(PATH_TO_ENWIK8)
  datasets = load_dataset('text', data_files={'train': PATH_TO_ENWIK8+'enwik8_train.txt','validation': PATH_TO_ENWIK8+'enwik8_validation.txt','test': PATH_TO_ENWIK8+'enwik8_test.txt'})
  tokenized_datasets = datasets.map(tokenize_function_wt2_enwik8, batched=True, num_proc=4, remove_columns=["text"])

In [None]:
#tokenized_datasets["train"][1]

In [None]:
lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
    num_proc=4,
)

In [None]:
#tokenizer.decode(lm_datasets["train"][1]["input_ids"])

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_id, axial_pos_embds=AXIAL_POS).to(device)

In [None]:
model_name = model_id.split("/")[-1]
training_args = TrainingArguments(
    f"{model_name}-finetuned-wikitext2",
    evaluation_strategy = "epoch",
    adafactor=True,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    push_to_hub=PUSH_HUB,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"],
    eval_dataset=lm_datasets["validation"],
)

In [None]:
train_results = trainer.train()

In [None]:
# Print Perplexity
eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

In [None]:
# Display Metrics
metrics = train_results.metrics
trainer.log_metrics("train", metrics)