<a href="https://colab.research.google.com/github/JuanJoseMV/neuraltextgen/blob/main/nGrams_evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!git clone --recursive https://github.com/JuanJoseMV/neuraltextgen.git
!pip install -r /content/neuraltextgen/texygen/requirements.txt

!git clone --recursive https://github.com/nyu-dl/bert-gen.git
!pip install simpletransformers

In [2]:
import os
import copy
import numpy as np
import pandas as pd
from collections import Counter
import nltk
nltk.download('punkt')
from transformers import AutoTokenizer

os.chdir("/content/neuraltextgen/texygen")
from utils.metrics.UniqueGram import UniqueGram
from utils.metrics.Bleu import Bleu

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [3]:
path = "/content/neuraltextgen/data/tbc.5k.txt"
file = open(path, "r")
tbc = file.readlines()

path = "/content/neuraltextgen/data/wiki103.5k.txt"
file = open(path, "r")
wiki = file.readlines()

In [4]:
from nltk.util import ngrams
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", do_lower_case="uncased" in "bert-base-uncased")

def getGrams(sents, n):
  grams = []
  for line in sents:
    line = tokenizer.convert_tokens_to_ids(line.split(" "))
    grams += UniqueGram(gram=n).get_gram(line)
  dictGrams = Counter(grams)
  return dictGrams

def compareUniqueGrams(pred_ngrams, ref_ngrams, max_n):
  pct_unique={}
  for i in range(2, max_n + 1):
    pred_ngram_counts = set(pred_ngrams[i].keys())
    total = sum(pred_ngrams[i].values())
    ref_ngram_counts = set(ref_ngrams[i].keys())
    pct_unique[i] = len(pred_ngram_counts.difference(ref_ngram_counts)) / total

  return pct_unique

def selfUniqueGrams(pred_ngrams, max_n):
  pct_unique={}
  for i in range(2, max_n+1):
    n_unique = len([k for k, v in pred_ngrams[i].items() if v == 1])
    total = sum(pred_ngrams[i].values())
    pct_unique[i] = n_unique/total

  return pct_unique

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [5]:
maxGrams = 4
wikiGrams={}
tbcGrams={}
for i in range(2, maxGrams+1):
  wikiGrams[i] = getGrams(wiki, i)
  tbcGrams[i] = getGrams(tbc, i)

In [6]:
models = [("/content/neuraltextgen/data/grid_search/Attention-1.csv", "/content/neuraltextgen/data/grid_search/text_generated_Attention-1"),
               ("/content/neuraltextgen/data/grid_search/Parallel-1.csv", "/content/neuraltextgen/data/grid_search/text_generated_Parallel-1"),
               ("/content/neuraltextgen/data/grid_search/parallel-0.15.csv", "/content/neuraltextgen/data/grid_search/text_generated_Parallel-015"),
               ("/content/neuraltextgen/data/grid_search/sequential.csv", "/content/neuraltextgen/data/grid_search/text_generated_sequential")]

header = ["model","method","mean_tbc","std_tbc","mean_wiki","std_wiki","self n=2","self n=3","self n=4","WT103 n=2", "WT103 n=3", "WT103 n=4", "TBC n=2", "TBC n=3", "TBC n=4"]
resultsDF = pd.DataFrame(columns= header)

for model in models:
  bleu, rootpath = model
  bleuDF = pd.read_csv(bleu, header="infer")
  method =rootpath.split("_")[-1]

  for index, row in bleuDF.iterrows():
    temperature = row["temperature"] if row["temperature"] == 0.1 else int(row["temperature"])
    filepath = f'/max_iter={str(row["max_iter"])},std_len={str(row["std_len"])},init_mask_prob={str(row["init_mask_prob"])},temperature={str(temperature)},sample={str(row["sample"])},top_k={str(row["top_k"])}.txt'
    path = rootpath + filepath
    file = open(path, "r")
    pred = file.readlines()

    modelGrams = {}
    for i in range(2, maxGrams+1):
      modelGrams[i] = getGrams(pred, i)

    pct_uniques_self = selfUniqueGrams(modelGrams, maxGrams)
    pct_uniques_wiki = compareUniqueGrams(modelGrams, wikiGrams, maxGrams)
    pct_uniques_tbc = compareUniqueGrams(modelGrams, tbcGrams, maxGrams)

    pct_uniques = [filepath[:-4], method] + [row["mean_tbc"],row["std_tbc"],row["mean_wiki"],row["std_wiki"]] +list(pct_uniques_self.values()) +list(pct_uniques_wiki.values()) + list(pct_uniques_tbc.values())
    pct_uniquesDF = pd.DataFrame([pct_uniques], columns=header)

    resultsDF = pd.concat([resultsDF ,pct_uniquesDF], ignore_index=True)

In [7]:
resultsDF

Unnamed: 0,model,method,mean_tbc,std_tbc,mean_wiki,std_wiki,self n=2,self n=3,self n=4,WT103 n=2,WT103 n=3,WT103 n=4,TBC n=2,TBC n=3,TBC n=4
0,"/max_iter=100,std_len=0,init_mask_prob=0,tempe...",Attention-1,0.038,0.003,0.049,0.006,0.544964,0.774681,0.900356,0.463129,0.736531,0.886909,0.501799,0.771411,0.911374
1,"/max_iter=100,std_len=0,init_mask_prob=0,tempe...",Attention-1,0.039,0.004,0.051,0.004,0.558746,0.785845,0.908773,0.477621,0.756496,0.898284,0.516784,0.790278,0.922918
2,"/max_iter=100,std_len=0,init_mask_prob=0,tempe...",Attention-1,0.041,0.004,0.049,0.006,0.568205,0.804715,0.923953,0.488791,0.768517,0.904032,0.521026,0.800152,0.929802
3,"/max_iter=100,std_len=0,init_mask_prob=0,tempe...",Attention-1,0.041,0.005,0.050,0.007,0.564359,0.796552,0.909355,0.477351,0.750116,0.894652,0.514875,0.789564,0.920343
4,"/max_iter=100,std_len=0,init_mask_prob=0,tempe...",Attention-1,0.019,0.000,0.020,0.001,0.852677,0.966264,0.992234,0.773917,0.945717,0.986049,0.801422,0.959878,0.991658
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
283,"/max_iter=500,std_len=5,init_mask_prob=1,tempe...",sequential,0.078,0.015,0.057,0.001,0.573356,0.809354,0.894268,0.471623,0.778993,0.894268,0.499537,0.802106,0.909554
284,"/max_iter=500,std_len=5,init_mask_prob=1,tempe...",sequential,0.115,0.021,0.062,0.008,0.408661,0.653492,0.769658,0.350731,0.670579,0.812673,0.346716,0.671322,0.820845
285,"/max_iter=500,std_len=5,init_mask_prob=1,tempe...",sequential,0.360,0.052,0.132,0.018,0.000000,0.000000,0.000000,0.005115,0.015027,0.021994,0.007161,0.016940,0.023460
286,"/max_iter=500,std_len=5,init_mask_prob=1,tempe...",sequential,0.009,0.001,0.010,0.001,0.949144,0.994720,0.999044,0.916816,0.991552,0.998635,0.932916,0.993928,0.999044


# Rouge score

In [45]:

import itertools

#supporting function
def _split_into_words(sentences):
  """Splits multiple sentences into words and flattens the result"""
  return list(itertools.chain(*[_.split(" ") for _ in sentences]))

#supporting function
def _get_word_ngrams(n, sentences):
  """Calculates word n-grams for multiple sentences.
  """
  assert len(sentences) > 0
  assert n > 0

  words = _split_into_words(sentences)
  return _get_ngrams(n, words)

#supporting function
def _get_ngrams(n, text):
  """Calcualtes n-grams.
  Args:
    n: which n-grams to calculate
    text: An array of tokens
  Returns:
    A set of n-grams
  """
  ngram_set = set()
  text_length = len(text)
  max_index_ngram_start = text_length - n
  for i in range(max_index_ngram_start + 1):
    ngram_set.add(tuple(text[i:i + n]))
  return ngram_set

def rouge_n(reference_sentences, evaluated_sentences, n=2):
  """
  Computes ROUGE-N of two text collections of sentences.
  Source: http://research.microsoft.com/en-us/um/people/cyl/download/
  papers/rouge-working-note-v1.3.1.pdf
  Args:
    evaluated_sentences: The sentences that have been picked by the summarizer
    reference_sentences: The sentences from the referene set
    n: Size of ngram.  Defaults to 2.
  Returns:
    recall rouge score(float)
  Raises:
    ValueError: raises exception if a param has len <= 0
  """
  if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0:
    raise ValueError("Collections must contain at least 1 sentence.")

  evaluated_ngrams = _get_word_ngrams(n, evaluated_sentences)
  reference_ngrams = _get_word_ngrams(n, reference_sentences)
  reference_count = len(reference_ngrams)
  evaluated_count = len(evaluated_ngrams)

  # Gets the overlapping ngrams between evaluated and reference
  overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams)
  overlapping_count = len(overlapping_ngrams)

  # Handle edge case. This isn't mathematically correct, but it's good enough
  if evaluated_count == 0:
    precision = 0.0
  else:
    precision = overlapping_count / evaluated_count

  if reference_count == 0:
    recall = 0.0
  else:
    recall = overlapping_count / reference_count

  f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8))

  #just returning recall count in rouge, useful for our purpose
  return (precision,recall,f1_score)

In [63]:
attention_bad = open(os.path.join("/content/neuraltextgen/data/grid_search/text_generated_Attention-1", "max_iter=100,std_len=0,init_mask_prob=0,temperature=10,sample=True,top_k=100.txt")).readlines()
attention_good = open(os.path.join("/content/neuraltextgen/data/grid_search/text_generated_Attention-1", "max_iter=500,std_len=0,init_mask_prob=0,temperature=0.1,sample=True,top_k=0.txt")).readlines()
original = open("/content/bert-gen/data/bert-base-uncased-len40-burnin250-topk100-temp1.000.txt").readlines()[:250]
similar_original = open(os.path.join("/content/neuraltextgen/data/grid_search/text_generated_Parallel-1","max_iter=500,std_len=0,init_mask_prob=1,temperature=1,sample=True,top_k=100.txt")).readlines()

In [73]:
titles = ['original', "similar_original", 'attention_bad', 'attention_good']
for ref,title in zip([original, similar_original , attention_bad, attention_good],titles):
  print(title)
  for k in range(1,5):
    precision, recall, f1score = rouge_n(wiki, ref, n=k)
    print(f" - rouge-{k} precision={precision:.2f}, recall={recall:.2f}, f1={f1score:.2f}")
    


original
 - rouge-1 precision=0.55, recall=0.11, f1=0.19
 - rouge-2 precision=0.20, recall=0.02, f1=0.04
 - rouge-3 precision=0.03, recall=0.00, f1=0.01
 - rouge-4 precision=0.00, recall=0.00, f1=0.00
similar_original
 - rouge-1 precision=0.42, recall=0.08, f1=0.14
 - rouge-2 precision=0.15, recall=0.02, f1=0.03
 - rouge-3 precision=0.02, recall=0.00, f1=0.00
 - rouge-4 precision=0.00, recall=0.00, f1=0.00
attention_bad
 - rouge-1 precision=0.31, recall=0.02, f1=0.04
 - rouge-2 precision=0.11, recall=0.01, f1=0.01
 - rouge-3 precision=0.02, recall=0.00, f1=0.00
 - rouge-4 precision=0.00, recall=0.00, f1=0.00
attention_good
 - rouge-1 precision=0.34, recall=0.07, f1=0.12
 - rouge-2 precision=0.14, recall=0.01, f1=0.02
 - rouge-3 precision=0.03, recall=0.00, f1=0.00
 - rouge-4 precision=0.00, recall=0.00, f1=0.00
