<a href="https://colab.research.google.com/github/aecoaker/FTA-Summary/blob/master/Exploring_BART_Models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Exploring BART Models to find best Pre-Trained Option

## Example of work prediction

In [18]:
from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig
import statistics
import time
import random
# load a pre-trained model and tokenizer 'bart-large-cnn'
tokeniser = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

In [21]:
text = "'customs duty reduction or elimination' means any customs duty reduction or elimination"

In [22]:
#use bart for summary of the sentence to check it all works
inputs = tokeniser.batch_encode_plus([text],return_tensors='pt')
summary_ids = model.generate(inputs['input_ids'], early_stopping=True)
bart_summaries = tokeniser.decode(summary_ids[0], skip_special_tokens=True)
print(bart_summaries)



'customs duty reduction or elimination' means any customs duty collection or elimination


In [23]:
#create text with a masked word
text = "'customs duty reduction or elimination' means any customs duty <mask> or elimination"

In [24]:
#now use BART to predict what the word is
input_ids = tokeniser([text], return_tensors="pt")["input_ids"]
logits = model(input_ids).logits
masked_index = (input_ids[0] == tokeniser.mask_token_id).nonzero().item()
probs = logits[0, masked_index].softmax(dim=0)
values, predictions = probs.topk(5) #only get top 5 predictions
tokeniser.decode(predictions).split()

['reduction', 'reductions', 'cut', 'increase', 'reduced']

## Writing this into a function that generates a metric

In [12]:
def perplexity(text, model = 'facebook/bart-base'):
  #read in chosen model and set up empty lists to use
  tokeniser = BartTokenizer.from_pretrained(model)
  model = BartForConditionalGeneration.from_pretrained(model)
  prod__pp_t = 1
  probs = []
  #tokenise text to get its length
  input_ids = tokeniser([text], return_tensors="pt")["input_ids"]
  n = len(input_ids[0])
  #iterate through tokens
  for i in range(1, n):
    #get full set of inputs, find real value then replace token with '<mask>'
    input_ids = tokeniser([text], return_tensors="pt")["input_ids"]
    true_token = int(input_ids[0][i])
    input_ids[0][i] = 50264
    #find sliding window of tokens to use as context
    if n < 1024:
      window_start = 0
      window_end = n
    else:
      window_start = i - 512
      window_end = i + 512
    if window_start < 0:
      window_start = 0
      window_end += -(i - 512)
    if window_end > n:
      window_end = n
      window_start += -(n - i - 512)
    #subset the input_ids to only look at the sliding window
    input_ids = input_ids[:, window_start:window_end]
    #use BART to predict what this token is given the context in the window
    logits = model(input_ids).logits
    masked_index = (input_ids[0] == tokeniser.mask_token_id).nonzero().item()
    probs = logits[0, masked_index].softmax(dim=0)
    values, predictions = probs.topk(1000)
    #get the probability of the true token within this prediction
    try:
      true_token_index = predictions.tolist().index(true_token)
      true_token_prob = values[true_token_index].detach().numpy().item()
    #deal with words that aren't in the top 1000 predictions by assigning them a very small probability
    except:
      true_token_prob = 0.00000000001
    #calculate the reciprocals of the probabilities and multiple together
    probs.append(true_token_prob)
    pp_t = 1/true_token_prob
    prod__pp_t *= pp_t
  #calculate the perplexity by normalising this, also (for comparison) show avg probabilities
  perplexity = prod__pp_t ** (1/n)
  print(perplexity)
  print(statistics.mean(probs))

In [13]:
text = "After receipt by a Party's investigating authority of a properly documented application for an anti-dumping investigation or a countervailing duty investigation with respect to imports from the other Party and before initiating an investigation, the importing Party shall provide written notification to the other Party of its receipt of the application. "
start = time.perf_counter()
perplexity(text)
end = time.perf_counter()

15.062050290979776
0.3582748056628363


In [17]:
print('The text had ' + str(len(text.split())) + ' words and this took ' + str(end - start) + ' seconds to run')

The text had 51 words and it took 26.51806341400004 seconds to run


In [32]:
#re-writing function to sample a window from the text, and use that instead of the full text in order to allow us to measure for FTAs
def perplexity(text, model = 'facebook/bart-base'):
  #read in chosen model and set up empty lists to use
  tokeniser = BartTokenizer.from_pretrained(model)
  model = BartForConditionalGeneration.from_pretrained(model)
  prod__pp_t = 1
  test = []
  #tokenise text to get its length
  input_ids = tokeniser([text], return_tensors="pt")["input_ids"]
  n = len(input_ids[0])
  #randomly select a window in which we will consider the perplexity
  if n > 1024:
    window_start = random.sample(range(0, n - 1024), 1)[0]
    window_end = window_start + 1024
  else:
    window_start = 0
    window_end = n
  #iterate through tokens in this window
  for i in range(window_start, window_end):
    #get full set of inputs, find real value then replace its token with '<mask>'
    input_ids = tokeniser([text], return_tensors="pt")["input_ids"][:, window_start:window_end]
    true_token = int(input_ids[0][i])
    input_ids[0][i] = 50264
    #use BART to predict what this token is given the context in the window
    logits = model(input_ids).logits
    masked_index = (input_ids[0] == tokeniser.mask_token_id).nonzero().item()
    probs = logits[0, masked_index].softmax(dim=0)
    values, predictions = probs.topk(1000)
    #get the probability of the true token within this prediction
    try:
      true_token_index = predictions.tolist().index(true_token)
      true_token_prob = values[true_token_index].detach().numpy().item()
    #deal with words that aren't in the top 1000 predictions by assigning them a very small probability
    except:
      true_token_prob = 0.00000000001
    #calculate the reciprocals of the probabilities and multiple together
    test.append(true_token_prob)
    pp_t = 1/true_token_prob
    prod__pp_t *= pp_t
  #calculate the perplexity by normalising this, also (for comparison) show avg probabilities
  perplexity = prod__pp_t ** (1/(window_end - window_start))
  print(perplexity)
  print(statistics.mean(test))

In [20]:
#check this still works as expected for smaller blocks of text
text = "After receipt by a Party's investigating authority of a properly documented application for an anti-dumping investigation or a countervailing duty investigation with respect to imports from the other Party and before initiating an investigation, the importing Party shall provide written notification to the other Party of its receipt of the application. "
perplexity(text)

15.062052955338832
0.3686250351319093


In [None]:
#test on a large block of text
text = 'For the purposes of this Chapter: “bilateral safeguard measure” means a measure referred to in paragraph 2 of Article 3.6 (Application of a Bilateral Safeguard Measure); “customs duty reduction or elimination” means any customs duty reduction or elimination in accordance with paragraph 2 of Article 2.5 (Treatment of Customs Duties – Trade in Goods); “domestic industry” means, with respect to an imported good, the producers as a whole of the like or directly competitive good operating within the territory of a Party, or those whose collective output of the like or directly competitive good constitutes a major proportion of the total domestic production of the good; “serious injury” means a significant overall impairment in the position of a domestic industry; “threat of serious injury” means serious injury that is clearly imminent, in accordance with the provisions of Article 3.8 (Investigation Procedure). A determination of the existence of a threat of serious injury shall be based on facts and not merely on allegation, conjecture, or remote possibility; and “transition period” means, in relation to a good, the entry into force of this Agreement until five years after the completion of the customs duty reduction or elimination in relation to the good. After receipt by a Party’s investigating authority of a properly documented application for an anti-dumping investigation or a countervailing duty investigation with respect to imports from the other Party and before initiating an investigation, the importing Party shall provide written notification to the other Party of its receipt of the application. 2. Without prejudice to its other rights and obligations under the SCM Agreement, prior to initiating a countervailing duty investigation against imports from the other Party, the importing Party shall afford to the other Party a reasonable opportunity to consult with the aim of clarifying the situation on matters raised in the application and arriving at a mutually agreed solution. Any such consultations shall not unnecessarily delay or prevent a Party from proceeding expeditiously to initiate and conduct an investigation. 3. The Parties reaffirm their rights and obligations under Articles 6.2 and 6.3 of the AD Agreement and Article 12.2 of the SCM Agreement, including with respect to the rights of interested parties to present information orally and to defend their interests in the conduct of an anti-dumping investigation or a countervailing duty investigation. 4. Each Party shall ensure, before a final determination is made, full and meaningful disclosure of all essential facts under consideration which form the basis for the decision whether to apply definitive measures in an anti-dumping investigation or a countervailing duty investigation. This is without prejudice to Article 6.5 of the AD Agreement and Article 12.4 of the SCM Agreement. Disclosures shall be made in writing, and allow interested parties sufficient time to defend their interests. For the purposes of this Chapter: “bilateral safeguard measure” means a measure referred to in paragraph 2 of Article 3.6 (Application of a Bilateral Safeguard Measure); “customs duty reduction or elimination” means any customs duty reduction or elimination in accordance with paragraph 2 of Article 2.5 (Treatment of Customs Duties – Trade in Goods); “domestic industry” means, with respect to an imported good, the producers as a whole of the like or directly competitive good operating within the territory of a Party, or those whose collective output of the like or directly competitive good constitutes a major proportion of the total domestic production of the good; “serious injury” means a significant overall impairment in the position of a domestic industry; “threat of serious injury” means serious injury that is clearly imminent, in accordance with the provisions of Article 3.8 (Investigation Procedure). A determination of the existence of a threat of serious injury shall be based on facts and not merely on allegation, conjecture, or remote possibility; and “transition period” means, in relation to a good, the entry into force of this Agreement until five years after the completion of the customs duty reduction or elimination in relation to the good. After receipt by a Party’s investigating authority of a properly documented application for an anti-dumping investigation or a countervailing duty investigation with respect to imports from the other Party and before initiating an investigation, the importing Party shall provide written notification to the other Party of its receipt of the application. 2. Without prejudice to its other rights and obligations under the SCM Agreement, prior to initiating a countervailing duty investigation against imports from the other Party, the importing Party shall afford to the other Party a reasonable opportunity to consult with the aim of clarifying the situation on matters raised in the application and arriving at a mutually agreed solution. Any such consultations shall not unnecessarily delay or prevent a Party from proceeding expeditiously to initiate and conduct an investigation. 3. The Parties reaffirm their rights and obligations under Articles 6.2 and 6.3 of the AD Agreement and Article 12.2 of the SCM Agreement, including with respect to the rights of interested parties to present information orally and to defend their interests in the conduct of an anti-dumping investigation or a countervailing duty investigation. 4. Each Party shall ensure, before a final determination is made, full and meaningful disclosure of all essential facts under consideration which form the basis for the decision whether to apply definitive measures in an anti-dumping investigation or a countervailing duty investigation. This is without prejudice to Article 6.5 of the AD Agreement and Article 12.4 of the SCM Agreement. Disclosures shall be made in writing, and allow interested parties sufficient time to defend their interests. For the purposes of this Chapter: “bilateral safeguard measure” means a measure referred to in paragraph 2 of Article 3.6 (Application of a Bilateral Safeguard Measure); “customs duty reduction or elimination” means any customs duty reduction or elimination in accordance with paragraph 2 of Article 2.5 (Treatment of Customs Duties – Trade in Goods); “domestic industry” means, with respect to an imported good, the producers as a whole of the like or directly competitive good operating within the territory of a Party, or those whose collective output of the like or directly competitive good constitutes a major proportion of the total domestic production of the good; “serious injury” means a significant overall impairment in the position of a domestic industry; “threat of serious injury” means serious injury that is clearly imminent, in accordance with the provisions of Article 3.8 (Investigation Procedure). A determination of the existence of a threat of serious injury shall be based on facts and not merely on allegation, conjecture, or remote possibility; and “transition period” means, in relation to a good, the entry into force of this Agreement until five years after the completion of the customs duty reduction or elimination in relation to the good. After receipt by a Party’s investigating authority of a properly documented application for an anti-dumping investigation or a countervailing duty investigation with respect to imports from the other Party and before initiating an investigation, the importing Party shall provide written notification to the other Party of its receipt of the application. 2. Without prejudice to its other rights and obligations under the SCM Agreement, prior to initiating a countervailing duty investigation against imports from the other Party, the importing Party shall afford to the other Party a reasonable opportunity to consult with the aim of clarifying the situation on matters raised in the application and arriving at a mutually agreed solution. Any such consultations shall not unnecessarily delay or prevent a Party from proceeding expeditiously to initiate and conduct an investigation. 3. The Parties reaffirm their rights and obligations under Articles 6.2 and 6.3 of the AD Agreement and Article 12.2 of the SCM Agreement, including with respect to the rights of interested parties to present information orally and to defend their interests in the conduct of an anti-dumping investigation or a countervailing duty investigation. 4. Each Party shall ensure, before a final determination is made, full and meaningful disclosure of all essential facts under consideration which form the basis for the decision whether to apply definitive measures in an anti-dumping investigation or a countervailing duty investigation. This is without prejudice to Article 6.5 of the AD Agreement and Article 12.4 of the SCM Agreement. Disclosures shall be made in writing, and allow interested parties sufficient time to defend their interests.'
start = time.perf_counter()
perplexity(text)
end = time.perf_counter()

Token indices sequence length is longer than the specified maximum sequence length for this model (1730 > 1024). Running this sequence through the model will result in indexing errors


In [26]:
print('The text had ' + str(len(text.split())) + ' words and this took ' + str(end - start) + ' seconds to run')

The text had 461 words and this took 1189.684408608 seconds to run
