This notebook takes the Spanish language [GPT2-base model](https://github.com/PlanTL-GOB-ES/lm-spanish) created by the Barcelona Supercomputing Center (see paper [here](https://arxiv.org/abs/2107.07253)) and fine-tune it using the Spanish portion of the [Europarl corpus](https://www.statmt.org/europarl/).

The fine-tuned model can be used to generate European Parliament-like discourses.

# Load initial models

In [1]:
import numpy as np
import torch
import math

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("PlanTL-GOB-ES/gpt2-base-bne")

model = AutoModelForCausalLM.from_pretrained("PlanTL-GOB-ES/gpt2-base-bne")

Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.


We need to fix the dimension of the embedding layer. Otherwise, we will get an error during training.

In [3]:
model.resize_token_embeddings(len(tokenizer))

Embedding(50263, 768)

# Prepare data

We can create a custom dataset from text files. In this case, we are using an 80-10-10 training-validation-test split of the Europarl corpus that we created beforehand.

In [4]:
from datasets import load_dataset
datasets = load_dataset("text", data_files={"train": "DATA/Europarl/Europarl_es_train.txt", "validation": "DATA/Europarl/Europarl_es_valid.txt"})

Using custom data configuration default
Reusing dataset text (/home/investigacion/.cache/huggingface/datasets/text/default-629021ad4b30cc92/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab)


We can look at a training sample to check everything is OK

In [5]:
print(datasets['train'][5]['text'])

Señora Presidenta, una cuestión de procedimiento. Sabrá usted por la prensa y la televisión que se han producido una serie de explosiones y asesinatos en Sri Lanka. Una de las personas que recientemente han asesinado en Sri Lanka ha sido al Sr. Kumar Ponnambalam, quien hace pocos meses visitó el Parlamento Europeo. ¿Sería apropiado que usted, Señora Presidenta, escribiese una carta al Presidente de Sri Lanka expresando las condolencias del Parlamento por esa y otras muertes violentas, pidiéndole que haga todo lo posible para encontrar una reconciliación pacífica ante la extremadamente difícil situación que está viviendo su país?


Now we prepare the data for training. Next few cells follow the Huggingface causal language modelling example scripts [here](https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm.py)

In [6]:
def tokenize_function(examples):
    return tokenizer(examples["text"])

In [7]:
tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"])

Loading cached processed dataset at /home/investigacion/.cache/huggingface/datasets/text/default-629021ad4b30cc92/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab/cache-b11fbe7d2ad3a8e4.arrow
Loading cached processed dataset at /home/investigacion/.cache/huggingface/datasets/text/default-629021ad4b30cc92/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab/cache-9d2ebaed1a0f0935.arrow
Loading cached processed dataset at /home/investigacion/.cache/huggingface/datasets/text/default-629021ad4b30cc92/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab/cache-922596aaae554e3d.arrow
Loading cached processed dataset at /home/investigacion/.cache/huggingface/datasets/text/default-629021ad4b30cc92/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab/cache-8ced01a05940f33d.arrow
Loading cached processed dataset at /home/investigacion/.cache/huggingface/datasets/text/default-629021ad4b30cc92/0.0.0/daf90a707a433ac193b369c8

In [9]:
block_size = 64
print(block_size)

64


In [10]:
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [11]:
lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
    num_proc=4,
)

Loading cached processed dataset at /home/investigacion/.cache/huggingface/datasets/text/default-629021ad4b30cc92/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab/cache-a882e3d2eb46433f.arrow
Loading cached processed dataset at /home/investigacion/.cache/huggingface/datasets/text/default-629021ad4b30cc92/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab/cache-dc9b99bfba5fecc4.arrow
Loading cached processed dataset at /home/investigacion/.cache/huggingface/datasets/text/default-629021ad4b30cc92/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab/cache-69bf80efcf4d353f.arrow
Loading cached processed dataset at /home/investigacion/.cache/huggingface/datasets/text/default-629021ad4b30cc92/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab/cache-193f83fad421aad8.arrow
Loading cached processed dataset at /home/investigacion/.cache/huggingface/datasets/text/default-629021ad4b30cc92/0.0.0/daf90a707a433ac193b369c8

# Train new model

In [13]:
from transformers import Trainer, TrainingArguments

In [14]:
training_args = TrainingArguments(
    "gpt2_PLANTL_base_ft_europarl",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    dataloader_drop_last=True,
)

In [15]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"],
    eval_dataset=lm_datasets["validation"],
)

In [16]:
trainer.train()



Epoch,Training Loss,Validation Loss,Runtime,Samples Per Second
1,2.9483,3.131453,199.9955,479.371
2,2.8829,3.098344,199.9948,479.372
3,2.8025,3.094039,199.9978,479.365


























TrainOutput(global_step=159549, training_loss=2.902411611152092, metrics={'train_runtime': 18745.9129, 'train_samples_per_second': 8.511, 'total_flos': 121603552719667200, 'epoch': 3.0})

In [17]:
eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

Perplexity: 22.07


In [20]:
trainer.save_model()