In [1]:
import math
import re
import pickle
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
from datasets import load_dataset
from transformers import (
    TrainerCallback,
    GPT2Config,
    GPT2Tokenizer,
    GPT2LMHeadModel,
    AutoConfig,
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
    AdamW,
    TrainingArguments,
    Trainer,
)
from transformers import BartTokenizer, BartTokenizerFast, BartModel, BartForConditionalGeneration



In [3]:
finetuned_model_name = 'BART-movie-plot-generator'
tokenizer = BartTokenizerFast.from_pretrained('facebook/bart-base')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

In [None]:
print(tokenizer.bos_token)
print(tokenizer.eos_token)
print(tokenizer.sep_token)

We need to change the dataset tokens to fit the pre-trained tokenizer tokens and add tokens for all genres

# Load and process dataset


In [None]:
# Load dataset from text file called "data.txt" and split into train/val
dataset = load_dataset("text", data_files="data_top_15_genres.txt")['train']

def processText(example):
    example['text'] = [ re.sub('<BOS>', '<s>', text) for text in example['text'] ]
    example['text'] = [ re.sub('<EOS>', '</s>', text) for text in example['text'] ]
    example['text'] = [ re.sub('<SEP>', '</s>', text) for text in example['text'] ]    
    return example

dataset = dataset.map(processText, batched=True)
dataset

In [None]:
dataset[0]['text']

## Tokenization

We now need to tokenize the dataset. The original tokenizer don't have all special tokens we require.

We need to add the special tokens that we use in our dataset. 

In [None]:
# Add special tokens for each genre
genres = ['romantic drama', 'short film', 'family film',
          'adventure', 'action/adventure', 'indie',
          'black-and-white', 'horror', 'crime fiction',
          'world cinema', 'action', 'thriller', 
          'romance film', 'comedy', 'drama']

print(f'Number of added genres: {len(genres)}')
special_tokens_dict = {'additional_special_tokens': [f'<{genre}>' for genre in genres]}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)

# We must resize token embeddings since new special tokens were added
model.resize_token_embeddings(len(tokenizer))

print(model.config.vocab_size, tokenizer.vocab_size + len(tokenizer.get_added_vocab()))
assert(model.config.vocab_size == tokenizer.vocab_size + len(tokenizer.get_added_vocab()))
print(*tokenizer.all_special_tokens)

In [None]:
tokenizer.tokenize('<s> <drama> This is the title </s> here is the plot </s>')

**Tokenize the dataset**

We tokenize the dataset. The tokenized examples contain the column names 'attention_mask' which is a mask for padding tokens and 'input_ids' which is the id of each token corrsponding to a word. We drop the text as that is not needed anymore. 

Note that we duplicate the inputs to add our labels. This is because the model of the 🤗 Transformers library apply the shifting to the right, so we don't need to do it manually.

In [None]:
def tokenize_function(examples):
    """
    padding='max_length' to pad to a length specified by the max_length argument 
    or the maximum length accepted by the model.
    truncation=True to truncate each sequence to the maximum length accepted by the model
    """
    #result = tokenizer(examples["text"], padding='max_length', truncation=True) # Max input according to model(1024)
    result = tokenizer(examples["text"], max_length=512, padding='max_length', truncation=True)
    result["labels"] = result["input_ids"].copy()
    return result

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

#Make dataset format pytorch tensors
tokenized_dataset.set_format("torch")

In [None]:
#Select subset if wanted
train_set = tokenized_dataset#.select(list(range(10)))
train_set

### Training
First, setup training args.
The last argument to setup everything so we can push the model to the Hub regularly during training..

Then pass training args to Trainer.

In [None]:
class SaveTokenizer(TrainerCallback):
    """
    A callback used to save the tokenizer whenever a model checkpoint is saved.
    """
    def on_save(self, args, state, control, **kwargs):
        tokenizer.save_pretrained(finetuned_model_name)

        
ce_loss = torch.nn.CrossEntropyLoss()
        
def compute_metrics(eval_pred):
    """
    The compute function needs to receive a tuple (with logits and labels)
    and has to return a dictionary with string keys (the name of the metric) and float values.
    It will be called at the end of each evaluation phase on the whole arrays of predictions/labels.
    """
    logits, labels = eval_pred
    # Calculate perplexity https://huggingface.co/transformers/perplexity.html
    # "the exponentiation of the cross-entropy between the data and model predictions."
    
    perplexity = math.exp(ce_loss(logits, labels))
    
    return {'perplexity': perplexity}

In [None]:
torch.cuda.empty_cache()
batch_size = 1 # 1:34:39 for one epoch (no evaluation steps) with batch_size = 2

training_args = TrainingArguments(
    finetuned_model_name,
    num_train_epochs=1,
    per_device_train_batch_size=batch_size,
    save_steps=2500,
    save_total_limit=1,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    compute_metrics=compute_metrics,
    callbacks=[SaveTokenizer],
)

In [None]:
train_results=trainer.train()
pickle.dump(train_results, open(finetuned_model_name+"/train_results.pickle", "wb")) #Load: train_results = pickle.load(open("train_results.pickle", "rb"))
model.save_pretrained(finetuned_model_name)
tokenizer.save_pretrained(finetuned_model_name)

In [None]:
# Inference test
generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
stories = generator("<s> <horror> Testing <\s>", max_length=1024, num_return_sequences=4)
print(*[story['generated_text'] + "\n\n\n------------------------\n" for story in stories])