# Fine-tuning GPT2 with Hugging Face Transformers

In [1]:
%pip install datasets transformers==4.28.0

Note: you may need to restart the kernel to use updated packages.


In [2]:
try:
    from google.colab import drive

    IN_COLAB = True
except:
    IN_COLAB = False

import os

if IN_COLAB:
    drive.mount("/content/drive")

## Load dataset

Using Novembers version of REG wiki, HuggingFace `datasets` loads the text files as training samples. The data come directly from a Github wiki download. The library loads the individual files and treats each line as a data sample.

In [4]:
from datasets import load_dataset

data_dir = "./drive/MyDrive/data/wiki-reg/" if IN_COLAB else "../../data/wiki-raw/"

dataset = load_dataset("text", data_dir=data_dir)

  from .autonotebook import tqdm as notebook_tqdm
Resolving data files: 100%|██████████| 52/52 [00:00<00:00, 436207.62it/s]


Downloading and preparing dataset text/default to /Users/egabasova/.cache/huggingface/datasets/text/default-d7149dba1ea46735/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 951.31it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 68.35it/s]
                                                        

Dataset text downloaded and prepared to /Users/egabasova/.cache/huggingface/datasets/text/default-d7149dba1ea46735/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:00<00:00, 417.84it/s]


Let's examine the data: 

In [5]:
dataset["train"][1:10]

{'text': ['',
  '## Context',
  'REG is a large team, with a relatively flat line management structure. Therefore we run a team-wide standardisation process where appraisals and progression award recommendations are reviewed to ensure consistency and fairness across the team.',
  '',
  'Our goal is to work within the wider HR policy and guidance (see [2022-23 HR performance review guide](https://mathison.turing.ac.uk/Utilities/Uploads/Handler/Uploader.ashx?area=composer&filename=PerformanceReviewGuide22-23.pdf&fileguid=945dbb37-769f-41b0-8f2b-72dc540692dd)) to prioritise what we think is most important for our team (fairness, transparency and a clear progression path) and provide clear guidance to achieve that.',
  '',
  '#### Principles',
  '1. REG folk who are operating at an equivalent level should receive equivalent pay.',
  '2. We are a learning team and our default assumption is that people will generally be growing and developing in their role as they gain more experience.']}

In [6]:
# Number of samples should correspond to the number of rows across all files

len(dataset["train"])

2434

## Loading a pre-trained model

From HuggingFace model repository: [HuggingFace models](https://huggingface.co/models)

I'm choosing a small version of GPT2 called [distilgpt2](https://huggingface.co/distilgpt2)

In [7]:
model_checkpoint = "distilgpt2"

### 1. Tokenizer
Each model comes with a tokenizer that was used for originally training the model.

In [11]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

In [14]:
def tokenize_function(examples):
    return tokenizer(examples["text"])


tokenized_dataset = dataset.map(
    tokenize_function, batched=True, remove_columns=["text"]
)

                                                                  

In [20]:
# Let's explore how the tokenized text looks like

tokenized_dataset["train"][3]

{'input_ids': [31553,
  318,
  257,
  1588,
  1074,
  11,
  351,
  257,
  5365,
  6228,
  1627,
  4542,
  4645,
  13,
  8447,
  356,
  1057,
  257,
  1074,
  12,
  4421,
  3210,
  5612,
  1429,
  810,
  28309,
  271,
  874,
  290,
  17085,
  5764,
  10763,
  389,
  11765,
  284,
  4155,
  15794,
  290,
  22692,
  1973,
  262,
  1074,
  13],
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1]}

### 2. Reformat training data

Wrangle data into a shape to enable training.


In [21]:
# Maximum length of input of the model
block_size = tokenizer.model_max_length
print(block_size)

# this seems to be a bit too big for free Colab GPU RAM
block_size = 512

1024


In [22]:
# Reformat the training data to enable effective training - concatenate all
# the text and then split it into chunks of block_size length


def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [23]:
lm_datasets = tokenized_dataset.map(
    group_texts,
    batched=True,
    batch_size=10,
    num_proc=4,
)

                                                                 

In [24]:
print(len(lm_datasets["train"]))  # batches
print(len(lm_datasets["train"][0]["input_ids"]))  # each batch is a block

13
512


In [25]:
# How do the data look now?
# We can use decode from the tokenizer

tokenizer.decode(lm_datasets["train"][0]["input_ids"])

'|:---:|:---:|:---:|:---:|:---:|:---:|:---|| 2023 | 5.0% | [8.8%](https://www.ons.gov.uk/economy/inflationandpriceindices/bulletins/consumerpriceinflation/january2023) | [9.2%](https://www.ons.gov.uk/economy/inflationandpriceindices/bulletins/consumerpriceinflation/december2022) | 20.74% | 22.56% | This is the first year the annual CoL increase has not matched or exceeded CPIH. [Mathison post from Jon Atkins](https://mathison.turing.ac.uk/page/2833) || 2022 | 5.0% | [4.9%](https://www.ons.gov.uk/economy/inflationandpriceindices/bulletins/consumerpriceinflation/january2022) | [4.8%](https://www.ons.gov.uk/economy/inflationandpriceindices/bulletins/consumerpriceinflation/december2021) | 14.99% | 12.65% | Additional £1,000 one-off cost of living support payment made to all staff in bands 1-3 to help offset the disproportionate effect of inflation on those on lower salaries. || 2021 | 1.5% | [0.9%](https://www.ons.gov.uk/economy/inflationandpriceindices/bulletins/consumerpriceinflation/jan

### 3. Fine-tuning the model

In [26]:
from transformers import AutoModelForCausalLM

# Load the same model we used for the tokenizer above
model = AutoModelForCausalLM.from_pretrained(model_checkpoint)

Downloading pytorch_model.bin: 100%|██████████| 353M/353M [00:09<00:00, 36.2MB/s] 
Downloading (…)neration_config.json: 100%|██████████| 124/124 [00:00<00:00, 530kB/s]


In [29]:
from transformers import Trainer, TrainingArguments

model_name = model_checkpoint.split("/")[-1]
print(model_name)

distilgpt2


In [30]:
num_epochs = 10

# training parameters
training_args = TrainingArguments(
    f"{model_name}-finetuned-regwiki-{num_epochs}",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    push_to_hub=False,
    num_train_epochs=num_epochs,
)

In [31]:
# create a trainer class
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"],
    eval_dataset=lm_datasets["train"],  # This is obviously not correct, haha
)

In [32]:
trainer.train()


import math

eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

output_dir = (
    f"./drive/MyDrive/REGindald-trained/reginald-{model_name}-{num_epochs}/"
    if IN_COLAB
    else f"./reginald-{model_name}-{num_epochs}/"
)

trainer.save_model(output_dir=output_dir)

                                              
 10%|█         | 2/20 [00:11<01:16,  4.23s/it]

{'eval_loss': 3.368446111679077, 'eval_runtime': 2.3205, 'eval_samples_per_second': 5.602, 'eval_steps_per_second': 0.862, 'epoch': 1.0}


                                              
 20%|██        | 4/20 [00:20<01:10,  4.42s/it]

{'eval_loss': 3.3026812076568604, 'eval_runtime': 2.0788, 'eval_samples_per_second': 6.254, 'eval_steps_per_second': 0.962, 'epoch': 2.0}


                                              
 30%|███       | 6/20 [00:31<01:04,  4.58s/it]

{'eval_loss': 3.262787103652954, 'eval_runtime': 2.1142, 'eval_samples_per_second': 6.149, 'eval_steps_per_second': 0.946, 'epoch': 3.0}


                                              
 40%|████      | 8/20 [00:40<00:53,  4.46s/it]

{'eval_loss': 3.2287650108337402, 'eval_runtime': 2.0364, 'eval_samples_per_second': 6.384, 'eval_steps_per_second': 0.982, 'epoch': 4.0}


                                               
 50%|█████     | 10/20 [00:50<00:44,  4.42s/it]

{'eval_loss': 3.203280210494995, 'eval_runtime': 2.024, 'eval_samples_per_second': 6.423, 'eval_steps_per_second': 0.988, 'epoch': 5.0}


                                               
 60%|██████    | 12/20 [00:59<00:35,  4.44s/it]

{'eval_loss': 3.1815555095672607, 'eval_runtime': 2.0786, 'eval_samples_per_second': 6.254, 'eval_steps_per_second': 0.962, 'epoch': 6.0}


                                               
 70%|███████   | 14/20 [01:09<00:26,  4.39s/it]

{'eval_loss': 3.1651830673217773, 'eval_runtime': 2.1053, 'eval_samples_per_second': 6.175, 'eval_steps_per_second': 0.95, 'epoch': 7.0}


                                               
 80%|████████  | 16/20 [01:18<00:17,  4.43s/it]

{'eval_loss': 3.1536152362823486, 'eval_runtime': 2.0402, 'eval_samples_per_second': 6.372, 'eval_steps_per_second': 0.98, 'epoch': 8.0}


                                               
 90%|█████████ | 18/20 [01:28<00:08,  4.39s/it]

{'eval_loss': 3.146908760070801, 'eval_runtime': 2.1674, 'eval_samples_per_second': 5.998, 'eval_steps_per_second': 0.923, 'epoch': 9.0}


                                               
100%|██████████| 20/20 [01:38<00:00,  4.94s/it]


{'eval_loss': 3.144015073776245, 'eval_runtime': 2.3812, 'eval_samples_per_second': 5.46, 'eval_steps_per_second': 0.84, 'epoch': 10.0}
{'train_runtime': 98.8383, 'train_samples_per_second': 1.315, 'train_steps_per_second': 0.202, 'train_loss': 3.384896469116211, 'epoch': 10.0}


100%|██████████| 2/2 [00:00<00:00,  2.55it/s]


Perplexity: 23.20


In [34]:
if IN_COLAB:
    !ls ./drive/MyDrive/REGindald-trained/reginald-distilgpt2-10/
else:
    !ls ./reginald-distilgpt2-10/

config.json            pytorch_model.bin
generation_config.json training_args.bin
