## Load Pretrained Model and Tokenizer

In [1]:
from transformers import AutoModelForMaskedLM, AutoTokenizer

model_checkpoint = "distilbert-base-uncased"
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)


  from .autonotebook import tqdm as notebook_tqdm


## Load IMDb Dataset

In [2]:
from datasets import load_dataset

imdb_dataset = load_dataset("imdb")


In [3]:
imdb_dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})

## Tokenize the Dataset

In [4]:
def tokenize_function(examples):
    result = tokenizer(examples["text"])
    if tokenizer.is_fast:
        result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
    return result

tokenized_datasets = imdb_dataset.map(
    tokenize_function, batched=True, remove_columns=["text", "label"]
)


## Group Texts into Chunks

In [5]:
chunk_size = tokenizer.mask_token_id

# Drop the last chunk if it’s smaller than chunk_size.
def group_texts(examples):
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    total_length = (total_length // chunk_size) * chunk_size
    result = {
        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result


# Pad the last chunk until its length equals chunk_size.
def group_texts_with_padding(examples):
    # Concatenate all texts
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    # Compute total length of concatenated texts
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # Split by chunks of max_len
    result = {
        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concatenated_examples.items()
    }
    
    # If the last chunk is smaller than chunk_size, pad it
    for key in result.keys():
        if len(result[key][-1]) < chunk_size:
            padding_length = chunk_size - len(result[key][-1])
            if key == "input_ids":
                # Pad input_ids with the tokenizer's pad token ID
                result[key][-1] += [tokenizer.pad_token_id] * padding_length
            else:
                # Pad other keys (e.g., attention_mask) with 0
                result[key][-1] += [0] * padding_length

    # Create a new labels column
    result["labels"] = result["input_ids"].copy()
    return result


In [6]:
lm_datasets = tokenized_datasets.map(group_texts, batched=True)
lm_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 76170
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 74448
    })
    unsupervised: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 152809
    })
})

In [7]:
tokenizer.decode(lm_datasets["train"][1]["input_ids"])

'she wants to focus her attentions to making some sort of documentary on what the average swede thought about certain political issues such as the vietnam war and race issues in the united states. in between asking politicians and ordinary denizens of stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men. < br / > < br / > what kills me about i am curious - yellow is that 40 years ago, this was considered pornographic. really, the sex and nudity'

In [8]:
tokenizer.decode(lm_datasets["train"][1]["labels"])

'she wants to focus her attentions to making some sort of documentary on what the average swede thought about certain political issues such as the vietnam war and race issues in the united states. in between asking politicians and ordinary denizens of stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men. < br / > < br / > what kills me about i am curious - yellow is that 40 years ago, this was considered pornographic. really, the sex and nudity'

## Combine train and unsupervised Datasets for Training

In [9]:
from datasets import concatenate_datasets

training_dataset = concatenate_datasets([lm_datasets["train"], lm_datasets["unsupervised"]])
evaluation_dataset = lm_datasets["test"]


## Apply Random Masking to Evaluation Dataset

In [10]:
# There are some fluctuations in our perplexity scores with each training run, One way to eliminate this source of randomness is to apply the masking once on the whole test set, and then use the default data collator in 🤗 Transformers to collect the batches during evaluation.

In [11]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)


In [12]:
def insert_random_mask(batch):
    features = [dict(zip(batch, t)) for t in zip(*batch.values())]
    masked_inputs = data_collator(features)
    # Create a new "masked" column for each column in the dataset
    return {"masked_" + k: v.numpy() for k, v in masked_inputs.items()}

In [13]:
evaluation_dataset = evaluation_dataset.remove_columns(["word_ids"])
training_dataset = training_dataset.remove_columns(["word_ids"])

In [14]:
## apply random masking to the evaluation dataset in advance to eliminate the randomness during evaluation, which ensures consistent perplexity scores across training runs. 

eval_dataset = evaluation_dataset.map(
    insert_random_mask,
    batched=True,
    remove_columns=evaluation_dataset.column_names,
)


eval_dataset = eval_dataset.rename_columns(
    {
        "masked_input_ids": "input_ids",
        "masked_attention_mask": "attention_mask",
        "masked_labels": "labels",
    }
)

## Prepare DataLoaders

In [15]:
from torch.utils.data import DataLoader
from transformers import default_data_collator

# Define batch size
batch_size = 64


# Note: 'data_collator' is used here to ensure random masking for MLM tasks in every training batch
train_dataloader = DataLoader(
    training_dataset, 
    shuffle=True,
    batch_size=batch_size, 
    collate_fn=data_collator
)

# Evaluation DataLoader with default collator
# Note: 'default_data_collator' or None is used here because masking was already applied to the evaluation dataset
eval_dataloader = DataLoader(
    eval_dataset,
    batch_size=batch_size,
    collate_fn=default_data_collator  # Use default behavior; no random masking during evaluation
)


## Setup Optimizer and Learning Rate Scheduler

In [16]:
from torch.optim import AdamW
from transformers import get_scheduler

optimizer = AdamW(model.parameters(), lr=5e-5)

num_train_epochs = 10
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    "linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)


##  Initialize 🤗 Accelerate

In [17]:
from accelerate import Accelerator

accelerator = Accelerator()
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)


## Training Loop

In [18]:
import json
import os
import math
from tqdm.auto import tqdm
import torch

# Define directories
output_dir = "distilbert-finetuned-imdb-mlm-accelerate-checkpoint"
final_output_dir = "distilbert-finetuned-imdb-mlm-accelerate"
metrics_output_file = os.path.join(output_dir, "metrics.json")
log_history_file = os.path.join(output_dir, "log_history.json")

# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)

# Early Stopping Parameters
metric_for_best_model = "eval_loss"
greater_is_better = False  # Lower loss is better
patience = 3  # Number of epochs to wait for improvement
best_metric = float("inf") if not greater_is_better else -float("inf")
patience_counter = 0

# Initialize tracking variables
progress_bar = tqdm(range(num_training_steps))
all_metrics = {}  # To store metrics for each epoch
log_history = []  # To store log history for each epoch

# Training and evaluation loop
for epoch in range(num_train_epochs):
    # ===== Training Phase =====
    model.train()
    for batch in train_dataloader:
        # Forward pass and loss computation
        outputs = model(**batch)
        loss = outputs.loss

        # Backward pass and optimizer step
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        # Update progress bar
        progress_bar.update(1)

    # ===== Evaluation Phase =====
    model.eval()
    losses = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs = model(**batch)

        loss = outputs.loss
        # Gather losses across all devices
        losses.append(accelerator.gather(loss.repeat(batch_size)))

    # Concatenate all losses and compute mean loss
    losses = torch.cat(losses)
    losses = losses[: len(eval_dataset)]  # Truncate to match dataset size
    mean_loss = torch.mean(losses).item()
    try:
        perplexity = math.exp(mean_loss)
    except OverflowError:
        perplexity = float("inf")

    # Log epoch metrics
    epoch_metrics = {
        "epoch": epoch,
        "evaluation": {
            "mean_loss": mean_loss,
            "perplexity": perplexity,
        },
    }
    all_metrics[f"epoch_{epoch}"] = epoch_metrics
    print(f"Epoch {epoch}: Loss={mean_loss:.4f}, Perplexity={perplexity:.4f}")

    # Save metrics and log history
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        # Ensure output directory exists
        os.makedirs(output_dir, exist_ok=True)

        # Update global metrics file
        with open(metrics_output_file, "w") as f:
            json.dump(all_metrics, f, indent=4)

        # Update log history
        log_entry = {
            "epoch": epoch,
            "mean_loss": mean_loss,
            "perplexity": perplexity,
        }
        log_history.append(log_entry)
        with open(log_history_file, "w") as f:
            json.dump(log_history, f, indent=4)

    # ===== Early Stopping Logic =====
    current_metric = mean_loss if metric_for_best_model == "eval_loss" else perplexity
    if (greater_is_better and current_metric > best_metric) or (not greater_is_better and current_metric < best_metric):
        best_metric = current_metric
        patience_counter = 0  # Reset patience counter
        # Save the best model
        print(f"New best model found at epoch {epoch} with {metric_for_best_model}: {current_metric:.4f}")
        best_model_dir = os.path.join(output_dir, "best_model")
        accelerator.wait_for_everyone()
        if accelerator.is_main_process:
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.save_pretrained(best_model_dir, save_function=accelerator.save)
            tokenizer.save_pretrained(best_model_dir)
    else:
        patience_counter += 1
        print(f"No improvement for {patience_counter} epoch(s). Best {metric_for_best_model}: {best_metric:.4f}")

    # Stop training if patience is exceeded
    if patience_counter >= patience:
        print(f"Early stopping triggered after {patience} epochs of no improvement.")
        break

# ===== Save Final Model =====
if accelerator.is_main_process:
    print("Saving the final model...")
    final_model_dir = os.path.join(final_output_dir)
    os.makedirs(final_model_dir, exist_ok=True)

    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(final_model_dir, save_function=accelerator.save)
    tokenizer.save_pretrained(final_model_dir)

    # Save final metrics and logs
    with open(os.path.join(final_model_dir, "metrics.json"), "w") as f:
        json.dump(all_metrics, f, indent=4)
    with open(os.path.join(final_model_dir, "log_history.json"), "w") as f:
        json.dump(log_history, f, indent=4)

    print(f"Final model and metrics saved to {final_model_dir}")


 10%|█         | 3578/35780 [13:02<1:49:04,  4.92it/s]

Epoch 0: Loss=2.2544, Perplexity=9.5294
New best model found at epoch 0 with eval_loss: 2.2544


 20%|██        | 7156/35780 [27:36<1:37:49,  4.88it/s]  

Epoch 1: Loss=2.1854, Perplexity=8.8946
New best model found at epoch 1 with eval_loss: 2.1854


 30%|███       | 10734/35780 [42:11<1:24:59,  4.91it/s] 

Epoch 2: Loss=2.1430, Perplexity=8.5250
New best model found at epoch 2 with eval_loss: 2.1430


 40%|████      | 14312/35780 [56:45<1:14:37,  4.79it/s]  

Epoch 3: Loss=2.1107, Perplexity=8.2537
New best model found at epoch 3 with eval_loss: 2.1107


 50%|█████     | 17890/35780 [1:11:20<1:01:23,  4.86it/s]

Epoch 4: Loss=2.0868, Perplexity=8.0589
New best model found at epoch 4 with eval_loss: 2.0868


 60%|██████    | 21468/35780 [1:25:55<48:45,  4.89it/s]    

Epoch 5: Loss=2.0646, Perplexity=7.8823
New best model found at epoch 5 with eval_loss: 2.0646


 70%|███████   | 25046/35780 [1:40:29<36:47,  4.86it/s]    

Epoch 6: Loss=2.0485, Perplexity=7.7563
New best model found at epoch 6 with eval_loss: 2.0485


 80%|████████  | 28624/35780 [1:55:04<24:30,  4.87it/s]   

Epoch 7: Loss=2.0338, Perplexity=7.6430
New best model found at epoch 7 with eval_loss: 2.0338


 90%|█████████ | 32202/35780 [2:09:39<12:10,  4.90it/s]   

Epoch 8: Loss=2.0227, Perplexity=7.5585
New best model found at epoch 8 with eval_loss: 2.0227


100%|██████████| 35780/35780 [2:24:13<00:00,  4.89it/s]   

Epoch 9: Loss=2.0173, Perplexity=7.5178
New best model found at epoch 9 with eval_loss: 2.0173
Saving the final model...
Final model and metrics saved to distilbert-finetuned-imdb-mlm-accelerate


In [3]:
from transformers import pipeline

mask_filler = pipeline(
    "fill-mask", model="distilbert-finetuned-imdb-mlm-accelerate"
)

Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.


In [7]:
text = "This is a great [MASK]."
preds = mask_filler(text)

for pred in preds:
    print(f">>> {pred['sequence']}")

>>> this is a great movie.
>>> this is a great film.
>>> this is a great show.
>>> this is a great story.
>>> this is a great documentary.
