In [1]:
import torch

EPOCH = 10
MAX_POSITION_EMBEDDINGS = 256

device = torch.device("cuda" if torch.cuda.is_available() else "mps")

In [2]:
from datasets import load_dataset

dataset = load_dataset("imdb")
dataset["train"].features

Found cached dataset imdb (/Users/louiechou/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


  0%|          | 0/3 [00:00<?, ?it/s]

{'text': Value(dtype='string', id=None),
 'label': ClassLabel(names=['neg', 'pos'], id=None)}

In [3]:
from transformers import BertTokenizerFast

tokenizer: BertTokenizerFast = BertTokenizerFast.from_pretrained("bert-base-uncased")


def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=MAX_POSITION_EMBEDDINGS,
    )

In [4]:
def key_padding_mask_func(examples: dict):
    examples.update(
        {
            "attention_mask": [
                i.to(dtype=torch.float32) for i in examples["attention_mask"]
            ]
        }
    )
    return examples


def removeHtml(text: dict):
    import re

    clean = re.compile("<.*?>")
    text["text"] = [re.sub(clean, "", t) for t in text["text"]]
    return text


tokenized_datasets = dataset.map(removeHtml, batched=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["text", "token_type_ids"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")
tokenized_datasets = tokenized_datasets.map(key_padding_mask_func, batched=True)

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

Loading cached processed dataset at /Users/louiechou/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-35289c1ebf9945c5.arrow
Loading cached processed dataset at /Users/louiechou/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-d6aa5d4b66ee0904.arrow
Loading cached processed dataset at /Users/louiechou/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-4154ece1a39afb06.arrow
Loading cached processed dataset at /Users/louiechou/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-792a6478f93045fe.arrow
Loading cached processed dataset at /Users/louiechou/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-dae20fdde3b95898.arrow
Loading cached 

In [5]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(3000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

Loading cached shuffled indices for dataset at /Users/louiechou/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-f9b86c15fd64c1a3.arrow
Loading cached shuffled indices for dataset at /Users/louiechou/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-a70352b40bb64e74.arrow


In [6]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm_probability=0.15, mlm=True
)

In [7]:
from bert import BertForMaskedLM
# from transformers import BertForMaskedLM
model = BertForMaskedLM(
    vocab_size=tokenizer.vocab_size,
    d_model=768,
    intermediate_size=4 * 768,
    max_position_embeddings=MAX_POSITION_EMBEDDINGS,
    num_attention_heads=8,
    hidden_dropout_prob=0.1,
    num_hidden_layers=10,
    output_attentions=True,
).to(device)

In [8]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="imdb_mlm_model",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=EPOCH,
    weight_decay=0.01,
    use_mps_device=torch.backends.mps.is_available(),
    gradient_accumulation_steps=4,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    data_collator=data_collator,
)
trainer.train()



  0%|          | 0/930 [00:00<?, ?it/s]

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 0.9306661486625671, 'eval_runtime': 26.5602, 'eval_samples_per_second': 37.65, 'eval_steps_per_second': 4.706, 'epoch': 0.99}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 0.8424407839775085, 'eval_runtime': 25.6855, 'eval_samples_per_second': 38.933, 'eval_steps_per_second': 4.867, 'epoch': 1.99}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 0.8057705760002136, 'eval_runtime': 26.3083, 'eval_samples_per_second': 38.011, 'eval_steps_per_second': 4.751, 'epoch': 3.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 0.8013680577278137, 'eval_runtime': 26.1877, 'eval_samples_per_second': 38.186, 'eval_steps_per_second': 4.773, 'epoch': 4.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 0.793107807636261, 'eval_runtime': 26.23, 'eval_samples_per_second': 38.124, 'eval_steps_per_second': 4.766, 'epoch': 4.99}
{'loss': 0.8798, 'learning_rate': 9.24731182795699e-06, 'epoch': 5.33}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 0.7834706902503967, 'eval_runtime': 26.2716, 'eval_samples_per_second': 38.064, 'eval_steps_per_second': 4.758, 'epoch': 5.99}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 0.7929794788360596, 'eval_runtime': 26.2496, 'eval_samples_per_second': 38.096, 'eval_steps_per_second': 4.762, 'epoch': 7.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 0.7877641320228577, 'eval_runtime': 26.2447, 'eval_samples_per_second': 38.103, 'eval_steps_per_second': 4.763, 'epoch': 8.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 0.7962309718132019, 'eval_runtime': 26.2673, 'eval_samples_per_second': 38.07, 'eval_steps_per_second': 4.759, 'epoch': 8.99}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 0.7843843102455139, 'eval_runtime': 26.2239, 'eval_samples_per_second': 38.133, 'eval_steps_per_second': 4.767, 'epoch': 9.92}
{'train_runtime': 2987.6515, 'train_samples_per_second': 10.041, 'train_steps_per_second': 0.311, 'train_loss': 0.8446862005418346, 'epoch': 9.92}


TrainOutput(global_step=930, training_loss=0.8446862005418346, metrics={'train_runtime': 2987.6515, 'train_samples_per_second': 10.041, 'train_steps_per_second': 0.311, 'train_loss': 0.8446862005418346, 'epoch': 9.92})

In [9]:
trainer.save_model("imdb_mlm_model")

In [10]:
import math

eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

  0%|          | 0/125 [00:00<?, ?it/s]

Perplexity: 2.21


In [11]:
from bertviz import head_view, model_view

input_text = "The sign of a good movie is that it can toy with our emotions"
inputs = tokenizer.encode(input_text, return_tensors="pt")

# inputs = {k: v.to(device) for k, v in inputs.items()}
# inputs["input_ids"].transpose_(0, 1)
# inputs["attention_mask"] = inputs["attention_mask"].to(dtype=torch.float32)

model.to("cpu")
loss, logits, attentions = model(
    input_ids=inputs, labels=inputs, attention_mask=torch.ones_like(inputs)
)

tokens = tokenizer.convert_ids_to_tokens(inputs)
model_view(attentions, tokens)

<IPython.core.display.Javascript object>