<a href="https://colab.research.google.com/github/aecoaker/FTA-Summary/blob/master/Exploring_BART_Models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Exploring BART Models to find best Pre-Trained Option

## Example of work prediction

In [6]:
import random
from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig
# load a pre-trained model and tokenizer 'bart-large-cnn'
tokeniser = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

In [6]:
text = "There is nothing quite like a sunny day to remind someone of their own mortality."

In [11]:
#use bart for summary of the sentence to check it all works
inputs = tokeniser.batch_encode_plus([text],return_tensors='pt')
summary_ids = model.generate(inputs['input_ids'], early_stopping=True)
bart_summaries = tokeniser.decode(summary_ids[0], skip_special_tokens=True)
print(bart_summaries)



There is nothing quite like a sunny day to remind someone of their own mortality. There is also nothing like a sun-soaked beach to remind you that you are not immortal. There are no guarantees in life, but there are some things that can be learned from the sun.


In [4]:
#create text with a masked word
text = "There is nothing quite like a sunny <mask> to remind someone of their own mortality."

In [9]:
#now use BART to predict what the word is
input_ids = tokeniser([text], return_tensors="pt")["input_ids"]
logits = model(input_ids).logits
masked_index = (input_ids[0] == tokeniser.mask_token_id).nonzero().item()
probs = logits[0, masked_index].softmax(dim=0)
values, predictions = probs.topk(5) #only get top 5 predictions
tokeniser.decode(predictions).split()

['day,', 'morning', 'moment', 'afternoon']

## Writing this into a function that generates a metric

In [30]:
def is_pred_good(text, model = 'facebook/bart-base'):
  #read in chosen model
  tokeniser = BartTokenizer.from_pretrained(model)
  model = BartForConditionalGeneration.from_pretrained(model)
  #tokenise text and sample 10% of tokens from it
  input_ids = tokeniser([text], return_tensors="pt")["input_ids"]
  n = len(input_ids[0])
  n_masks = int(n/10)
  masks_sample = random.sample(range(1, n), n_masks) #avoid the first and last tokens which are static
  #iterate through the sampled tokens
  for i in range(n_masks):
    print('loop ' + str(i))
    #replace sampled token with '<mask>'
    input_ids = tokeniser([text], return_tensors="pt")["input_ids"]
    true_token = int(input_ids[0][masks_sample[i]])
    input_ids[0][masks_sample[i]] = 50264
    #use BART to predict what this token is
    logits = model(input_ids).logits
    masked_index = (input_ids[0] == tokeniser.mask_token_id).nonzero().item()
    probs = logits[0, masked_index].softmax(dim=0)
    values, predictions = probs.topk(500)
    #get the probability of the true token within this prediction
    try:
      true_token_index = predictions.tolist().index(true_token)
      true_token_prob = values[true_token_index]
    except:
      true_token_prob = 0
    print(true_token_prob)

In [31]:
text = "Amy penned: beep beep boop boop beep beep boop boop beep beep boop boop 'Hey all, I've got some news which isn't easy to share. I've recently been diagnosed with breast cancer but I'm determined to get back on that dance floor before you know it. Welsh love Amy.'  Amy has battled gut condition Crohn's Disease since she was a child and admitted she has already been through 'quite a lot' in her life with her health struggles."
is_pred_good(text)

loop 0
tensor(0.9969, grad_fn=<SelectBackward0>)
loop 1
0
loop 2
tensor(0.0112, grad_fn=<SelectBackward0>)
loop 3
tensor(0.7863, grad_fn=<SelectBackward0>)
loop 4
tensor(0.3439, grad_fn=<SelectBackward0>)
loop 5
tensor(0.5515, grad_fn=<SelectBackward0>)
loop 6
tensor(0.0357, grad_fn=<SelectBackward0>)
loop 7
tensor(0.9620, grad_fn=<SelectBackward0>)
loop 8
tensor(0.2836, grad_fn=<SelectBackward0>)
loop 9
tensor(0.0419, grad_fn=<SelectBackward0>)
