# Finetuning BART for abstractive text summarisation with fastai2

A great thing about working in NLP at the moment is being able to park a hard problem for a few weeks and discovering the community making massive amounts of progress on your behalf. I used to be overwhelmed by the challenge of just training a summarisation model to generate plausible looking text without burning through tonnes of cash on GPUs. Then [BertExtAbs](../finetuning-bertsumextabs) came along and solved that problem. Unfortunately, it still gernerated incoherent sentences sometimes and had a habit of confusing entities in an article. You certainly couldn't trust it to convey the facts of an article reliably.

Enter BART (Bidirectional and Auto-Regressive Transformers). Here we have a model that generates staggeringly good summaries and has a wonderful implementation from Sam Shleifer at HuggingFace. It's still a work in progress, but after digging around in the Transformers pull requests and with help from [Morgan McGuire's FastHugs notebook](https://github.com/morganmcg1/fasthugs) I have put together this notebook for fine-tuning BART and generating summaries. Feedback welcome!

I should mention that this a big model requiring big inputs. For fine-tuning I've been able to get a batch size of 4 and a maximum sequence length of 512 on an AWS P3.2xlarge (~£4 an hour).

We begin with a bunch of imports and an args object for storing variables we will need. We'll be finetuning the model on the Curation Corpus of abstractive text summaries. We load it into a dataframe using Pandas. For more information about how to access this dataset for your own purposes please see our [article introducing the dataset](https://medium.com/curation-corporation/teaching-an-ai-to-abstract-a-new-dataset-for-abstractive-auto-summarisation-5227f546caa8).

In [2]:
%reload_ext autoreload
%autoreload 2

In [3]:
import sys
sys.path.append('..')
import logging
logging.getLogger().setLevel(100)
from fastprogress import progress_bar
from fastai2.basics import Transform, Datasets, RandomSplitter, Module, Learner, ranger, params, load_learner
from fastai2.text.all import TensorText
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import PreTrainedTokenizer, BartTokenizer, BartForConditionalGeneration, BartConfig 
import torch
from torch.nn import functional as F
from torch import nn

Hopefully we will be able to increase our batch size and/or maximum sequence lengths when some pull requests to reduce the model's memory footprint get merged into the Transformers repository

In [None]:
class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)
        
args = Namespace(
    batch_size=4,
    max_seq_len=512,
    data_path="../data/private_dataset.file",
    device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), # ('cpu'),
    stories_folder='../data/my_own_stories',
    subset=None,
    test_pct=0.1
)

In [None]:
ds = pd.read_feather(args.data_path).iloc[:args.subset]
ds = ds[ds['summary'] != '']
train_ds, test_ds = train_test_split(ds, test_size=args.test_pct, random_state=42)
valid_ds, test_ds = train_test_split(test_ds, test_size=0.5, random_state=42)

To pass our data to the model in our fastai2 learner object we need a dataloader. To create a dataloader we need a Datasets object, batch size, and device type. To create a Datasets object, we have to pass a few things:
- Our raw data which in our case is a Pandas dataframe
- A list of transforms. Or to be more precise a list containing the list of transforms to perform on our inputs and a list of transforms to perform on our desired outputs. I've defined a transform below that encodes the text using the BART tokenizer. Mostly it will be the encodes class method that gets called by fastai2. However the decodes method can also be useful if you want to reverse the process.
- We will also split our data into training and validation datasets here, using fastai2's RandomSplitter class.

In [None]:
tokenizer = BartTokenizer.from_pretrained('bart-large-cnn', add_prefix_space=True)

I'm still exploring whether it is necessary to pass any of the masks and other ids manually or if it is handled for us. Any advice here would be much appreciated!

In [None]:
class DataTransform(Transform):
    def __init__(self, tokenizer:PreTrainedTokenizer, column:str):
        self.tokenizer = tokenizer
        self.column = column
        
    def encodes(self, inp):  
        tokenized = self.tokenizer.batch_encode_plus(
            [list(inp[self.column])],
            max_length=args.max_seq_len, 
            pad_to_max_length=True, 
            return_tensors='pt'
        )
        return TensorText(tokenized['input_ids']).squeeze()
        
    def decodes(self, encoded):
        decoded = [
            self.tokenizer.decode(
                o, 
                skip_special_tokens=True, 
                clean_up_tokenization_spaces=False
            ) for o in encoded
        ]
        return decoded

In [None]:
x_tfms = [DataTransform(tokenizer, column='text')]
y_tfms = [DataTransform(tokenizer, column='summary')]
dss = Datasets(
    train_ds, 
    tfms=[x_tfms, y_tfms], 
    splits=RandomSplitter(valid_pct=0.1)(range(train_ds.shape[0]))
)

In [None]:
dls = dss.dataloaders(bs=args.batch_size, device=args.device.type)

This function lets us choose between loading the model architecture with Facebook's pretrained weights, the model architecture with our own weights stored locally, or the model architecture with no pretraining at all.

In [None]:
def load_hf_model(config, pretrained=False, path=None): 
    if pretrained:    
        if path:
            model = BartForConditionalGeneration.from_pretrained(
                "bart-large-cnn", 
                state_dict=torch.load(path, map_location=torch.device(args.device)), 
                config=config
            )
        else: 
            model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", config=config)
    else:
        model = BartForConditionalGeneration()

    return model.to(args.device)

The model will return a lot of different things, but we only want the weights to calculate the loss when training, so we will wrap the model in this class to control what gets passed to the loss function.

In [None]:
class FastaiWrapper(Module):
    def __init__(self):
        self.config = BartConfig(vocab_size=50264, output_past=True)
        self.bart = load_hf_model(config=self.config, pretrained=True)
        
    def forward(self, x):
        output = self.bart(x)[0]
        return output

You can think of seq2seq tasks as a series of attempts to categorise which word should come next. Cross entropy loss is a pretty good loss function for this use case. We want to normalise it by how many non padding words are in each sequence.

In [None]:
class SummarisationLoss(Module):
    def __init__(self):
        self.criterion = torch.nn.CrossEntropyLoss()
        
    def forward(self, output, target):
        x = F.log_softmax(output, dim=-1)
        norm = (target != 1).data.sum()
        return self.criterion(x.contiguous().view(-1, x.size(-1)), target.contiguous().view(-1)) / norm

### Training

When fine-tuning the model we start by just training the top layer(s). You can experiment by unfreezing layers further down in the decoder, and then (if you're feeling bold) then encoder. fastai2 provides an easy way to split the model up into groups with frozen or unfrozen parameters.

In [None]:
def bart_splitter(model):
    return [
        params(model.bart.model.encoder), 
        params(model.bart.model.decoder.embed_tokens),
        params(model.bart.model.decoder.embed_positions),
        params(model.bart.model.decoder.layers),
        params(model.bart.model.decoder.layernorm_embedding),
    ]

I've been experimenting with half precision training. In theory this will save a lot of memory. However, I find my loss quickly becomes a bunch of nans. This may be an issue with HuggingFace's implementation or it may be an issue with my code. I'll update if I work out how to get fp16() working. Do let me know if you have any ideas!

In [None]:
learn = Learner(
    dls, 
    FastaiWrapper(), 
    loss_func=SummarisationLoss(), 
    opt_func=ranger,
    splitter=bart_splitter
)#.to_fp16()

In [None]:
learn.freeze_to(-1)

I've been finding that the learning rate finder suggests values that are too high. Your mileage may vary though.

In [None]:
# learn.lr_find()

In [None]:
learn.fit_flat_cos(
    1,
    lr=1e-4
)

If you do carry on unfreezing layers, you may find that you need to reduce your batch size to fit everything in memory. Also you should probably lower your learning rate.

In [None]:
learn.freeze_to(-2)
learn.dls.train.bs = args.batch_size//2
learn.dls.valid.bs = args.batch_size//2

In [None]:
learn.lr_find()

In [None]:
learn.fit_flat_cos(
    2,
    lr=1e-5
)

Now that everything is done we can export the model

In [None]:
learn.export('../models/fintuned_bart.pkl')

### Inference

In [5]:
learn = load_learner('../models/fintuned_bart.pkl')

The following code for generating the summaries comes from [Sam Shleifer's example in the Transformers repository](https://github.com/huggingface/transformers/blob/master/examples/summarization/bart/evaluate_cnn.py). 

In [None]:
def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i : i + n]

def generate_summaries(lns, out_file, batch_size=4):
    dec = []
    for batch in progress_bar(list(chunks(lns, batch_size))):
        dct = tokenizer.batch_encode_plus(
            batch, 
            max_length=1024, 
            return_tensors="pt", 
            pad_to_max_length=True
        )
        
        summaries = learn.model.bart.to(args.device).generate(
            input_ids=dct["input_ids"].to(args.device),
            num_beams=4,
            length_penalty=2.0,
            max_length=142,
            min_length=56,
            no_repeat_ngram_size=3,
        )
        
        dec.extend([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries])
        
    return dec

In [None]:
lns = [" " + x.rstrip() for x in list(test_ds['text'])[:8]]
bart_sums = generate_summaries(lns, f'{args.stories_folder}/output.txt', batch_size=args.batch_size)

In [None]:
for s in bart_sums[:8]:
    print(s)
    print("***************")