<a href="https://colab.research.google.com/github/aecoaker/FTA-Summary/blob/master/Exploring_BART_Models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Exploring BART Models to find best Pre-Trained Option

## Example of work prediction

In [18]:
from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig
import statistics
import time
import random
# load a pre-trained model and tokenizer 'bart-large-cnn'
tokeniser = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

In [None]:
text = "There is nothing quite like a sunny day to remind someone of their own mortality."

In [None]:
#use bart for summary of the sentence to check it all works
inputs = tokeniser.batch_encode_plus([text],return_tensors='pt')
summary_ids = model.generate(inputs['input_ids'], early_stopping=True)
bart_summaries = tokeniser.decode(summary_ids[0], skip_special_tokens=True)
print(bart_summaries)



There is nothing quite like a sunny day to remind someone of their own mortality. There is also nothing like a sun-soaked beach to remind you that you are not immortal. There are no guarantees in life, but there are some things that can be learned from the sun.


In [None]:
#create text with a masked word
text = "There is nothing quite like a sunny <mask> to remind someone of their own mortality."

In [None]:
#now use BART to predict what the word is
input_ids = tokeniser([text], return_tensors="pt")["input_ids"]
logits = model(input_ids).logits
masked_index = (input_ids[0] == tokeniser.mask_token_id).nonzero().item()
probs = logits[0, masked_index].softmax(dim=0)
values, predictions = probs.topk(5) #only get top 5 predictions
tokeniser.decode(predictions).split()

['day,', 'morning', 'moment', 'afternoon']

In [None]:
input_ids

tensor([[    0,   970,    16,  1085,  1341,   101,    10,  5419, 50264,     7,
          8736,   951,     9,    49,   308, 15812,     4,     2]])

In [None]:
[input_ids[0][1:10].detach().numpy().tolist()]

[[970, 16, 1085, 1341, 101, 10, 5419, 50264, 7]]

## Writing this into a function that generates a metric

In [12]:
def perplexity(text, model = 'facebook/bart-base'):
  #read in chosen model and set up empty lists to use
  tokeniser = BartTokenizer.from_pretrained(model)
  model = BartForConditionalGeneration.from_pretrained(model)
  prod__pp_t = 1
  probs = []
  #tokenise text to get its length
  input_ids = tokeniser([text], return_tensors="pt")["input_ids"]
  n = len(input_ids[0])
  #iterate through tokens
  for i in range(1, n):
    #get full set of inputs, find real value then replace token with '<mask>'
    input_ids = tokeniser([text], return_tensors="pt")["input_ids"]
    true_token = int(input_ids[0][i])
    input_ids[0][i] = 50264
    #find sliding window of tokens to use as context
    if n < 1024:
      window_start = 0
      window_end = n
    else:
      window_start = i - 512
      window_end = i + 512
    if window_start < 0:
      window_start = 0
      window_end += -(i - 512)
    if window_end > n:
      window_end = n
      window_start += -(n - i - 512)
    #subset the input_ids to only look at the sliding window
    input_ids = input_ids[:, window_start:window_end]
    #use BART to predict what this token is given the context in the window
    logits = model(input_ids).logits
    masked_index = (input_ids[0] == tokeniser.mask_token_id).nonzero().item()
    probs = logits[0, masked_index].softmax(dim=0)
    values, predictions = probs.topk(1000)
    #get the probability of the true token within this prediction
    try:
      true_token_index = predictions.tolist().index(true_token)
      true_token_prob = values[true_token_index].detach().numpy().item()
    #deal with words that aren't in the top 1000 predictions by assigning them a very small probability
    except:
      true_token_prob = 0.00000000001
    #calculate the reciprocals of the probabilities and multiple together
    probs.append(true_token_prob)
    pp_t = 1/true_token_prob
    prod__pp_t *= pp_t
  #calculate the perplexity by normalising this, also (for comparison) show avg probabilities
  perplexity = prod__pp_t ** (1/n)
  print(perplexity)
  print(statistics.mean(probs))

In [13]:
text = "After receipt by a Party's investigating authority of a properly documented application for an anti-dumping investigation or a countervailing duty investigation with respect to imports from the other Party and before initiating an investigation, the importing Party shall provide written notification to the other Party of its receipt of the application. "
start = time.perf_counter()
perplexity(text)
end = time.perf_counter()

15.062050290979776
0.3582748056628363


In [17]:
print('The text had ' + str(len(text.split())) + ' words and this took ' + str(end - start) + ' seconds to run')

The text had 51 words and it took 26.51806341400004 seconds to run


In [None]:
#re-writing function to sample a window from the text, and use that instead of the full text in order to allow us to measure for FTAs
def perplexity(text, model = 'facebook/bart-base'):
  #read in chosen model and set up empty lists to use
  tokeniser = BartTokenizer.from_pretrained(model)
  model = BartForConditionalGeneration.from_pretrained(model)
  prod__pp_t = 1
  test = []
  #tokenise text to get its length
  input_ids = tokeniser([text], return_tensors="pt")["input_ids"]
  n = len(input_ids[0])
  #iterate through tokens
  for i in range(1, n):
    #get full set of inputs, find real value then replace token with '<mask>'
    input_ids = tokeniser([text], return_tensors="pt")["input_ids"]
    true_token = int(input_ids[0][i])
    input_ids[0][i] = 50264
    #find sliding window of tokens to use as context
    if n < 1024:
      window_start = 0
      window_end = n
    else:
      window_start = i - 512
      window_end = i + 512
    if window_start < 0:
      window_start = 0
      window_end += -(i - 512)
    if window_end > n:
      window_end = n
      window_start += -(n - i - 512)
    #subset the input_ids to only look at the sliding window
    input_ids = input_ids[:, window_start:window_end]
    #use BART to predict what this token is given the context in the window
    logits = model(input_ids).logits
    masked_index = (input_ids[0] == tokeniser.mask_token_id).nonzero().item()
    probs = logits[0, masked_index].softmax(dim=0)
    values, predictions = probs.topk(1000)
    #get the probability of the true token within this prediction
    try:
      true_token_index = predictions.tolist().index(true_token)
      true_token_prob = values[true_token_index].detach().numpy().item()
    #deal with words that aren't in the top 1000 predictions by assigning them a very small probability
    except:
      true_token_prob = 0.00000000001
    #calculate the reciprocals of the probabilities and multiple together
    test.append(true_token_prob)
    pp_t = 1/true_token_prob
    prod__pp_t *= pp_t
  #calculate the perplexity by normalising this, also (for comparison) show avg probabilities
  perplexity = prod__pp_t ** (1/n)
  print(perplexity)
  print(statistics.mean(test))