In [1]:
import math
import re
import pickle
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
from datasets import load_dataset
from transformers import (
    TrainerCallback,
    GPT2Config,
    GPT2Tokenizer,
    GPT2LMHeadModel,
    AutoConfig,
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
    AdamW,
    TrainingArguments,
    Trainer,
)

In [2]:
start_from_checkpoint = False

# Load pretrained tokenizer and model
finetuned_model_name = 'movie-plot-generator'

if start_from_checkpoint:
    config=AutoConfig.from_pretrained(finetuned_model_name)
    tokenizer = AutoTokenizer.from_pretrained(finetuned_model_name)
    model = AutoModelForCausalLM.from_pretrained(finetuned_model_name, config=config)
else:
    model_name = 'gpt2' 
    config=AutoConfig.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
    model = AutoModelForCausalLM.from_pretrained(model_name)

In [5]:
generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
stories = generator("<BOS> <action> Shrek in the swamp <SEP> He was in the ", max_length=200, num_return_sequences=2)
print(*[story['generated_text'] + "\n\n\n------------------------\n" for story in stories])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


<BOS> <action> Shrek in the swamp <SEP> He was in the  That the world of Shrek is falling apart; a mysterious girl named Aaliya's life is falling apart, as is her own. When she asks him what she has forgotten, he admits to her that she must come to a conclusion: that she is not alone. Soon after her awakening, Aaliya, a man on the verge of disaster, confronts him in the woods and he kills her with his sword. She is put down by a mob. When he wakes up, the girl screams. It is later revealed that this was the father of the family. With His wife having died, the father has now decided to go back to town, where he would find his daughter and make her return home to her parents. 


------------------------
 <BOS> <action> Shrek in the swamp <SEP> He was in the Riverboat Graveyard by his grandparents. The boat he was on, and the swamp he was on, still had mud walls and flooded water. He lived a great part of the night in his car. A friend of Mr. Marge's and his wife  was working at the bank 

# Load dataset

First, we load the dataset and split into train and validation 

In [4]:
# Load dataset from text file called "data.txt" and split into train/val
datasets = load_dataset("text", data_files="data_top_15_genres.txt")['train']
datasets = datasets.train_test_split(train_size=0.985, seed=42)
datasets['validation'] = datasets.pop('test')

datasets = datasets.map(processText, batched=True)
datasets

Using custom data configuration default-2fcf8d2135508f85
Reusing dataset text (C:\Users\Anton\.cache\huggingface\datasets\text\default-2fcf8d2135508f85\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5)


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Loading cached split indices for dataset at C:\Users\Anton\.cache\huggingface\datasets\text\default-2fcf8d2135508f85\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-1dc70cf5269acfd9.arrow and C:\Users\Anton\.cache\huggingface\datasets\text\default-2fcf8d2135508f85\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-cd027351125f5409.arrow
Loading cached processed dataset at C:\Users\Anton\.cache\huggingface\datasets\text\default-2fcf8d2135508f85\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-f67f75d4d783a388.arrow
Loading cached processed dataset at C:\Users\Anton\.cache\huggingface\datasets\text\default-2fcf8d2135508f85\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-049c6570017b7ca7.arrow





DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 36475
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 556
    })
})

As can be seen, the examples are of different lengths. Examples longer than 1024 tokens needs to be truncated as this is the maximum input to GPT2.

In [5]:
# Example
print(datasets['train'][0]['text'] + '\n')
print(datasets['train'][1]['text'] + '\n')
print(datasets['train'][3]['text'] + '\n')


<BOS> <action> Gundaraj <SEP> Ajay Chauhan lives with his parents and younger sister. He is in love with Pooja, and hopes to marry her someday. His father wants him to get a job and settle down, and then get married. Ajay applies for a job in Bombay, and soon receives a letter asking him to appear for an interview. He attends the interview, and is hired. Delighted to see all his dreams coming true, he goes to offer his thanks to God, and it is there a woman named Pratika Jetley sees him and notifies the police that he is indeed the one who had brutally raped three young women in a college campus. Ajay vehemently denies this, but is personally identified and criminally held responsible, convicted and sentenced to prison. Several years later he is released from prison, and finds out that his father and Pooja had committed suicide while his mother and sister are untraceable. He sets out to put his life together and meets with a ruthless police inspector, whose daughter was one of the rape

## Tokenization

We now need to tokenize the dataset. The original tokenizer don't have all special tokens we require.

In [5]:
print(*tokenizer.all_special_tokens)

<|endoftext|>


We need to add the special tokens that we use in our dataset. 

In [6]:
# Set special tokens
if not start_from_checkpoint:
    tokenizer.bos_token = '<BOS>'
    tokenizer.eos_token = '<EOS>'
    tokenizer.pad_token = '<PAD>'
    tokenizer.sep_token = '<SEP>'

    # Add special tokens for each genre
    genres = ['romantic drama', 'short film', 'family film',
              'adventure', 'action/adventure', 'indie',
              'black-and-white', 'horror', 'crime fiction',
              'world cinema', 'action', 'thriller', 
              'romance film', 'comedy', 'drama']

    print(f'Number of added genres: {len(genres)}')
    new_special_tokens = ['<BOS>', '<EOS>', '<PAD>', '<SEP>']
    new_special_tokens.extend([f'<{genre}>' for genre in genres])
    special_tokens = tokenizer.additional_special_tokens

    special_tokens.extend(new_special_tokens) 
    new_special_tokens_dict = {'additional_special_tokens': special_tokens}
    num_added_toks = tokenizer.add_special_tokens(new_special_tokens_dict)

    # We must resize token embeddings since new special tokens were added
    model.resize_token_embeddings(len(tokenizer))

print(model.config.vocab_size, tokenizer.vocab_size + len(tokenizer.get_added_vocab()))
assert(model.config.vocab_size == tokenizer.vocab_size + len(tokenizer.get_added_vocab()))
print(*tokenizer.all_special_tokens)

Number of added genres: 15
50276 50276
<BOS> <EOS> <|endoftext|> <SEP> <PAD> <romantic drama> <short film> <family film> <adventure> <action/adventure> <indie> <black-and-white> <horror> <crime fiction> <world cinema> <action> <thriller> <romance film> <comedy> <drama>


In [7]:
tokenizer.tokenize('<BOS> <drama> He was' )

['<BOS>', 'Ġ', '<drama>', 'ĠHe', 'Ġwas']

**Tokenize the dataset**

We tokenize the dataset. The tokenized examples contain the column names 'attention_mask' which is a mask for padding tokens and 'input_ids' which is the id of each token corrsponding to a word. We drop the text as that is not needed anymore. 

Note that we duplicate the inputs to add our labels. This is because the model of the 🤗 Transformers library apply the shifting to the right, so we don't need to do it manually.

In [8]:
def tokenize_function(examples):
    """
    padding='max_length' to pad to a length specified by the max_length argument 
    or the maximum length accepted by the model.
    truncation=True to truncate each sequence to the maximum length accepted by the model
    """
    result = tokenizer(examples["text"], padding='max_length', truncation=True) # Max input according to model(1024)

    result["labels"] = result["input_ids"].copy()
    return result

tokenized_datasets = datasets.map(tokenize_function, batched=True, remove_columns=["text"])

#Make dataset format pytorch tensors
tokenized_datasets.set_format("torch")
tokenized_datasets

HBox(children=(FloatProgress(value=0.0, max=37.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




In [10]:
# Finally, extract the datasets and select a subset if wanted
train_set = tokenized_datasets['train']#.select(list(range(10)))
valid_set = tokenized_datasets['validation']#.select(list(range(2)))
print(train_set, valid_set)

### Training
First, setup training args.
The last argument to setup everything so we can push the model to the Hub regularly during training..

Then pass training args to Trainer.

In [12]:
class SaveTokenizer(TrainerCallback):
    """
    A callback used to save the tokenizer whenever a model checkpoint is saved.
    """
    def on_save(self, args, state, control, **kwargs):
        tokenizer.save_pretrained(finetuned_model_name)

        
ce_loss = torch.nn.CrossEntropyLoss()
        
def compute_metrics(eval_pred):
    """
    The compute function needs to receive a tuple (with logits and labels)
    and has to return a dictionary with string keys (the name of the metric) and float values.
    It will be called at the end of each evaluation phase on the whole arrays of predictions/labels.
    """
    logits, labels = eval_pred
    # Calculate perplexity https://huggingface.co/transformers/perplexity.html
    # "the exponentiation of the cross-entropy between the data and model predictions."
    
    perplexity = math.exp(ce_loss(logits, labels))
    
    return {'perplexity': perplexity}

In [13]:
torch.cuda.empty_cache()
batch_size = 1 

training_args = TrainingArguments(
    finetuned_model_name,
    evaluation_strategy = "no",
    num_train_epochs=1,
    learning_rate=1e-6,
    weight_decay=0.01,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=1,
    save_steps=2000,
    save_total_limit=1,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=valid_set,
    compute_metrics=compute_metrics,
    callbacks=[SaveTokenizer],
)

In [14]:
train_results=trainer.train()
pickle.dump(train_results, open("train_results.pickle", "wb")) #Load: train_results = pickle.load(open("train_results.pickle", "rb"))

model.save_pretrained("movie-plot-generator")
tokenizer.save_pretrained("movie-plot-generator")

***** Running training *****
  Num examples = 36475
  Num Epochs = 1
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 36475


Step,Training Loss
500,11.5128
1000,1.9574
1500,1.6135
2000,1.5035
2500,1.4678
3000,1.4301
3500,1.3342
4000,1.4458
4500,1.4535
5000,1.3974


Saving model checkpoint to pranavpsv/gpt2-genre-story-generator\checkpoint-2000
Configuration saved in pranavpsv/gpt2-genre-story-generator\checkpoint-2000\config.json
Model weights saved in pranavpsv/gpt2-genre-story-generator\checkpoint-2000\pytorch_model.bin
tokenizer config file saved in ./pranavpsv/gpt2-genre-story-generator/tokenizer_config.json
Special tokens file saved in ./pranavpsv/gpt2-genre-story-generator/special_tokens_map.json
Saving model checkpoint to pranavpsv/gpt2-genre-story-generator\checkpoint-4000
Configuration saved in pranavpsv/gpt2-genre-story-generator\checkpoint-4000\config.json
Model weights saved in pranavpsv/gpt2-genre-story-generator\checkpoint-4000\pytorch_model.bin
Deleting older checkpoint [pranavpsv\gpt2-genre-story-generator\checkpoint-2000] due to args.save_total_limit
tokenizer config file saved in ./pranavpsv/gpt2-genre-story-generator/tokenizer_config.json
Special tokens file saved in ./pranavpsv/gpt2-genre-story-generator/special_tokens_map.jso

Saving model checkpoint to pranavpsv/gpt2-genre-story-generator\checkpoint-32000
Configuration saved in pranavpsv/gpt2-genre-story-generator\checkpoint-32000\config.json
Model weights saved in pranavpsv/gpt2-genre-story-generator\checkpoint-32000\pytorch_model.bin
Deleting older checkpoint [pranavpsv\gpt2-genre-story-generator\checkpoint-30000] due to args.save_total_limit
tokenizer config file saved in ./pranavpsv/gpt2-genre-story-generator/tokenizer_config.json
Special tokens file saved in ./pranavpsv/gpt2-genre-story-generator/special_tokens_map.json
Saving model checkpoint to pranavpsv/gpt2-genre-story-generator\checkpoint-34000
Configuration saved in pranavpsv/gpt2-genre-story-generator\checkpoint-34000\config.json
Model weights saved in pranavpsv/gpt2-genre-story-generator\checkpoint-34000\pytorch_model.bin
Deleting older checkpoint [pranavpsv\gpt2-genre-story-generator\checkpoint-32000] due to args.save_total_limit
tokenizer config file saved in ./pranavpsv/gpt2-genre-story-gene

('movie-plot-generator\\tokenizer_config.json',
 'movie-plot-generator\\special_tokens_map.json',
 'movie-plot-generator\\vocab.json',
 'movie-plot-generator\\merges.txt',
 'movie-plot-generator\\added_tokens.json',
 'movie-plot-generator\\tokenizer.json')

In [None]:
eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

In [None]:
tokenizer.push_to_hub(finetuned_model_name)
trainer.push_to_hub(finetuned_model_name)

In [19]:
# Inference test
generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
stories = generator("<BOS> <drama> Expecting the unexpected <SEP>", max_length=512, num_return_sequences=4)
print(*[story['generated_text'] + "\n\n\n------------------------\n" for story in stories])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


<BOS> <romantic drama> Expecting the unexpected <SEP> Kajsa and Anton are inseparable while they meet each other in a pub.  A boy and his sister are hired to take over the family's daughter's job and is hired to take over the family's job. She is then sent to work for the family. She is then taken over by the family, and is then brought out to a girl who has a child, and she is then introduced to the girls who also have a child, and she is then brought to the main group, and she is then brought up on a tour of the family's farm as a bus driver. She is then brought back to the family to take her back to the family, and she is brought on a bus to a group that is getting the kids to be with each other. In the end, the group is brought into a group which is brought to a farm in the afternoon. The group is brought up to a group of boys, while the boys are allowed to be with the boys. The boys are given that they will be for the other boys and they are handed a few things. The group is then 