<a href="https://colab.research.google.com/github/JuanJoseMV/neuraltextgen/blob/main/nGrams_evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
%%capture
!git clone --recursive https://github.com/JuanJoseMV/neuraltextgen.git
!pip install -r /content/neuraltextgen/texygen/requirements.txt

!git clone --recursive https://github.com/nyu-dl/bert-gen.git
!pip install simpletransformers

In [3]:
import os
import copy
import numpy as np
import pandas as pd
from collections import Counter
import nltk
nltk.download('punkt')
from transformers import AutoTokenizer

os.chdir("/content/neuraltextgen/texygen")
from utils.metrics.UniqueGram import UniqueGram
from utils.metrics.Bleu import Bleu

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [4]:
path = "/content/neuraltextgen/data/tbc.5k.txt"
file = open(path, "r")
tbc = file.readlines()

path = "/content/neuraltextgen/data/wiki103.5k.txt"
file = open(path, "r")
wiki = file.readlines()

In [5]:
from nltk.util import ngrams
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", do_lower_case="uncased" in "bert-base-uncased")

def getGrams(sents, n):
  grams = []
  for line in sents:
    line = tokenizer.convert_tokens_to_ids(line.split(" "))
    grams += UniqueGram(gram=n).get_gram(line)
  dictGrams = Counter(grams)
  return dictGrams

def compareUniqueGrams(pred_ngrams, ref_ngrams, max_n):
  pct_unique={}
  for i in range(2, max_n + 1):
    pred_ngram_counts = set(pred_ngrams[i].keys())
    total = sum(pred_ngrams[i].values())
    ref_ngram_counts = set(ref_ngrams[i].keys())
    pct_unique[i] = len(pred_ngram_counts.difference(ref_ngram_counts)) / total

  return pct_unique

def selfUniqueGrams(pred_ngrams, max_n):
  pct_unique={}
  for i in range(2, max_n+1):
    n_unique = len([k for k, v in pred_ngrams[i].items() if v == 1])
    total = sum(pred_ngrams[i].values())
    pct_unique[i] = n_unique/total

  return pct_unique

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [6]:
maxGrams = 4
wikiGrams={}
tbcGrams={}
for i in range(2, maxGrams+1):
  wikiGrams[i] = getGrams(wiki, i)
  tbcGrams[i] = getGrams(tbc, i)

In [7]:
import itertools

#supporting function
def _split_into_words(sentences):
  """Splits multiple sentences into words and flattens the result"""
  return list(itertools.chain(*[_.split(" ") for _ in sentences]))

#supporting function
def _get_word_ngrams(n, sentences):
  """Calculates word n-grams for multiple sentences.
  """
  assert len(sentences) > 0
  assert n > 0

  words = _split_into_words(sentences)
  return _get_ngrams(n, words)

#supporting function
def _get_ngrams(n, text):
  """Calcualtes n-grams.
  Args:
    n: which n-grams to calculate
    text: An array of tokens
  Returns:
    A set of n-grams
  """
  ngram_set = set()
  text_length = len(text)
  max_index_ngram_start = text_length - n
  for i in range(max_index_ngram_start + 1):
    ngram_set.add(tuple(text[i:i + n]))
  return ngram_set

def rouge_n(reference_sentences, evaluated_sentences, n=2):
  """
  Computes ROUGE-N of two text collections of sentences.
  Source: http://research.microsoft.com/en-us/um/people/cyl/download/
  papers/rouge-working-note-v1.3.1.pdf
  Args:
    evaluated_sentences: The sentences that have been picked by the summarizer
    reference_sentences: The sentences from the referene set
    n: Size of ngram.  Defaults to 2.
  Returns:
    recall rouge score(float)
  Raises:
    ValueError: raises exception if a param has len <= 0
  """
  if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0:
    raise ValueError("Collections must contain at least 1 sentence.")

  evaluated_ngrams = _get_word_ngrams(n, evaluated_sentences)
  reference_ngrams = _get_word_ngrams(n, reference_sentences)
  reference_count = len(reference_ngrams)
  evaluated_count = len(evaluated_ngrams)

  # Gets the overlapping ngrams between evaluated and reference
  overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams)
  overlapping_count = len(overlapping_ngrams)

  # Handle edge case. This isn't mathematically correct, but it's good enough
  if evaluated_count == 0:
    precision = 0.0
  else:
    precision = overlapping_count / evaluated_count

  if reference_count == 0:
    recall = 0.0
  else:
    recall = overlapping_count / reference_count

  f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8))

  #just returning recall count in rouge, useful for our purpose
  return [precision,recall,f1_score]

In [8]:
models = [("/content/neuraltextgen/data/grid_search/Attention-1.csv", "/content/neuraltextgen/data/grid_search/text_generated_Attention-1"),
               ("/content/neuraltextgen/data/grid_search/Parallel-1.csv", "/content/neuraltextgen/data/grid_search/text_generated_Parallel-1"),
               ("/content/neuraltextgen/data/grid_search/parallel-0.15.csv", "/content/neuraltextgen/data/grid_search/text_generated_Parallel-015"),
               ("/content/neuraltextgen/data/grid_search/sequential.csv", "/content/neuraltextgen/data/grid_search/text_generated_sequential")]

uniqueGramsHeader = ["model","method","mean_tbc","std_tbc","mean_wiki","std_wiki","self n=2","self n=3","self n=4","WT103 n=2", "WT103 n=3", "WT103 n=4", "TBC n=2", "TBC n=3", "TBC n=4"]
uniqueGramsResults = pd.DataFrame(columns= uniqueGramsHeader)

rougeHeader = ["model","method","Wiki R1 Prec","Wiki R1 Rec","Wiki R1 F1", "Wiki R2 Prec","Wiki R2 Rec","Wiki R2 F1",
               "Wiki R3 Prec","Wiki R3 Rec","Wiki R3 F1", "Wiki R4 Prec","Wiki R4 Rec","Wiki R4 F1",
               "TBC R1 Prec","TBC R1 Rec","TBC R1 F1", "TBC R2 Prec","TBC R2 Rec","TBC R2 F1",
               "TBC R3 Prec","TBC R3 Rec","TBC R3 F1", "TBC R4 Prec","TBC R4 Rec","TBC R4 F1",]
rougeResults = pd.DataFrame(columns= rougeHeader)

for model in models:
  bleu, rootpath = model
  bleuDF = pd.read_csv(bleu, header="infer")
  method =rootpath.split("_")[-1]

  for index, row in bleuDF.iterrows():
    temperature = row["temperature"] if row["temperature"] == 0.1 else int(row["temperature"])
    filepath = f'/max_iter={str(row["max_iter"])},std_len={str(row["std_len"])},init_mask_prob={str(row["init_mask_prob"])},temperature={str(temperature)},sample={str(row["sample"])},top_k={str(row["top_k"])}.txt'
    path = rootpath + filepath
    file = open(path, "r")
    pred = file.readlines()

    modelGrams = {}
    for i in range(2, maxGrams+1):
      modelGrams[i] = getGrams(pred, i)

    pct_uniques_self = selfUniqueGrams(modelGrams, maxGrams)
    pct_uniques_wiki = compareUniqueGrams(modelGrams, wikiGrams, maxGrams)
    pct_uniques_tbc = compareUniqueGrams(modelGrams, tbcGrams, maxGrams)

    pct_uniques = [filepath[:-4], method] + [row["mean_tbc"],row["std_tbc"],row["mean_wiki"],row["std_wiki"]] +list(pct_uniques_self.values()) +list(pct_uniques_wiki.values()) + list(pct_uniques_tbc.values())
    pct_uniquesDF = pd.DataFrame([pct_uniques], columns=uniqueGramsHeader)

    uniqueGramsResults = pd.concat([uniqueGramsResults ,pct_uniquesDF], ignore_index=True)

    rougeWiki = []
    rougeTBC = []
    for k in range(1,5):
      rougeWiki += rouge_n(wiki, pred, n=k)
      rougeTBC += rouge_n(wiki, pred, n=k)
    
    rougeResult = [filepath[:-4], method] + rougeWiki + rougeTBC
    rougeResultDF = pd.DataFrame([rougeResult], columns = rougeHeader)

    rougeResults = pd.concat([rougeResults, rougeResultDF])


In [10]:
uniqueGramsResults

Unnamed: 0,model,method,mean_tbc,std_tbc,mean_wiki,std_wiki,self n=2,self n=3,self n=4,WT103 n=2,WT103 n=3,WT103 n=4,TBC n=2,TBC n=3,TBC n=4
0,"/max_iter=100,std_len=0,init_mask_prob=0,tempe...",Attention-1,0.038,0.003,0.049,0.006,0.544964,0.774681,0.900356,0.463129,0.736531,0.886909,0.501799,0.771411,0.911374
1,"/max_iter=100,std_len=0,init_mask_prob=0,tempe...",Attention-1,0.039,0.004,0.051,0.004,0.558746,0.785845,0.908773,0.477621,0.756496,0.898284,0.516784,0.790278,0.922918
2,"/max_iter=100,std_len=0,init_mask_prob=0,tempe...",Attention-1,0.041,0.004,0.049,0.006,0.568205,0.804715,0.923953,0.488791,0.768517,0.904032,0.521026,0.800152,0.929802
3,"/max_iter=100,std_len=0,init_mask_prob=0,tempe...",Attention-1,0.041,0.005,0.050,0.007,0.564359,0.796552,0.909355,0.477351,0.750116,0.894652,0.514875,0.789564,0.920343
4,"/max_iter=100,std_len=0,init_mask_prob=0,tempe...",Attention-1,0.019,0.000,0.020,0.001,0.852677,0.966264,0.992234,0.773917,0.945717,0.986049,0.801422,0.959878,0.991658
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
283,"/max_iter=500,std_len=5,init_mask_prob=1,tempe...",sequential,0.078,0.015,0.057,0.001,0.573356,0.809354,0.894268,0.471623,0.778993,0.894268,0.499537,0.802106,0.909554
284,"/max_iter=500,std_len=5,init_mask_prob=1,tempe...",sequential,0.115,0.021,0.062,0.008,0.408661,0.653492,0.769658,0.350731,0.670579,0.812673,0.346716,0.671322,0.820845
285,"/max_iter=500,std_len=5,init_mask_prob=1,tempe...",sequential,0.360,0.052,0.132,0.018,0.000000,0.000000,0.000000,0.005115,0.015027,0.021994,0.007161,0.016940,0.023460
286,"/max_iter=500,std_len=5,init_mask_prob=1,tempe...",sequential,0.009,0.001,0.010,0.001,0.949144,0.994720,0.999044,0.916816,0.991552,0.998635,0.932916,0.993928,0.999044


In [11]:
rougeResults

Unnamed: 0,model,method,Wiki R1 Prec,Wiki R1 Rec,Wiki R1 F1,Wiki R2 Prec,Wiki R2 Rec,Wiki R2 F1,Wiki R3 Prec,Wiki R3 Rec,Wiki R3 F1,Wiki R4 Prec,Wiki R4 Rec,Wiki R4 F1,TBC R1 Prec,TBC R1 Rec,TBC R1 F1,TBC R2 Prec,TBC R2 Rec,TBC R2 F1,TBC R3 Prec,TBC R3 Rec,TBC R3 F1,TBC R4 Prec,TBC R4 Rec,TBC R4 F1
0,"/max_iter=100,std_len=0,init_mask_prob=0,tempe...",Attention-1,0.270748,0.074557,0.116918,0.083927,0.007772,0.014226,0.014177,0.000937,0.001757,0.001159,0.000070,0.000132,0.270748,0.074557,0.116918,0.083927,0.007772,0.014226,0.014177,0.000937,0.001757,0.001159,0.000070,0.000132
0,"/max_iter=100,std_len=0,init_mask_prob=0,tempe...",Attention-1,0.282808,0.078437,0.122812,0.086085,0.008065,0.014748,0.016693,0.001120,0.002100,0.001995,0.000122,0.000231,0.282808,0.078437,0.122812,0.086085,0.008065,0.014748,0.016693,0.001120,0.002100,0.001995,0.000122,0.000231
0,"/max_iter=100,std_len=0,init_mask_prob=0,tempe...",Attention-1,0.278564,0.077732,0.121546,0.083645,0.007874,0.014394,0.014016,0.000946,0.001773,0.001416,0.000087,0.000165,0.278564,0.077732,0.121546,0.083645,0.007874,0.014394,0.014016,0.000946,0.001773,0.001416,0.000087,0.000165
0,"/max_iter=100,std_len=0,init_mask_prob=0,tempe...",Attention-1,0.281425,0.078014,0.122163,0.088027,0.008182,0.014973,0.016480,0.001091,0.002047,0.001446,0.000087,0.000165,0.281425,0.078014,0.122163,0.088027,0.008182,0.014973,0.016480,0.001091,0.002047,0.001446,0.000087,0.000165
0,"/max_iter=100,std_len=0,init_mask_prob=0,tempe...",Attention-1,0.339551,0.132256,0.190365,0.028728,0.003226,0.005801,0.001559,0.000116,0.000216,0.000000,0.000000,0.000000,0.339551,0.132256,0.190365,0.028728,0.003226,0.005801,0.001559,0.000116,0.000216,0.000000,0.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,"/max_iter=500,std_len=5,init_mask_prob=1,tempe...",sequential,0.409466,0.094590,0.153679,0.138830,0.013432,0.024494,0.020634,0.001458,0.002724,0.001340,0.000087,0.000164,0.409466,0.094590,0.153679,0.138830,0.013432,0.024494,0.020634,0.001458,0.002724,0.001340,0.000087,0.000164
0,"/max_iter=500,std_len=5,init_mask_prob=1,tempe...",sequential,0.423218,0.065317,0.113168,0.160343,0.012332,0.022903,0.026845,0.001613,0.003043,0.002456,0.000140,0.000265,0.423218,0.065317,0.113168,0.160343,0.012332,0.022903,0.026845,0.001613,0.003043,0.002456,0.000140,0.000265
0,"/max_iter=500,std_len=5,init_mask_prob=1,tempe...",sequential,0.519231,0.001904,0.003795,0.258427,0.000337,0.000674,0.025424,0.000029,0.000058,0.006993,0.000009,0.000017,0.519231,0.001904,0.003795,0.258427,0.000337,0.000674,0.025424,0.000029,0.000058,0.006993,0.000009,0.000017
0,"/max_iter=500,std_len=5,init_mask_prob=1,tempe...",sequential,0.247335,0.127672,0.168411,0.000124,0.000015,0.000026,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.247335,0.127672,0.168411,0.000124,0.000015,0.000026,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
