<a href="https://colab.research.google.com/github/JuanJoseMV/neuraltextgen/blob/main/nGrams_evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [100]:
%%capture
!git clone --recursive https://github.com/JuanJoseMV/neuraltextgen.git
!pip install -r /content/neuraltextgen/texygen/requirements.txt

!git clone --recursive https://github.com/nyu-dl/bert-gen.git

In [101]:
import os
import copy
import numpy as np
import pandas as pd
from collections import Counter
import nltk
nltk.download('punkt')

os.chdir("/content/neuraltextgen/texygen")
from utils.metrics.UniqueGram import UniqueGram
from utils.metrics.Bleu import Bleu

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [102]:
path = "/content/neuraltextgen/data/tbc.5k.txt"
file = open(path, "r")
tbc = file.readlines()

path = "/content/neuraltextgen/data/wiki103.5k.txt"
file = open(path, "r")
wiki = file.readlines()

In [103]:
def getGrams(sents, n):
  grams = []
  for line in sents:
    line = line.split(" ")
    grams += UniqueGram(gram=n).get_gram(line)
  dictGrams = Counter(grams)
  return dictGrams

def compareUniqueGrams(pred_ngrams, ref_ngrams, max_n):
  pct_unique={}
  for i in range(2, max_n + 1):
    pred_ngram_counts = set(pred_ngrams[i].keys())
    total = sum(pred_ngrams[i].values())
    ref_ngram_counts = set(ref_ngrams[i].keys())
    pct_unique[i] = len(pred_ngram_counts.difference(ref_ngram_counts)) / total

  return pct_unique

In [104]:
bestModels = [#'parallel-0.15/max_iter=100,std_len=0,init_mask_prob=1,temperature=10,sample=False,top_k=0.txt',
 #'parallel-0.15/max_iter=500,std_len=0,init_mask_prob=1,temperature=1,sample=False,top_k=0.txt',
 #'parallel-0.15/max_iter=500,std_len=5,init_mask_prob=1,temperature=0.1,sample=False,top_k=0.txt',
 '/content/drive/MyDrive/Data Science and Engineering - PoliTo2/2nd Semester/Machine Learning and Deep Learning/Project - NeuralTextGeneration/bertGenGenerations.txt',
 'sequential/max_iter=500,std_len=5,init_mask_prob=1,temperature=0.1,sample=False,top_k=0.txt',
 'sequential/max_iter=500,std_len=5,init_mask_prob=1,temperature=1,sample=False,top_k=0.txt',
 'sequential/max_iter=500,std_len=5,init_mask_prob=1,temperature=10,sample=False,top_k=0.txt',
 'Attention-1/max_iter=100,std_len=0,init_mask_prob=1,temperature=10,sample=False,top_k=0.txt',
 'Attention-1/max_iter=100,std_len=5,init_mask_prob=1,temperature=0.1,sample=True,top_k=0.txt',
 'Attention-1/max_iter=100,std_len=5,init_mask_prob=1,temperature=0.1,sample=True,top_k=100.txt',
 'Parallel-1/max_iter=500,std_len=5,init_mask_prob=1,temperature=1,sample=True,top_k=100.txt',
 'Parallel-1/max_iter=100,std_len=5,init_mask_prob=1,temperature=0.1,sample=False,top_k=0.txt',
 'Parallel-1/max_iter=100,std_len=5,init_mask_prob=1,temperature=1,sample=False,top_k=0.txt']

In [105]:
maxGrams = 4
wikiGrams={}
tbcGrams={}
for i in range(2, maxGrams+1):
  wikiGrams[i] = getGrams(wiki, i)
  tbcGrams[i] = getGrams(tbc, i)

In [106]:
header = ["model", "WT103 n=2", "WT103 n=3", "WT103 n=4", "TBC n=2", "TBC n=3", "TBC n=4"]
resultsDF = pd.DataFrame(columns= header)

for model in bestModels:
  if model.startswith("/content"):
    path = model
    model = model.split("/")[-1]
  else:
    path = "/content/neuraltextgen/data/grid_search/text_generated_"+model

  file = open(path, "r")
  pred = file.readlines()

  modelGrams = {}
  for i in range(2, maxGrams+1):
    modelGrams[i] = getGrams(pred, i)

  pct_uniques_wiki = compareUniqueGrams(modelGrams, wikiGrams, maxGrams)
  pct_uniques_tbc = compareUniqueGrams(modelGrams, tbcGrams, maxGrams)

  pct_uniques = [model] + list(pct_uniques_wiki.values()) + list(pct_uniques_tbc.values())
  pct_uniquesDF = pd.DataFrame([pct_uniques], columns=header)

  resultsDF = pd.concat([resultsDF ,pct_uniquesDF], ignore_index=True)

In [107]:
resultsDF

Unnamed: 0,model,WT103 n=2,WT103 n=3,WT103 n=4,TBC n=2,TBC n=3,TBC n=4
0,bertGenGenerations.txt,0.606933,0.89863,0.96338,0.6752,0.913973,0.964507
1,"sequential/max_iter=500,std_len=5,init_mask_pr...",0.017059,0.027619,0.032759,0.017059,0.027619,0.033103
2,"sequential/max_iter=500,std_len=5,init_mask_pr...",0.014205,0.023853,0.028146,0.015341,0.023853,0.028477
3,"sequential/max_iter=500,std_len=5,init_mask_pr...",0.013555,0.022131,0.025513,0.014578,0.022131,0.025806
4,"Attention-1/max_iter=100,std_len=0,init_mask_p...",0.356121,0.590324,0.725253,0.362914,0.590024,0.725253
5,"Attention-1/max_iter=100,std_len=5,init_mask_p...",0.399855,0.644203,0.770977,0.409593,0.645862,0.771446
6,"Attention-1/max_iter=100,std_len=5,init_mask_p...",0.400372,0.631486,0.751299,0.405334,0.633258,0.750628
7,"Parallel-1/max_iter=500,std_len=5,init_mask_pr...",0.707098,0.942554,0.982824,0.71166,0.936535,0.982553
8,"Parallel-1/max_iter=100,std_len=5,init_mask_pr...",0.052251,0.075418,0.08554,0.053698,0.075585,0.085889
9,"Parallel-1/max_iter=100,std_len=5,init_mask_pr...",0.042007,0.062136,0.071138,0.041822,0.06233,0.071341
