# Fine-tuning a masked language model (TensorFlow)

Install the Transformers, Datasets, and Evaluate libraries to run this notebook.

In [1]:
!pip -q install datasets evaluate transformers[sentencepiece]
!apt install git-lfs
!pip -q install --upgrade fsspec
!pip -q install --upgrade datasets

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
git-lfs is already the newest version (3.0.2-1ubuntu0.3).
0 upgraded, 0 newly installed, 0 to remove and 35 not upgraded.
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.1/199.1 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torch 2.6.0+cu124 requires nvidia-cublas-cu12==12.4.5.8; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cublas-cu12 12.5.3.2 which is incompatible.
torch 2.6.0+cu124 requires nvidia-cuda-cupti-cu12==12.4.127; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cuda-cupti-cu12 12.5.82 which is 

You will need to setup git, adapt your email and name in the following cell.

In [None]:
!git config --global user.email "your email"
!git config --global user.name "your name"

You will also need to be logged in to the Hugging Face Hub. Execute the following and enter your credentials.

In [None]:
from huggingface_hub import notebook_login

notebook_login()

# Download DistilBERT

In [4]:
from transformers import TFAutoModelForMaskedLM

model_checkpoint = "distilbert-base-uncased"
model = TFAutoModelForMaskedLM.from_pretrained(model_checkpoint)

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

All PyTorch model weights were used when initializing TFDistilBertForMaskedLM.

All the weights of TFDistilBertForMaskedLM were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertForMaskedLM for predictions without further training.


In [6]:
model.summary()

Model: "tf_distil_bert_for_masked_lm"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 distilbert (TFDistilBertMa  multiple                  66362880  
 inLayer)                                                        
                                                                 
 vocab_transform (Dense)     multiple                  590592    
                                                                 
 vocab_layer_norm (LayerNor  multiple                  1536      
 malization)                                                     
                                                                 
 vocab_projector (TFDistilB  multiple                  23866170  
 ertLMHead)                                                      
                                                                 
Total params: 66985530 (255.53 MB)
Trainable params: 66985530 (255.53 MB)
Non-trainable params: 0 (0.00 

In [7]:
text = "This is a great [MASK]."

In [8]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [9]:
tokenizer

DistilBertTokenizerFast(name_or_path='distilbert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

In [11]:
import numpy as np
import tensorflow as tf

inputs = tokenizer(text, return_tensors="np")
token_logits = model(**inputs).logits
# print(token_logits)
# Find the location of [MASK] and extract its logits
mask_token_index = np.argwhere(inputs["input_ids"] == tokenizer.mask_token_id)[0, 1]
mask_token_logits = token_logits[0, mask_token_index, :]
# Pick the [MASK] candidates with the highest logits
# We negate the array before argsort to get the largest, not the smallest, logits
top_5_tokens = np.argsort(-mask_token_logits)[:5].tolist()

for token in top_5_tokens:
    print(f">>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}")

>>> This is a great deal.
>>> This is a great success.
>>> This is a great adventure.
>>> This is a great idea.
>>> This is a great feat.


# The dataset
## Large Movie Review Dataset (or IMDb for short)

In [12]:
from datasets import load_dataset

imdb_dataset = load_dataset("imdb")
imdb_dataset

README.md:   0%|          | 0.00/7.81k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

unsupervised-00000-of-00001.parquet:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})

In [13]:
sample = imdb_dataset["train"].shuffle(seed=42).select(range(3))

for row in sample:
    print(f"\n'>>> Review: {row['text']}'")
    print(f"'>>> Label: {row['label']}'")


'>>> Review: There is no relation at all between Fortier and Profiler but the fact that both are police series about violent crimes. Profiler looks crispy, Fortier looks classic. Profiler plots are quite simple. Fortier's plot are far more complicated... Fortier looks more like Prime Suspect, if we have to spot similarities... The main character is weak and weirdo, but have "clairvoyance". People like to compare, to judge, to evaluate. How about just enjoying? Funny thing too, people writing Fortier looks American but, on the other hand, arguing they prefer American series (!!!). Maybe it's the language, or the spirit, but I think this series is more English than American. By the way, the actors are really good and funny. The acting is not superficial at all...'
'>>> Label: 1'

'>>> Review: This movie is a great. The plot is very true to the book which is a classic written by Mark Twain. The movie starts of with a scene where Hank sings a song with a bunch of kids called "when you stu

In [17]:
label = imdb_dataset["train"].features["label"]
print(f"label:{label}")

label_names = label.names
print(f"label_names:{label_names}")

label:ClassLabel(names=['neg', 'pos'], id=None)
label_names:['neg', 'pos']


**✏️ Try it out!** Create a random sample of the unsupervised split and verify that the labels are neither 0 nor 1. While you’re at it, you could also check that the labels in the train and test splits are indeed 0 or 1 — this is a useful sanity check that every NLP practitioner should perform at the start of a new project!

In [21]:
sample = imdb_dataset["unsupervised"].shuffle(seed=42).select(range(3))

for row in sample:
    print(f"\n'>>> Review: {row['text']}'")
    print(f"'>>> Label: {row['label']}'")


'>>> Review: If you've seen the classic Roger Corman version starring Vincent Price it's hard to put it out of your head, but you probably should do because this one is totally different. Subtlety has been abandoned in favour of gross-out horror - nudity, gore and all-round unpleasantness. OK it's ridiculous, trashy, sensationalised and historically dubious (did any members of the Inquisition really wear horn-rimmed glasses?), but despite all this it is strangely compelling. I literally couldn't tear myself away from the screen until the end of the movie. If there's a bigger compliment you can pay to a film I don't know what it is.'
'>>> Label: -1'

'>>> Review: For me, this was the most moving film of the decade. Samira Makhmalbaf shows pure bravery and vision in the making. She has an intelligence and gift for speaking to the people, regardless of their nationality or beliefs. I am inspired and touched by her humanity and can only hope that she has touched many people the same way. 

**Sanity Check for Label Values.**
To verify that label is not 0 or 1 in the unsupervised split (and that train/test splits do have binary labels), you could do:

In [19]:
# Check label values in all splits
print("Train label example:", imdb_dataset["train"][0]["label"])
print("Test label example:", imdb_dataset["test"][0]["label"])
print("Unsupervised label example:", imdb_dataset["unsupervised"][0]["label"])

Train label example: 0
Test label example: 0
Unsupervised label example: -1


Or even better:

In [20]:
print("Unique labels in train:", set(imdb_dataset["train"]["label"]))
print("Unique labels in test:", set(imdb_dataset["test"]["label"]))
print("Unique labels in unsupervised:", set(imdb_dataset["unsupervised"]["label"]))

Unique labels in train: {0, 1}
Unique labels in test: {0, 1}
Unique labels in unsupervised: {-1}


# Preprocessing the data

In [22]:
def tokenize_function(examples):
    result = tokenizer(examples["text"])
    if tokenizer.is_fast:
        result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
    return result


# Use batched=True to activate fast multithreading!
tokenized_datasets = imdb_dataset.map(
    tokenize_function, batched=True, remove_columns=["text", "label"]
)
tokenized_datasets

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (720 > 512). Running this sequence through the model will result in indexing errors


Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids'],
        num_rows: 50000
    })
})

In [23]:
tokenizer.model_max_length

512

**✏️ Try it out!** Some Transformer models, like **BigBird** and **Longformer**, have a much longer context length than BERT and other early Transformer models. Instantiate the tokenizer for one of these checkpoints and verify that the model_max_length agrees with what’s quoted on its model card.

In [24]:
from transformers import AutoTokenizer, AutoModelForPreTraining

tokenizer_bigbird = AutoTokenizer.from_pretrained("google/bigbird-roberta-base")
tokenizer_bigbird.model_max_length

tokenizer_config.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/760 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/846k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/775 [00:00<?, ?B/s]

4096

In [25]:
chunk_size = 128

In [26]:
# Slicing produces a list of lists for each feature
tokenized_samples = tokenized_datasets["train"][:3]

for idx, sample in enumerate(tokenized_samples["input_ids"]):
    print(f"'>>> Review {idx} length: {len(sample)}'")

'>>> Review 0 length: 363'
'>>> Review 1 length: 304'
'>>> Review 2 length: 133'


In [27]:
concatenated_examples = {
    k: sum(tokenized_samples[k], []) for k in tokenized_samples.keys()
}
total_length = len(concatenated_examples["input_ids"])
print(f"'>>> Concatenated reviews length: {total_length}'")

'>>> Concatenated reviews length: 800'


In [28]:
chunks = {
    k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
    for k, t in concatenated_examples.items()
}

for chunk in chunks["input_ids"]:
    print(f"'>>> Chunk length: {len(chunk)}'")

'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 128'
'>>> Chunk length: 32'


1.   Drop the last chunk if it’s smaller than chunk_size.
2.   Pad the last chunk until its length equals chunk_size.

We’ll take the first approach here, so let’s wrap all of the above logic in a single function that we can apply to our tokenized datasets:

In [64]:
def group_texts(examples):
    # Concatenate all texts
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    # Compute length of concatenated texts
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the last chunk if it's smaller than chunk_size
    total_length = (total_length // chunk_size) * chunk_size
    # Split by chunks of max_len
    result = {
        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concatenated_examples.items()
    }
    # Create a new labels column
    result["labels"] = result["input_ids"].copy()
    return result

In [65]:
lm_datasets = tokenized_datasets.map(group_texts, batched=True)
lm_datasets

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 61291
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 59904
    })
    unsupervised: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 122957
    })
})

In [31]:
tokenizer.decode(lm_datasets["train"][1]["input_ids"])

"as the vietnam war and race issues in the united states. in between asking politicians and ordinary denizens of stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men. < br / > < br / > what kills me about i am curious - yellow is that 40 years ago, this was considered pornographic. really, the sex and nudity scenes are few and far between, even then it ' s not shot like some cheaply made porno. while my countrymen mind find it shocking, in reality sex and nudity are a major staple in swedish cinema. even ingmar bergman,"

In [32]:
tokenizer.decode(lm_datasets["train"][1]["labels"])

"as the vietnam war and race issues in the united states. in between asking politicians and ordinary denizens of stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men. < br / > < br / > what kills me about i am curious - yellow is that 40 years ago, this was considered pornographic. really, the sex and nudity scenes are few and far between, even then it ' s not shot like some cheaply made porno. while my countrymen mind find it shocking, in reality sex and nudity are a major staple in swedish cinema. even ingmar bergman,"

# Fine-tuning DistilBERT with the Trainer API

In [33]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

To see how the random masking works, let’s feed a few examples to the data collator. Since it expects a list of dicts, where each dict represents a single chunk of contiguous text, we first iterate over the dataset before feeding the batch to the collator. We remove the "word_ids" key for this data collator as it does not expect it:

In [37]:
samples = [lm_datasets["train"][i] for i in range(2)]
for sample in samples:
    _ = sample.pop("word_ids")

for chunk in data_collator(samples)["input_ids"]:
    print(f"\n'>>> {tokenizer.decode(chunk)}'")


'>>> [CLS] [MASK] rented i am [MASK] - ct [MASK] my video store because of all the controversy that surrounded it when it was first released in 1967. i also [MASK] that at first it was seized by u. s [MASK] customs if it ever tried to [MASK] this [MASK], therefore being a fan of films considered " controversial " i really had to [MASK] this for myself. < br / > < br / > the plot is centered around a [MASK] swedish [MASK] [MASK] named lena who wants to learn everything she can about life. in particular she wants to focus her [MASK]s [MASK] making some sort of documentary on what the average swede thought [MASK] certain political issues such'

'>>> as the vietnam war and race issues in the united states. in between [MASK] politicians and ordinary denize [MASK] of stockholm about their opinions on [MASK], she has sex with her drama teacher, classmates, and married men [MASK] < br / > < br / [MASK] what kills me about i [MASK] curious - [MASK] is [MASK] 40 lowest [MASK], this was consider

**✏️ Try it out**! Run the code snippet above several times to see the random masking happen in front of your very eyes! Also replace the tokenizer.decode() method with tokenizer.convert_ids_to_tokens() to see that sometimes a single token from a given word is masked, and not the others.

In [40]:
samples = [lm_datasets["train"][i] for i in range(2)]
for sample in samples:
    _ = sample.pop("word_ids")

for chunk in data_collator(samples)["input_ids"]:
    print(f"\n'>>> {tokenizer.convert_ids_to_tokens(chunk)}'")


'>>> ['[CLS]', 'i', 'rented', 'i', 'am', 'curious', '-', 'yellow', 'from', 'my', 'video', 'store', 'because', 'of', 'all', 'the', 'controversy', 'that', 'surrounded', 'it', 'when', 'it', 'was', '[MASK]', 'released', 'in', '1967', '.', 'i', 'also', 'heard', 'that', 'at', 'first', 'it', 'was', 'seized', 'by', 'u', '.', 's', '.', 'customs', 'if', 'it', 'ever', 'tried', 'spreading', 'enter', 'this', '[MASK]', ',', 'therefore', 'being', 'a', 'fan', 'of', 'films', 'considered', '"', 'controversial', '"', 'i', '[MASK]', 'had', 'to', 'see', 'this', 'for', 'myself', '.', '<', 'br', '/', '>', '<', 'br', '/', '>', 'the', 'plot', 'is', 'centered', 'around', 'a', 'young', 'swedish', 'drama', 'student', '[MASK]', 'lena', 'who', 'wants', 'to', 'learn', 'everything', 'she', 'can', 'about', '##ɐ', '[MASK]', 'in', 'particular', 'she', 'wants', 'to', 'focus', '[MASK]', 'attention', '##s', 'to', 'making', 'some', 'sort', 'of', 'documentary', '[MASK]', '[MASK]', 'the', 'monsoon', '[MASK]', '##ede', 'thoug

In [66]:
import collections
import numpy as np

from transformers.data.data_collator import tf_default_data_collator

wwm_probability = 0.2


def whole_word_masking_data_collator(features):
    for feature in features:
        word_ids = feature.pop("word_ids")

        # Create a map between words and corresponding token indices
        mapping = collections.defaultdict(list)
        current_word_index = -1
        current_word = None
        for idx, word_id in enumerate(word_ids):
            if word_id is not None:
                if word_id != current_word:
                    current_word = word_id
                    current_word_index += 1
                mapping[current_word_index].append(idx)

        # Randomly mask words
        mask = np.random.binomial(1, wwm_probability, (len(mapping),))
        input_ids = feature["input_ids"]
        labels = feature["labels"]
        new_labels = [-100] * len(labels)
        for word_id in np.where(mask)[0]:
            word_id = word_id.item()
            for idx in mapping[word_id]:
                new_labels[idx] = labels[idx]
                input_ids[idx] = tokenizer.mask_token_id
        feature["labels"] = new_labels

    return tf_default_data_collator(features)

In [45]:
samples = [lm_datasets["train"][i] for i in range(2)]
batch = whole_word_masking_data_collator(samples)

for chunk in batch["input_ids"]:
    print(f"\n'>>> {tokenizer.decode(chunk)}'")


'>>> [CLS] i rented i am curious - yellow [MASK] my video store because of all the [MASK] that surrounded it when [MASK] was first released in 1967. [MASK] also heard [MASK] at first it was seized by u. [MASK] [MASK] customs if it ever tried to enter this [MASK], [MASK] being a [MASK] [MASK] [MASK] considered " controversial [MASK] i really had to see this for myself. < br / > < br [MASK] [MASK] the plot is centered around a young [MASK] drama student named lena who [MASK] to learn [MASK] she can [MASK] life. [MASK] particular [MASK] wants to focus her attentions to making some sort [MASK] [MASK] [MASK] what the average [MASK] [MASK] thought about certain political issues such'

'>>> [MASK] the vietnam war [MASK] race issues in the united states. in between asking politicians and ordinary denizens of stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and [MASK] [MASK]. < br / > [MASK] [MASK] / [MASK] [MASK] kills [MASK] [MASK] i am curious - ye

**✏️ Try it out!** Run the code snippet above several times to see the random masking happen in front of your very eyes! Also replace the tokenizer.decode() method with tokenizer.convert_ids_to_tokens() to see that the tokens from a given word are always masked together.

In [47]:
samples = [lm_datasets["train"][i] for i in range(2)]
batch = whole_word_masking_data_collator(samples)

for chunk in batch["input_ids"]:
    print(f"\n'>>> {tokenizer.convert_ids_to_tokens(chunk)}'")


'>>> ['[CLS]', 'i', '[MASK]', 'i', 'am', 'curious', '-', 'yellow', 'from', 'my', 'video', 'store', '[MASK]', 'of', '[MASK]', 'the', 'controversy', 'that', 'surrounded', 'it', 'when', 'it', 'was', '[MASK]', 'released', '[MASK]', '[MASK]', '.', 'i', 'also', '[MASK]', '[MASK]', 'at', '[MASK]', 'it', '[MASK]', 'seized', 'by', '[MASK]', '.', 's', '[MASK]', '[MASK]', 'if', 'it', 'ever', 'tried', '[MASK]', 'enter', '[MASK]', 'country', ',', 'therefore', 'being', 'a', 'fan', '[MASK]', 'films', '[MASK]', '[MASK]', 'controversial', '"', 'i', 'really', 'had', 'to', 'see', '[MASK]', 'for', 'myself', '.', '<', '[MASK]', '/', '>', '<', 'br', '/', '>', 'the', 'plot', '[MASK]', 'centered', '[MASK]', 'a', 'young', '[MASK]', 'drama', 'student', '[MASK]', 'lena', 'who', 'wants', 'to', 'learn', 'everything', 'she', 'can', '[MASK]', 'life', '.', 'in', 'particular', '[MASK]', 'wants', 'to', 'focus', '[MASK]', 'attention', '##s', 'to', 'making', 'some', 'sort', 'of', 'documentary', 'on', 'what', 'the', 'ave

In [48]:
train_size = 10_000
test_size = int(0.1 * train_size)

downsampled_dataset = lm_datasets["train"].train_test_split(
    train_size=train_size, test_size=test_size, seed=42
)
downsampled_dataset

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 10000
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 1000
    })
})

In [49]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Once we’re logged in, we can create our tf.data datasets. To do so, we’ll use the prepare_tf_dataset() method, which uses our model to automatically infer which columns should go into the dataset. If you want to control exactly which columns to use, you can use the Dataset.to_tf_dataset() method instead. To keep things simple, we’ll just use the standard data collator here, but you can also try the whole word masking collator and compare the results as an exercise:

In [50]:
tf_train_dataset = model.prepare_tf_dataset(
    downsampled_dataset["train"],
    collate_fn=data_collator,
    shuffle=True,
    batch_size=32,
)

tf_eval_dataset = model.prepare_tf_dataset(
    downsampled_dataset["test"],
    collate_fn=data_collator,
    shuffle=False,
    batch_size=32,
)

In [54]:
from transformers import create_optimizer
from transformers.keras_callbacks import PushToHubCallback
import tensorflow as tf

num_train_steps = len(tf_train_dataset)
optimizer, schedule = create_optimizer(
    init_lr=2e-5,
    num_warmup_steps=1_000,
    num_train_steps=num_train_steps,
    weight_decay_rate=0.01,
)
model.compile(optimizer=optimizer)

# Train in mixed-precision float16
# tf.keras.mixed_precision.set_global_policy("mixed_float16")

model_name = model_checkpoint.split("/")[-1]
callback = PushToHubCallback(
    output_dir=f"{model_name}-fineTuned-imdb", tokenizer=tokenizer
)

For more details, please read https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http.
Cloning https://huggingface.co/Mhammad2023/distilbert-base-uncased-fineTuned-imdb into local empty directory.


We’re now ready to run model.fit() — but before doing so let’s briefly look at perplexity, which is a common metric to evaluate the performance of language models.

# Perplexity for language models

In [55]:
import math

eval_loss = model.evaluate(tf_eval_dataset)
print(f"Perplexity: {math.exp(eval_loss):.2f}")

Perplexity: 22.50


A lower perplexity score means a better language model, and we can see here that our starting model has a somewhat large value. Let’s see if we can lower it by fine-tuning! To do that, we first run the training loop:

In [80]:
model.fit(tf_train_dataset, validation_data=tf_eval_dataset, callbacks=[callback])

  6/312 [..............................] - ETA: 3:06 - loss: 2.6364





<tf_keras.src.callbacks.History at 0x784a1878de50>

In [81]:
eval_loss = model.evaluate(tf_eval_dataset)
print(f"Perplexity: {math.exp(eval_loss):.2f}")

Perplexity: 12.35


# Using our fine-tuned model

In [82]:
from transformers import pipeline

mask_filler = pipeline(
    "fill-mask", model="Mhammad2023/distilbert-base-uncased-fineTuned-imdb"
)

tf_model.h5:   0%|          | 0.00/363M [00:00<?, ?B/s]

All model checkpoint layers were used when initializing TFDistilBertForMaskedLM.

All the layers of TFDistilBertForMaskedLM were initialized from the model checkpoint at Mhammad2023/distilbert-base-uncased-fineTuned-imdb.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertForMaskedLM for predictions without further training.
Device set to use 0


In [83]:
preds = mask_filler(text)

for pred in preds:
    print(f">>> {pred['sequence']}")

>>> this is a great idea.
>>> this is a great deal.
>>> this is a great movie.
>>> this is a great film.
>>> this is a great adventure.
